diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index d5a2ac1a7..5cf30d8ff 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -20,8 +20,10 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) const ( @@ -2113,13 +2115,30 @@ func TestChannelArbitratorAnchors(t *testing.T) { reports, ) + // Add a dummy payment hash to the preimage lookup. + rHash := [lntypes.PreimageSize]byte{1, 2, 3} + mockPreimageDB := newMockWitnessBeacon() + mockPreimageDB.lookupPreimage[rHash] = rHash + + // Attack a mock PreimageDB and Registry to channel arbitrator. chanArb := chanArbCtx.chanArb - chanArb.cfg.PreimageDB = newMockWitnessBeacon() + chanArb.cfg.PreimageDB = mockPreimageDB chanArb.cfg.Registry = &mockRegistry{} // Setup two pre-confirmation anchor resolutions on the mock channel. chanArb.cfg.Channel.(*mockChannel).anchorResolutions = - &lnwallet.AnchorResolutions{} + &lnwallet.AnchorResolutions{ + Local: &lnwallet.AnchorResolution{ + AnchorSignDescriptor: input.SignDescriptor{ + Output: &wire.TxOut{Value: 1}, + }, + }, + Remote: &lnwallet.AnchorResolution{ + AnchorSignDescriptor: input.SignDescriptor{ + Output: &wire.TxOut{Value: 1}, + }, + }, + } if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -2139,6 +2158,41 @@ func TestChannelArbitratorAnchors(t *testing.T) { } chanArb.UpdateContractSignals(signals) + // Set current block height. + heightHint := uint32(1000) + chanArbCtx.chanArb.blocks <- int32(heightHint) + + // Create testing HTLCs. + htlcExpiryBase := heightHint + uint32(10) + htlcWithPreimage := channeldb.HTLC{ + HtlcIndex: 99, + RefundTimeout: htlcExpiryBase + 2, + RHash: rHash, + Incoming: true, + } + htlc := channeldb.HTLC{ + HtlcIndex: 100, + RefundTimeout: htlcExpiryBase + 3, + } + + // We now send two HTLC updates, one for local HTLC set and the other + // for remote HTLC set. + htlcUpdates <- &ContractUpdate{ + HtlcKey: LocalHtlcSet, + // This will make the deadline of the local anchor resolution + // to be htlcWithPreimage's CLTV minus heightHint since the + // incoming HTLC (toLocalHTLCs) has a lower CLTV value and is + // preimage available. + Htlcs: []channeldb.HTLC{htlc, htlcWithPreimage}, + } + htlcUpdates <- &ContractUpdate{ + HtlcKey: RemoteHtlcSet, + // This will make the deadline of the remote anchor resolution + // to be htlcWithPreimage's CLTV minus heightHint because the + // incoming HTLC (toRemoteHTLCs) has a lower CLTV. + Htlcs: []channeldb.HTLC{htlc, htlcWithPreimage}, + } + errChan := make(chan error, 1) respChan := make(chan *wire.MsgTx, 1) @@ -2254,6 +2308,20 @@ func TestChannelArbitratorAnchors(t *testing.T) { } assertResolverReport(t, reports, expectedReport) + + // We expect two anchor inputs, the local and the remote to be swept. + // Thus we should expect there are two deadlines used, both are equal + // to htlcWithPreimage's CLTV minus current block height. + require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines)) + require.EqualValues(t, + htlcWithPreimage.RefundTimeout-heightHint, + chanArbCtx.sweeper.deadlines[0], + ) + require.EqualValues(t, + htlcWithPreimage.RefundTimeout-heightHint, + chanArbCtx.sweeper.deadlines[1], + ) + } // putResolverReportInChannel returns a put report function which will pipe diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 37c0fe59b..cec788934 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -108,14 +108,17 @@ type mockSweeper struct { sweepTx *wire.MsgTx sweepErr error createSweepTxChan chan *wire.MsgTx + + deadlines []uint32 } func newMockSweeper() *mockSweeper { return &mockSweeper{ - sweptInputs: make(chan input.Input), + sweptInputs: make(chan input.Input, 3), updatedInputs: make(chan wire.OutPoint), sweepTx: &wire.MsgTx{}, createSweepTxChan: make(chan *wire.MsgTx), + deadlines: []uint32{}, } } @@ -124,6 +127,11 @@ func (s *mockSweeper) SweepInput(input input.Input, params sweep.Params) ( s.sweptInputs <- input + // Update the deadlines used if it's set. + if params.Fee.ConfTarget != 0 { + s.deadlines = append(s.deadlines, params.Fee.ConfTarget) + } + result := make(chan sweep.Result, 1) result <- sweep.Result{ Tx: s.sweepTx,