diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index eb54dbf32..93c0b8b57 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -103,17 +103,19 @@ func (i *commitSweepResolverTestContext) waitForResult() { } type mockSweeper struct { - sweptInputs chan input.Input - updatedInputs chan wire.OutPoint - sweepTx *wire.MsgTx - sweepErr error + sweptInputs chan input.Input + updatedInputs chan wire.OutPoint + sweepTx *wire.MsgTx + sweepErr error + createSweepTxChan chan *wire.MsgTx } func newMockSweeper() *mockSweeper { return &mockSweeper{ - sweptInputs: make(chan input.Input), - updatedInputs: make(chan wire.OutPoint), - sweepTx: &wire.MsgTx{}, + sweptInputs: make(chan input.Input), + updatedInputs: make(chan wire.OutPoint), + sweepTx: &wire.MsgTx{}, + createSweepTxChan: make(chan *wire.MsgTx), } } @@ -133,7 +135,9 @@ func (s *mockSweeper) SweepInput(input input.Input, params sweep.Params) ( func (s *mockSweeper) CreateSweepTx(inputs []input.Input, feePref sweep.FeePreference, currentBlockHeight uint32) (*wire.MsgTx, error) { - return nil, nil + // We will wait for the test to supply the sweep tx to return. + sweepTx := <-s.createSweepTxChan + return sweepTx, nil } func (s *mockSweeper) RelayFeePerKW() chainfee.SatPerKWeight { diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index 6e44c22c7..0366853a3 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -1,10 +1,14 @@ package contractcourt import ( + "bytes" + "io" + "reflect" "testing" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/kvdb" @@ -22,11 +26,11 @@ type htlcSuccessResolverTestContext struct { t *testing.T } -func newHtlcSuccessResolverTextContext(t *testing.T) *htlcSuccessResolverTestContext { +func newHtlcSuccessResolverTextContext(t *testing.T, checkpoint io.Reader) *htlcSuccessResolverTestContext { notifier := &mock.ChainNotifier{ - EpochChan: make(chan *chainntnfs.BlockEpoch), - SpendChan: make(chan *chainntnfs.SpendDetail), - ConfChan: make(chan *chainntnfs.TxConfirmation), + EpochChan: make(chan *chainntnfs.BlockEpoch, 1), + SpendChan: make(chan *chainntnfs.SpendDetail, 1), + ConfChan: make(chan *chainntnfs.TxConfirmation, 1), } checkPointChan := make(chan struct{}, 1) @@ -42,6 +46,11 @@ func newHtlcSuccessResolverTextContext(t *testing.T) *htlcSuccessResolverTestCon PublishTx: func(_ *wire.MsgTx, _ string) error { return nil }, + Sweeper: newMockSweeper(), + IncubateOutputs: func(wire.OutPoint, *lnwallet.OutgoingHtlcResolution, + *lnwallet.IncomingHtlcResolution, uint32) error { + return nil + }, }, PutResolverReport: func(_ kvdb.RwTx, report *channeldb.ResolverReport) error { @@ -59,15 +68,27 @@ func newHtlcSuccessResolverTextContext(t *testing.T) *htlcSuccessResolverTestCon return nil }, } + htlc := channeldb.HTLC{ + RHash: testResHash, + OnionBlob: testOnionBlob, + Amt: testHtlcAmt, + } + if checkpoint != nil { + var err error + testCtx.resolver, err = newSuccessResolverFromReader(checkpoint, cfg) + if err != nil { + t.Fatal(err) + } - testCtx.resolver = &htlcSuccessResolver{ - contractResolverKit: *newContractResolverKit(cfg), - htlcResolution: lnwallet.IncomingHtlcResolution{}, - htlc: channeldb.HTLC{ - RHash: testResHash, - OnionBlob: testOnionBlob, - Amt: testHtlcAmt, - }, + testCtx.resolver.Supplement(htlc) + + } else { + + testCtx.resolver = &htlcSuccessResolver{ + contractResolverKit: *newContractResolverKit(cfg), + htlcResolution: lnwallet.IncomingHtlcResolution{}, + htlc: htlc, + } } return testCtx @@ -98,8 +119,9 @@ func (i *htlcSuccessResolverTestContext) waitForResult() { } } -// TestSingleStageSuccess tests successful sweep of a single stage htlc claim. -func TestSingleStageSuccess(t *testing.T) { +// TestHtlcSuccessSingleStage tests successful sweep of a single stage htlc +// claim. +func TestHtlcSuccessSingleStage(t *testing.T) { htlcOutpoint := wire.OutPoint{Index: 3} sweepTx := &wire.MsgTx{ @@ -114,15 +136,6 @@ func TestSingleStageSuccess(t *testing.T) { ClaimOutpoint: htlcOutpoint, } - // We send a confirmation for our sweep tx to indicate that our sweep - // succeeded. - resolve := func(ctx *htlcSuccessResolverTestContext) { - ctx.notifier.ConfChan <- &chainntnfs.TxConfirmation{ - Tx: ctx.resolver.sweepTx, - BlockHeight: testInitialBlockHeight - 1, - } - } - sweepTxid := sweepTx.TxHash() claim := &channeldb.ResolverReport{ OutPoint: htlcOutpoint, @@ -131,14 +144,45 @@ func TestSingleStageSuccess(t *testing.T) { ResolverOutcome: channeldb.ResolverOutcomeClaimed, SpendTxID: &sweepTxid, } + + checkpoints := []checkpoint{ + { + // We send a confirmation for our sweep tx to indicate + // that our sweep succeeded. + preCheckpoint: func(ctx *htlcSuccessResolverTestContext, + _ bool) error { + // The resolver will create and publish a sweep + // tx. + ctx.resolver.Sweeper.(*mockSweeper). + createSweepTxChan <- sweepTx + + // Confirm the sweep, which should resolve it. + ctx.notifier.ConfChan <- &chainntnfs.TxConfirmation{ + Tx: sweepTx, + BlockHeight: testInitialBlockHeight - 1, + } + + return nil + }, + + // After the sweep has confirmed, we expect the + // checkpoint to be resolved, and with the above + // report. + resolved: true, + reports: []*channeldb.ResolverReport{ + claim, + }, + }, + } + testHtlcSuccess( - t, singleStageResolution, resolve, sweepTx, claim, + t, singleStageResolution, checkpoints, ) } // TestSecondStageResolution tests successful sweep of a second stage htlc -// claim. -func TestSecondStageResolution(t *testing.T) { +// claim, going through the Nursery. +func TestHtlcSuccessSecondStageResolution(t *testing.T) { commitOutpoint := wire.OutPoint{Index: 2} htlcOutpoint := wire.OutPoint{Index: 3} @@ -158,20 +202,17 @@ func TestSecondStageResolution(t *testing.T) { PreviousOutPoint: commitOutpoint, }, }, - TxOut: []*wire.TxOut{}, + TxOut: []*wire.TxOut{ + { + Value: 111, + PkScript: []byte{0xaa, 0xaa}, + }, + }, }, ClaimOutpoint: htlcOutpoint, SweepSignDesc: testSignDesc, } - // We send a spend notification for our output to resolve our htlc. - resolve := func(ctx *htlcSuccessResolverTestContext) { - ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ - SpendingTx: sweepTx, - SpenderTxHash: &sweepHash, - } - } - successTx := twoStageResolution.SignedSuccessTx.TxHash() firstStage := &channeldb.ResolverReport{ OutPoint: commitOutpoint, @@ -189,54 +230,167 @@ func TestSecondStageResolution(t *testing.T) { SpendTxID: &sweepHash, } + checkpoints := []checkpoint{ + { + // The resolver will send the output to the Nursery. + incubating: true, + }, + { + // It will then wait for the Nursery to spend the + // output. We send a spend notification for our output + // to resolve our htlc. + preCheckpoint: func(ctx *htlcSuccessResolverTestContext, + _ bool) error { + ctx.notifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: sweepTx, + SpenderTxHash: &sweepHash, + } + + return nil + }, + incubating: true, + resolved: true, + reports: []*channeldb.ResolverReport{ + secondStage, + firstStage, + }, + }, + } + testHtlcSuccess( - t, twoStageResolution, resolve, sweepTx, secondStage, firstStage, + t, twoStageResolution, checkpoints, ) } -// testHtlcSuccess tests resolution of a success resolver. It takes a resolve -// function which triggers resolution and the sweeptxid that will resolve it. +// checkpoint holds expected data we expect the resolver to checkpoint itself +// to the DB next. +type checkpoint struct { + // preCheckpoint is a method that will be called before we reach the + // checkpoint, to carry out any needed operations to drive the resolver + // in this stage. + preCheckpoint func(*htlcSuccessResolverTestContext, bool) error + + // data we expect the resolver to be checkpointed with next. + incubating bool + resolved bool + reports []*channeldb.ResolverReport +} + +// testHtlcSuccess tests resolution of a success resolver. It takes a a list of +// checkpoints that it expects the resolver to go through. And will run the +// resolver all the way through these checkpoints, and also attempt to resume +// the resolver from every checkpoint. func testHtlcSuccess(t *testing.T, resolution lnwallet.IncomingHtlcResolution, - resolve func(*htlcSuccessResolverTestContext), - sweepTx *wire.MsgTx, reports ...*channeldb.ResolverReport) { + checkpoints []checkpoint) { defer timeout(t)() - ctx := newHtlcSuccessResolverTextContext(t) - - // Replace our checkpoint with one which will push reports into a - // channel for us to consume. We replace this function on the resolver - // itself because it is created by the test context. - reportChan := make(chan *channeldb.ResolverReport) - ctx.resolver.Checkpoint = func(_ ContractResolver, - reports ...*channeldb.ResolverReport) error { - - // Send all of our reports into the channel. - for _, report := range reports { - reportChan <- report - } - - return nil - } - + // We first run the resolver from start to finish, ensuring it gets + // checkpointed at every expected stage. We store the checkpointed data + // for the next portion of the test. + ctx := newHtlcSuccessResolverTextContext(t, nil) ctx.resolver.htlcResolution = resolution - // We set the sweepTx to be non-nil and mark the output as already - // incubating so that we do not need to set test values for crafting - // our own sweep transaction. - ctx.resolver.sweepTx = sweepTx - ctx.resolver.outputIncubating = true + checkpointedState := runFromCheckpoint(t, ctx, checkpoints) + + // Now, from every checkpoint created, we re-create the resolver, and + // run the test from that checkpoint. + for i := range checkpointedState { + cp := bytes.NewReader(checkpointedState[i]) + ctx := newHtlcSuccessResolverTextContext(t, cp) + ctx.resolver.htlcResolution = resolution + + // Run from the given checkpoint, ensuring we'll hit the rest. + _ = runFromCheckpoint(t, ctx, checkpoints[i+1:]) + } +} + +// runFromCheckpoint executes the Resolve method on the success resolver, and +// asserts that it checkpoints itself according to the expected checkpoints. +func runFromCheckpoint(t *testing.T, ctx *htlcSuccessResolverTestContext, + expectedCheckpoints []checkpoint) [][]byte { + + defer timeout(t)() + + var checkpointedState [][]byte + + // Replace our checkpoint method with one which we'll use to assert the + // checkpointed state and reports are equal to what we expect. + nextCheckpoint := 0 + checkpointChan := make(chan struct{}) + ctx.resolver.Checkpoint = func(resolver ContractResolver, + reports ...*channeldb.ResolverReport) error { + + if nextCheckpoint >= len(expectedCheckpoints) { + t.Fatal("did not expect more checkpoints") + } + + h := resolver.(*htlcSuccessResolver) + cp := expectedCheckpoints[nextCheckpoint] + + if h.resolved != cp.resolved { + t.Fatalf("expected checkpoint to be resolve=%v, had %v", + cp.resolved, h.resolved) + } + + if !reflect.DeepEqual(h.outputIncubating, cp.incubating) { + t.Fatalf("expected checkpoint to be have "+ + "incubating=%v, had %v", cp.incubating, + h.outputIncubating) + + } + + // Check we go the expected reports. + if len(reports) != len(cp.reports) { + t.Fatalf("unexpected number of reports. Expected %v "+ + "got %v", len(cp.reports), len(reports)) + } + + for i, report := range reports { + if !reflect.DeepEqual(report, cp.reports[i]) { + t.Fatalf("expected: %v, got: %v", + spew.Sdump(cp.reports[i]), + spew.Sdump(report)) + } + } + + // Finally encode the resolver, and store it for later use. + b := bytes.Buffer{} + if err := resolver.Encode(&b); err != nil { + t.Fatal(err) + } + + checkpointedState = append(checkpointedState, b.Bytes()) + nextCheckpoint++ + checkpointChan <- struct{}{} + return nil + } // Start the htlc success resolver. ctx.resolve() - // Trigger and event that will resolve our test context. - resolve(ctx) + // Go through our list of expected checkpoints, so we can run the + // preCheckpoint logic if needed. + resumed := true + for i, cp := range expectedCheckpoints { + if cp.preCheckpoint != nil { + if err := cp.preCheckpoint(ctx, resumed); err != nil { + t.Fatalf("failure at stage %d: %v", i, err) + } - for _, report := range reports { - assertResolverReport(t, reportChan, report) + } + resumed = false + + // Wait for the resolver to have checkpointed its state. + <-checkpointChan } // Wait for the resolver to fully complete. ctx.waitForResult() + + if nextCheckpoint < len(expectedCheckpoints) { + t.Fatalf("not all checkpoints hit") + } + + return checkpointedState }