diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 38489b141..20c9c9cde 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -41,6 +41,12 @@ type paymentLifecycle struct { // or failed with temporary error. Otherwise, we should exit the // lifecycle loop as a terminal error has occurred. resultCollected chan error + + // resultCollector is a function that is used to collect the result of + // an HTLC attempt, which is always mounted to `p.collectResultAsync` + // except in unit test, where we use a much simpler resultCollector to + // decouple the test flow for the payment lifecycle. + resultCollector func(attempt *channeldb.HTLCAttempt) } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -60,6 +66,9 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, resultCollected: make(chan error, 1), } + // Mount the result collector. + p.resultCollector = p.collectResultAsync + // If a timeout is specified, create a timeout channel. If no timeout is // specified, the channel is left nil and will never abort the payment // loop. @@ -178,7 +187,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { log.Infof("Resuming payment shard %v for payment %v", a.AttemptID, p.identifier) - p.collectResultAsync(&a) + p.resultCollector(&a) } // exitWithErr is a helper closure that logs and returns an error. @@ -280,7 +289,7 @@ lifecycle: // Now that the shard was successfully sent, launch a go // routine that will handle its result when its back. if result.err == nil { - p.collectResultAsync(attempt) + p.resultCollector(attempt) } } @@ -416,6 +425,9 @@ type attemptResult struct { // will send a nil error to channel `resultCollected` to indicate there's an // result. func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { + log.Debugf("Collecting result for attempt %v in payment %v", + attempt.AttemptID, p.identifier) + go func() { // Block until the result is available. _, err := p.collectResult(attempt) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index e90f162fc..f4b637ee4 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -1,22 +1,211 @@ package routing import ( + "sync/atomic" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lnmock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) -var ( - dummyErr = errors.New("dummy") -) +var errDummy = errors.New("dummy") + +// createTestPaymentLifecycle creates a `paymentLifecycle` without mocks. +func createTestPaymentLifecycle() *paymentLifecycle { + paymentHash := lntypes.Hash{1, 2, 3} + quitChan := make(chan struct{}) + rt := &ChannelRouter{ + cfg: &Config{}, + quit: quitChan, + } + + return &paymentLifecycle{ + router: rt, + identifier: paymentHash, + } +} + +// mockers wraps a list of mocked interfaces used inside payment lifecycle. +type mockers struct { + shard *mockShard + shardTracker *mockShardTracker + control *mockControlTower + paySession *mockPaymentSession + payer *mockPaymentAttemptDispatcher + clock *lnmock.MockClock + missionControl *mockMissionControl + + // collectResultsCount is the number of times the collectResultAsync + // has been called. + collectResultsCount int + + // payment is the mocked `dbMPPayment` used in the test. + payment *mockMPPayment +} + +// newTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. It +// also asserts the mockers are called as expected when the test is finished. +func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { + paymentHash := lntypes.Hash{1, 2, 3} + quitChan := make(chan struct{}) + + // Create a mock shard to be return from `NewShard`. + mockShard := &mockShard{} + + // Create a list of mocks and add it to the router config. + mockControlTower := &mockControlTower{} + mockPayer := &mockPaymentAttemptDispatcher{} + mockClock := &lnmock.MockClock{} + mockMissionControl := &mockMissionControl{} + + // Make a channel router. + rt := &ChannelRouter{ + cfg: &Config{ + Control: mockControlTower, + Payer: mockPayer, + Clock: mockClock, + MissionControl: mockMissionControl, + }, + quit: quitChan, + } + + // Create mockers to init a payment lifecycle. + mockPaymentSession := &mockPaymentSession{} + mockShardTracker := &mockShardTracker{} + + // Create a test payment lifecycle with no fee limit and no timeout. + p := newPaymentLifecycle( + rt, noFeeLimit, paymentHash, mockPaymentSession, + mockShardTracker, 0, 0, + ) + + // Create a mock payment which is returned from mockControlTower. + mockPayment := &mockMPPayment{} + + mockers := &mockers{ + shard: mockShard, + shardTracker: mockShardTracker, + control: mockControlTower, + paySession: mockPaymentSession, + payer: mockPayer, + clock: mockClock, + missionControl: mockMissionControl, + payment: mockPayment, + } + + // Overwrite the collectResultAsync to focus on testing the payment + // lifecycle within the goroutine. + resultCollector := func(attempt *channeldb.HTLCAttempt) { + mockers.collectResultsCount++ + } + p.resultCollector = resultCollector + + // Validate the mockers are called as expected before exiting the test. + t.Cleanup(func() { + mockShard.AssertExpectations(t) + mockShardTracker.AssertExpectations(t) + mockControlTower.AssertExpectations(t) + mockPaymentSession.AssertExpectations(t) + mockPayer.AssertExpectations(t) + mockClock.AssertExpectations(t) + mockMissionControl.AssertExpectations(t) + mockPayment.AssertExpectations(t) + }) + + return p, mockers +} + +// setupTestPaymentLifecycle creates a new `paymentLifecycle` and mocks the +// initial steps of the payment lifecycle so we can enter into the loop +// directly. +func setupTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { + p, m := newTestPaymentLifecycle(t) + + // Mock the first two calls. + m.control.On("FetchPayment", p.identifier).Return( + m.payment, nil, + ).Once() + + htlcs := []channeldb.HTLCAttempt{} + m.payment.On("InFlightHTLCs").Return(htlcs).Once() + + return p, m +} + +// resumePaymentResult is used to hold the returned values from +// `resumePayment`. +type resumePaymentResult struct { + preimage lntypes.Hash + err error +} + +// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error +// is returned. +func sendPaymentAndAssertFailed(t *testing.T, + p *paymentLifecycle, errExpected error) { + + resultChan := make(chan *resumePaymentResult, 1) + + // We now make a call to `resumePayment` and expect it to return the + // error. + go func() { + preimage, _, err := p.resumePayment() + resultChan <- &resumePaymentResult{ + preimage: preimage, + err: err, + } + }() + + // Validate the returned values or timeout. + select { + case r := <-resultChan: + require.ErrorIs(t, r.err, errExpected, "expected error") + require.Empty(t, r.preimage, "preimage should be empty") + + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for result") + } +} + +// sendPaymentAndAssertSucceeded calls `resumePayment` and asserts that the +// returned preimage is correct. +func sendPaymentAndAssertSucceeded(t *testing.T, + p *paymentLifecycle, expected lntypes.Preimage) { + + resultChan := make(chan *resumePaymentResult, 1) + + // We now make a call to `resumePayment` and expect it to return the + // preimage. + go func() { + preimage, _, err := p.resumePayment() + resultChan <- &resumePaymentResult{ + preimage: preimage, + err: err, + } + }() + + // Validate the returned values or timeout. + select { + case r := <-resultChan: + require.NoError(t, r.err, "unexpected error") + require.EqualValues(t, expected, r.preimage, + "preimage not match") + + case <-time.After(testTimeout): + require.Fail(t, "timeout waiting for result") + } +} // createDummyRoute builds a route a->b->c paying the given amt to c. func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route { @@ -149,21 +338,6 @@ func TestCheckTimeoutOnRouterQuit(t *testing.T) { require.ErrorIs(t, err, ErrRouterShuttingDown) } -// createTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. -func createTestPaymentLifecycle() *paymentLifecycle { - paymentHash := lntypes.Hash{1, 2, 3} - quitChan := make(chan struct{}) - rt := &ChannelRouter{ - cfg: &Config{}, - quit: quitChan, - } - - return &paymentLifecycle{ - router: rt, - identifier: paymentHash, - } -} - // TestRequestRouteSucceed checks that `requestRoute` can successfully request // a route. func TestRequestRouteSucceed(t *testing.T) { @@ -409,9 +583,9 @@ func TestDecideNextStep(t *testing.T) { }, { name: "error on allow more attempts", - allowMoreAttempts: &mockReturn{false, dummyErr}, + allowMoreAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, - expectedErr: dummyErr, + expectedErr: errDummy, }, { name: "no wait and exit", @@ -423,9 +597,9 @@ func TestDecideNextStep(t *testing.T) { { name: "wait returns an error", allowMoreAttempts: &mockReturn{false, nil}, - needWaitAttempts: &mockReturn{false, dummyErr}, + needWaitAttempts: &mockReturn{false, errDummy}, expectedStep: stepExit, - expectedErr: dummyErr, + expectedErr: errDummy, }, { @@ -491,3 +665,949 @@ func TestDecideNextStep(t *testing.T) { payment.AssertExpectations(t) } } + +// TestResumePaymentFailOnFetchPayment checks when we fail to fetch the +// payment, the error is returned. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnFetchPayment(t *testing.T) { + // Create a test paymentLifecycle. + p, m := newTestPaymentLifecycle(t) + + // Mock an error returned. + m.control.On("FetchPayment", p.identifier).Return(nil, errDummy) + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnTimeout checks that when timeout is reached, the +// payment is failed. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnTimeout(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. make the timeout happens instantly and sleep one millisecond to + // make sure it timed out. + p.timeoutChan = time.After(1 * time.Nanosecond) + time.Sleep(1 * time.Millisecond) + + // 4. the payment should be failed with reason timeout. + m.control.On("FailPayment", + p.identifier, channeldb.FailureReasonTimeout, + ).Return(nil).Once() + + // 5. decideNextStep now returns stepExit. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // 6. control tower deletes failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // 7. the payment returns the failed reason. + reason := channeldb.FailureReasonTimeout + m.payment.On("TerminalInfo").Return(nil, &reason) + + // Send the payment and assert it failed with the timeout reason. + sendPaymentAndAssertFailed(t, p, reason) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnTimeoutErr checks that the lifecycle fails when an +// error is returned from `checkTimeout`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnTimeoutErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. quit the router to return an error. + close(p.router.quit) + + // Send the payment and assert it failed when router is shutting down. + sendPaymentAndAssertFailed(t, p, ErrRouterShuttingDown) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnStepErr checks that the lifecycle fails when an +// error is returned from `decideNextStep`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnStepErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns an error. + m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnRequestRouteErr checks that the lifecycle fails when +// an error is returned from `requestRoute`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + paymentAmt := lnwire.MilliSatoshi(10000) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an error. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnRegisterAttemptErr checks that the lifecycle fails +// when an error is returned from `registerAttempt`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 5. mock shardTracker used in `createNewPaymentAttempt` to return an + // error. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Return an error to end the lifecycle. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentFailOnSendAttemptErr checks that the lifecycle fails when +// an error is returned from `sendAttempt`. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be call twice. + m.clock.On("Now").Return(time.Now()).Twice() + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 6. mock `sendAttempt` to return an error. + m.payer.On("SendHTLC", + mock.Anything, attemptID, mock.Anything, + ).Return(errDummy).Once() + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a dummy error to exit the loop. + m.control.On("FailAttempt", + p.identifier, attemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Send the payment and assert it failed. + sendPaymentAndAssertFailed(t, p, errDummy) + + // Expected collectResultAsync to not be called. + require.Zero(t, m.collectResultsCount) +} + +// TestResumePaymentSuccess checks that a normal payment flow that is +// succeeded. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentSuccess(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt) + + // We now enter the payment lifecycle loop. + // + // 1.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 1.2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 1.3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 1.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 1.5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to always return the attemptID. + attemptID := uint64(1) + p.router.cfg.NextPaymentID = func() (uint64, error) { + return attemptID, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be called. + m.clock.On("Now").Return(time.Now()) + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 1.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle. + m.payer.On("SendHTLC", + mock.Anything, attemptID, mock.Anything, + ).Return(nil).Once() + + // We now enter the second iteration of the lifecycle loop. + // + // 2.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Run(func(args mock.Arguments) { + ps.RemainingAmt = 0 + }).Once() + + // 2.3. decideNextStep now returns stepExit and exits the loop. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // We should perform an optional deletion over failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // Finally, mock the `TerminalInfo` to return the settled attempt. + // Create a SettleAttempt. + testPreimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage) + m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once() + + // Send the payment and assert the preimage is matched. + sendPaymentAndAssertSucceeded(t, p, testPreimage) + + // Expected collectResultAsync to called. + require.Equal(t, 1, m.collectResultsCount) +} + +// TestResumePaymentSuccessWithTwoAttempts checks a successful payment flow +// with two HTLC attempts. +// +// NOTE: No parallel test because it overwrites global variables. +// +//nolint:paralleltest +func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := setupTestPaymentLifecycle(t) + + // Create a dummy route that will be returned by `RequestRoute`. + paymentAmt := lnwire.MilliSatoshi(10000) + rt := createDummyRoute(t, paymentAmt/2) + + // We now enter the payment lifecycle loop. + // + // 1.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 1.2. calls `GetState` and return the state. + ps := &channeldb.MPPaymentState{ + RemainingAmt: paymentAmt, + } + m.payment.On("GetState").Return(ps).Once() + + // NOTE: GetStatus is only used to populate the logs which is + // not critical so we loosen the checks on how many times it's + // been called. + m.payment.On("GetStatus").Return(channeldb.StatusInFlight) + + // 1.3. decideNextStep now returns stepProceed. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 1.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // Create two attempt IDs here. + attemptID1 := uint64(1) + attemptID2 := uint64(2) + + // 1.5. mock `registerAttempt` to return an attempt. + // + // Mock NextPaymentID to return the first attemptID on the first call + // and the second attemptID on the second call. + var numAttempts atomic.Uint64 + p.router.cfg.NextPaymentID = func() (uint64, error) { + numAttempts.Add(1) + if numAttempts.Load() == 1 { + return attemptID1, nil + } + + return attemptID2, nil + } + + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID1, false, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // Mock the time and expect it to be called. + m.clock.On("Now").Return(time.Now()) + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Run(func(args mock.Arguments) { + ps.RemainingAmt = paymentAmt / 2 + }).Once() + + // 1.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle where we mock a temp failure. + m.payer.On("SendHTLC", + mock.Anything, attemptID1, mock.Anything, + ).Return(nil).Once() + + // We now enter the second iteration of the lifecycle loop. + // + // 2.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Once() + + // 2.3. decideNextStep now returns stepProceed so we can send the + // second attempt. + m.payment.On("AllowMoreAttempts").Return(true, nil).Once() + + // 2.4. mock requestRoute to return an route. + m.paySession.On("RequestRoute", + paymentAmt/2, p.feeLimit, uint32(ps.NumAttemptsInFlight), + uint32(p.currentHeight), + ).Return(rt, nil).Once() + + // 2.5. mock `registerAttempt` to return an attempt. + // + // Mock shardTracker to return the mock shard. + m.shardTracker.On("NewShard", + attemptID2, true, + ).Return(m.shard, nil).Once() + + // Mock the methods on the shard. + m.shard.On("MPP").Return(&record.MPP{}).Twice(). + On("AMP").Return(nil).Once(). + On("Hash").Return(p.identifier).Once() + + // We now register attempt and return no error. + m.control.On("RegisterAttempt", + p.identifier, mock.Anything, + ).Return(nil).Once() + + // 2.6. mock `sendAttempt` to succeed, which brings us into the next + // iteration of the lifecycle. + m.payer.On("SendHTLC", + mock.Anything, attemptID2, mock.Anything, + ).Return(nil).Once() + + // We now enter the third iteration of the lifecycle loop. + // + // 3.1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 3.2. calls `GetState` and return the state. + m.payment.On("GetState").Return(ps).Once() + + // 3.3. decideNextStep now returns stepExit to exit the loop. + m.payment.On("AllowMoreAttempts").Return(false, nil).Once(). + On("NeedWaitAttempts").Return(false, nil).Once() + + // We should perform an optional deletion over failed attempts. + m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once() + + // Finally, mock the `TerminalInfo` to return the settled attempt. + // Create a SettleAttempt. + testPreimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage) + m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once() + + // Send the payment and assert the preimage is matched. + sendPaymentAndAssertSucceeded(t, p, testPreimage) + + // Expected collectResultAsync to called. + require.Equal(t, 2, m.collectResultsCount) +} + +// TestCollectResultExitOnErr checks that when there's an error returned from +// htlcswitch via `GetAttemptResult`, it's handled and returned. +func TestCollectResultExitOnErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a dummy error. + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(nil, errDummy).Once() + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a switch error. + switchErr := errors.New("switch err") + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, switchErr).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, switchErr, "expected switch error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnResultErr checks that when there's an error returned +// from htlcswitch via the result channel, it's handled and returned. +func TestCollectResultExitOnResultErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send an error to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Error: errDummy, + } + }) + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd fail the payment, cancel the shard and fail the attempt. + // + // `FailPayment` should be called with an internal reason. + reason := channeldb.FailureReasonError + m.control.On("FailPayment", p.identifier, reason).Return(nil).Once() + + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a switch error. + switchErr := errors.New("switch err") + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, switchErr).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, switchErr, "expected switch error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnSwitcQuit checks that when the htlcswitch is shutting +// down an error is returned. +func TestCollectResultExitOnSwitchQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Close the result chan to simulate a htlcswitch quit. + close(resultChan) + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, htlcswitch.ErrSwitchExiting, + "expected switch exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnRouterQuit checks that when the channel router is +// shutting down an error is returned. +func TestCollectResultExitOnRouterQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Close the channel router. + close(p.router.quit) + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, ErrRouterShuttingDown, "expected router exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnLifecycleQuit checks that when the payment lifecycle +// is shutting down an error is returned. +func TestCollectResultExitOnLifecycleQuit(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + attempt := makeFailedAttempt(t, paymentAmt) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Stop the lifecycle. + p.stop() + }) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, ErrPaymentLifecycleExiting, + "expected lifecycle exit") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultExitOnSettleErr checks that when settling the attempt +// fails an error is returned. +func TestCollectResultExitOnSettleErr(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now mock an error being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.ErrorIs(t, err, errDummy, "expected settle error") + require.Nil(t, result, "expected nil attempt") +} + +// TestCollectResultSuccess checks a successful htlc settlement. +func TestCollectResultSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now the settled htlc being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(attempt, nil).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + result, err := p.collectResult(attempt) + require.NoError(t, err, "expected no error") + require.Equal(t, preimage, result.attempt.Settle.Preimage, + "preimage mismatch") +} + +// TestCollectResultAsyncSuccess checks a successful htlc settlement. +func TestCollectResultAsyncSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Mock shardTracker to return the payment hash. + m.shardTracker.On("GetHash", + attempt.AttemptID, + ).Return(p.identifier, nil).Once() + + // Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- &htlcswitch.PaymentResult{ + Preimage: preimage, + } + }) + + // Once the result is received, `ReportPaymentSuccess` should be + // called. + m.missionControl.On("ReportPaymentSuccess", + attempt.AttemptID, &attempt.Route, + ).Return(nil).Once() + + // Now the settled htlc being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(attempt, nil).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Now call the method under test. + p.collectResultAsync(attempt) + + // Assert the result is returned within 5 seconds. + var err error + waitErr := wait.NoError(func() error { + err = <-p.resultCollected + return nil + }, testTimeout) + require.NoError(t, waitErr, "timeout waiting for result") + + // Assert that a nil error is received. + require.NoError(t, err, "expected no error") +}