From 10052ff4f5564b8ce2382316ad99a9064a9d43c2 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Wed, 8 Mar 2023 02:50:33 +0800 Subject: [PATCH] routing: patch unit tests for payment lifecycle This commit adds unit tests for `resumePayment`. In addition, the `resumePayment` has been split into two parts so it's easier to be tested, 1) sending the htlc, and 2) collecting results. As seen in the new tests, this split largely reduces the complexity involved and makes the unit test flow sequential. This commit also makes full use of `mock.Mock` in the unit tests to provide a more clear testing flow. --- routing/payment_lifecycle.go | 16 +- routing/payment_lifecycle_test.go | 1164 ++++++++++++++++++++++++++++- 2 files changed, 1156 insertions(+), 24 deletions(-) diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 38489b141..20c9c9cde 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -41,6 +41,12 @@ type paymentLifecycle struct { // or failed with temporary error. Otherwise, we should exit the // lifecycle loop as a terminal error has occurred. resultCollected chan error + + // resultCollector is a function that is used to collect the result of + // an HTLC attempt, which is always mounted to `p.collectResultAsync` + // except in unit test, where we use a much simpler resultCollector to + // decouple the test flow for the payment lifecycle. + resultCollector func(attempt *channeldb.HTLCAttempt) } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -60,6 +66,9 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, resultCollected: make(chan error, 1), } + // Mount the result collector. + p.resultCollector = p.collectResultAsync + // If a timeout is specified, create a timeout channel. If no timeout is // specified, the channel is left nil and will never abort the payment // loop. @@ -178,7 +187,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { log.Infof("Resuming payment shard %v for payment %v", a.AttemptID, p.identifier) - p.collectResultAsync(&a) + p.resultCollector(&a) } // exitWithErr is a helper closure that logs and returns an error. @@ -280,7 +289,7 @@ lifecycle: // Now that the shard was successfully sent, launch a go // routine that will handle its result when its back. if result.err == nil { - p.collectResultAsync(attempt) + p.resultCollector(attempt) } } @@ -416,6 +425,9 @@ type attemptResult struct { // will send a nil error to channel `resultCollected` to indicate there's an // result. func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { + log.Debugf("Collecting result for attempt %v in payment %v", + attempt.AttemptID, p.identifier) + go func() { // Block until the result is available. _, err := p.collectResult(attempt) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index e90f162fc..f4b637ee4 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -1,22 +1,211 @@ package routing import ( + "sync/atomic" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lnmock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) -var ( - dummyErr = errors.New("dummy") -) +var errDummy = errors.New("dummy") + +// createTestPaymentLifecycle creates a `paymentLifecycle` without mocks. +func createTestPaymentLifecycle() *paymentLifecycle { + paymentHash := lntypes.Hash{1, 2, 3} + quitChan := make(chan struct{}) + rt := &ChannelRouter{ + cfg: &Config{}, + quit: quitChan, + } + + return &paymentLifecycle{ + router: rt, + identifier: paymentHash, + } +} + +// mockers wraps a list of mocked interfaces used inside payment lifecycle. +type mockers struct { + shard *mockShard + shardTracker *mockShardTracker + control *mockControlTower + paySession *mockPaymentSession + payer *mockPaymentAttemptDispatcher + clock *lnmock.MockClock + missionControl *mockMissionControl + + // collectResultsCount is the number of times the collectResultAsync + // has been called. + collectResultsCount int + + // payment is the mocked `dbMPPayment` used in the test. + payment *mockMPPayment +} + +// newTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. It +// also asserts the mockers are called as expected when the test is finished. +func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { + paymentHash := lntypes.Hash{1, 2, 3} + quitChan := make(chan struct{}) + + // Create a mock shard to be return from `NewShard`. + mockShard := &mockShard{} + + // Create a list of mocks and add it to the router config. + mockControlTower := &mockControlTower{} + mockPayer := &mockPaymentAttemptDispatcher{} + mockClock := &lnmock.MockClock{} + mockMissionControl := &mockMissionControl{} + + // Make a channel router. + rt := &ChannelRouter{ + cfg: &Config{ + Control: mockControlTower, + Payer: mockPayer, + Clock: mockClock, + MissionControl: mockMissionControl, + }, + quit: quitChan, + } + + // Create mockers to init a payment lifecycle. + mockPaymentSession := &mockPaymentSession{} + mockShardTracker := &mockShardTracker{} + + // Create a test payment lifecycle with no fee limit and no timeout. + p := newPaymentLifecycle( + rt, noFeeLimit, paymentHash, mockPaymentSession, + mockShardTracker, 0, 0, + ) + + // Create a mock payment which is returned from mockControlTower. + mockPayment := &mockMPPayment{} + + mockers := &mockers{ + shard: mockShard, + shardTracker: mockShardTracker, + control: mockControlTower, + paySession: mockPaymentSession, + payer: mockPayer, + clock: mockClock, + missionControl: mockMissionControl, + payment: mockPayment, + } + + // Overwrite the collectResultAsync to focus on testing the payment + // lifecycle within the goroutine. + resultCollector := func(attempt *channeldb.HTLCAttempt) { + mockers.collectResultsCount++ + } + p.resultCollector = resultCollector + + // Validate the mockers are called as expected before exiting the test. + t.Cleanup(func() { + mockShard.AssertExpectations(t) + mockShardTracker.AssertExpectations(t) + mockControlTower.AssertExpectations(t) + mockPaymentSession.AssertExpectations(t) + mockPayer.AssertExpectations(t) + mockClock.AssertExpectations(t) + mockMissionControl.AssertExpectations(t) + mockPayment.AssertExpectations(t) + }) + + return p, mockers +} + +// setupTestPaymentLifecycle creates a new `paymentLifecycle` and mocks the +// initial steps of the payment lifecycle so we can enter into the loop +// directly. +func setupTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { + p, m := newTestPaymentLifecycle(t) + + // Mock the first two calls. + m.control.On("FetchPayment", p.identifier).Return( + m.payment, nil, + ).Once() + + htlcs := []channeldb.HTLCAttempt{} + m.payment.On("InFlightHTLCs").Return(htlcs).Once() + + return p, m +} + +// resumePaymentResult is used to hold the returned values from +// `resumePayment`. +type resumePaymentResult struct { + preimage lntypes.Hash + err error +} + +// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error +// is returned. +func sendPaymentAndAssertFailed(t *testing.T, + p *paymentLifecycle, errExpected error) { + + resultChan := make(chan *resumePaymentResult, 1) + + // We now make a call to `resumePayment` and expect it to return the + // error. + go func() { + preimage, _, err := p.resumePayment() + resultChan <- &resumePaymentResult{ + preimage: preimage, + err: err, + } + }() + + // Validate the returned values or timeout. + select { + case r := <-resultChan: + require.ErrorIs(t, r.err, errExpected, "expected error") + require.Empty(t, r.preimage, "preimage should be empty") + + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for result") + } +} + +// sendPaymentAndAssertSucceeded calls `resumePayment` and asserts that the +// returned preimage is correct. +func sendPaymentAndAssertSucceeded(t *testing.T, + p *paymentLifecycle, expected lntypes.Preimage) { + + resultChan := make(chan *resumePaymentResult, 1) + + // We now make a call to `resumePayment` and expect it to return the + // preimage. + go func() { + preimage, _, err := p.resumePayment() + resultChan <- &resumePaymentResult{ + preimage: preimage, + err: err, + } + }() + + // Validate the returned values or timeout. + select { + case r := <-resultChan: + require.NoError(t, r.err, "unexpected error") + require.EqualValues(t, expected, r.preimage, + "preimage not match") + + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for result") + } +} // createDummyRoute builds a route a->b->c paying the given amt to c. func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route { @@ -149,21 +338,6 @@ func TestCheckTimeoutOnRouterQuit(t *testing.T) { require.ErrorIs(t, err, ErrRouterShuttingDown) } -// createTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. -func createTestPaymentLifecycle() *paymentLifecycle { - paymentHash := lntypes.Hash{1, 2, 3} - quitChan := make(chan struct{}) - rt := &ChannelRouter{ - cfg: &Config{}, - quit: quitChan, - } - - return &paymentLifecycle{ - router: rt, - identifier: paymentHash, - } -} - // TestRequestRouteSucceed checks that `requestRoute` can successfully request // a route. func TestRequestRouteSucceed(t *testing.T) { @@ -409,9 +583,9 @@ func TestDecideNextStep(t *testing.T) { }, { name: "error on allow more attempts", - allowMoreAttempts: &mockReturn{false, dummyErr}, + allowMoreAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, - expectedErr: dummyErr, + expectedErr: errDummy, }, { name: "no wait and exit", @@ -423,9 +597,9 @@ func TestDecideNextStep(t *testing.T) { { name: "wait returns an error", allowMoreAttempts: &mockReturn{false, nil}, - needWaitAttempts: &mockReturn{false, dummyErr}, + needWaitAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, - expectedErr: dummyErr, + expectedErr: errDummy, }, { @@ -491,3 +665,949 @@ func TestDecideNextStep(t *testing.T) { payment.AssertExpectations(t) } } + +// TestResumePaymentFailOnFetchPayment checks when we fail to fetch the +// payment, the error is returned. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnFetchPayment(t *testing.T) { + // Create a test paymentLifecycle. + p, m := newTestPaymentLifecycle(t) + + // Mock an error returned. + m.control.On("FetchPayment", p.identifier).Return(nil, errDummy) + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnTimeout checks that when timeout is reached, the +// payment is failed. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnTimeout(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. make the timeout happens instantly and sleep one millisecond to + // make sure it timed out. + p.timeoutChan = time.After(1 * time.Nanosecond) + time.Sleep(1 * time.Millisecond) + + // 4. the payment should be failed with reason timeout. + m.control.On("FailPayment", + p.identifier, channeldb.FailureReasonTimeout, + ).Return(nil).Once() + + // 5. decideNextStep now returns stepExit. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // 6. control tower deletes failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // 7. the payment returns the failed reason. + reason := channeldb.FailureReasonTimeout + m.payment.On("TerminalInfo").Return(nil, &reason) + + // Send the payment and assert it failed with the timeout reason. + sendPaymentAndAssertFailed(t, p, reason) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnTimeoutErr checks that the lifecycle fails when an +// error is returned from `checkTimeout`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnTimeoutErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. quit the router to return an error. + close(p.router.quit) + + // Send the payment and assert it failed when router is shutting down. + sendPaymentAndAssertFailed(t, p, ErrRouterShuttingDown) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnStepErr checks that the lifecycle fails when an +// error is returned from `decideNextStep`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnStepErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns an error. + m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnRequestRouteErr checks that the lifecycle fails when +// an error is returned from `requestRoute`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an error. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnRegisterAttemptErr checks that the lifecycle fails +// when an error is returned from `registerAttempt`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 5. mock shardTracker used in `createNewPaymentAttempt` to return an + // error. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Return an error to end the lifecycle. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnSendAttemptErr checks that the lifecycle fails when +// an error is returned from `sendAttempt`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be call twice. + m.clock.On("Now").Return(time.Now()).Twice() + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 6. mock `sendAttempt` to return an error. + m.payer.On("SendHTLC", + mock.Anything, attemptID, mock.Anything, + ).Return(errDummy).Once() + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a dummy error to exit the loop. + m.control.On("FailAttempt", + p.identifier, attemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentSuccess checks that a normal payment flow that is +// succeeded. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentSuccess(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 1.2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 1.3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 1.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 1.5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be called. + m.clock.On("Now").Return(time.Now()) + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 1.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle. + m.payer.On("SendHTLC", + mock.Anything, attemptID, mock.Anything, + ).Return(nil).Once() + + // We now enter the second iteration of the lifecycle loop. + // + // 2.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Run(func(args mock.Arguments) { + ps.RemainingAmt = 0 + }).Once() + + // 2.3. decideNextStep now returns stepExit and exits the loop. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // We should perform an optional deletion over failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // Finally, mock the `TerminalInfo` to return the settled attempt. + // Create a SettleAttempt. + testPreimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage) + m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once() + + // Send the payment and assert the preimage is matched. + sendPaymentAndAssertSucceeded(t, p, testPreimage) + + // Expected collectResultAsync to called. + require.Equal(t, 1, m.collectResultsCount) +} + +// TestResumePaymentSuccessWithTwoAttempts checks a successful payment flow +// with two HTLC attempts. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt/2) + + // We now enter the payment lifecycle loop. + // + // 1.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 1.2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 1.3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 1.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // Create two attempt IDs here. + attemptID1 := uint64(1) + attemptID2 := uint64(2) + + // 1.5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to return the first attemptID on the first call + // and the second attemptID on the second call. + var numAttempts atomic.Uint64 + p.router.cfg.NextPaymentID = func() (uint64, error) { + numAttempts.Add(1) + if numAttempts.Load() == 1 { + return attemptID1, nil + } + + return attemptID2, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID1, false, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be called. + m.clock.On("Now").Return(time.Now()) + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + ps.RemainingAmt = paymentAmt / 2 + }).Once() + + // 1.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle where we mock a temp failure. + m.payer.On("SendHTLC", + mock.Anything, attemptID1, mock.Anything, + ).Return(nil).Once() + + // We now enter the second iteration of the lifecycle loop. + // + // 2.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Once() + + // 2.3. decideNextStep now returns stepProceed so we can send the + // second attempt. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 2.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt/2, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 2.5. mock `registerAttempt` to return an attempt. + // + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID2, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 2.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle. + m.payer.On("SendHTLC", + mock.Anything, attemptID2, mock.Anything, + ).Return(nil).Once() + + // We now enter the third iteration of the lifecycle loop. + // + // 3.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 3.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Once() + + // 3.3. decideNextStep now returns stepExit to exit the loop. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // We should perform an optional deletion over failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // Finally, mock the `TerminalInfo` to return the settled attempt. + // Create a SettleAttempt. + testPreimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage) + m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once() + + // Send the payment and assert the preimage is matched. + sendPaymentAndAssertSucceeded(t, p, testPreimage) + + // Expected collectResultAsync to called. + require.Equal(t, 2, m.collectResultsCount) +} + +// TestCollectResultExitOnErr checks that when there's an error returned from +// htlcswitch via `GetAttemptResult`, it's handled and returned. +func TestCollectResultExitOnErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a dummy error. + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(nil, errDummy).Once() + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a switch error. + switchErr := errors.New("switch err") + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, switchErr).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, switchErr, "expected switch error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnResultErr checks that when there's an error returned +// from htlcswitch via the result channel, it's handled and returned. +func TestCollectResultExitOnResultErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send an error to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Error: errDummy, + } + }) + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a switch error. + switchErr := errors.New("switch err") + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, switchErr).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, switchErr, "expected switch error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnSwitcQuit checks that when the htlcswitch is shutting +// down an error is returned. +func TestCollectResultExitOnSwitchQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Close the result chan to simulate a htlcswitch quit. + close(resultChan) + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, htlcswitch.ErrSwitchExiting, + "expected switch exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnRouterQuit checks that when the channel router is +// shutting down an error is returned. +func TestCollectResultExitOnRouterQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Close the channel router. + close(p.router.quit) + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, ErrRouterShuttingDown, "expected router exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnLifecycleQuit checks that when the payment lifecycle +// is shutting down an error is returned. +func TestCollectResultExitOnLifecycleQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Stop the lifecycle. + p.stop() + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, ErrPaymentLifecycleExiting, + "expected lifecycle exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnSettleErr checks that when settling the attempt +// fails an error is returned. +func TestCollectResultExitOnSettleErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now mock an error being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, errDummy, "expected settle error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultSuccess checks a successful htlc settlement. +func TestCollectResultSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now the settled htlc being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(attempt, nil).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.NoError(t, err, "expected no error") + require.Equal(t, preimage, result.attempt.Settle.Preimage, + "preimage mismatch") +} + +// TestCollectResultAsyncSuccess checks a successful htlc settlement. +func TestCollectResultAsyncSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now the settled htlc being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(attempt, nil).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + p.collectResultAsync(attempt) + + // Assert the result is returned within 5 seconds. + var err error + waitErr := wait.NoError(func() error { + err = <-p.resultCollected + return nil + }, testTimeout) + require.NoError(t, waitErr, "timeout waiting for result") + + // Assert that a nil error is received. + require.NoError(t, err, "expected no error") +}