From 34d0e5d4c5ae9cef2aea5a48c2355c54bccd77a8 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 9 Feb 2023 12:51:43 +0800 Subject: [PATCH] routing+channeldb: make MPPayment into an interface This commit turns `MPPayment` into an interface inside `routing`. Having this interface gives us the benefit to write more granular unit tests inside payment lifecycle. As seen from the modified unit tests, several hacky ways of testing the `SendPayment` method is now replaced by a mock over `MPPayment`. --- .golangci.yml | 7 + channeldb/mp_payment.go | 20 +++ routing/control_tower.go | 30 +++- routing/control_tower_test.go | 14 +- routing/mock_test.go | 64 +++++-- routing/payment_lifecycle.go | 8 +- routing/payment_lifecycle_test.go | 6 - routing/router_test.go | 275 +++++++++++++++--------------- 8 files changed, 257 insertions(+), 167 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 7d3154222..951831384 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -268,3 +268,10 @@ issues: # if the returned value doesn't match the type, so there's no need to # check the convert. - forcetypeassert + + - path: mock* + linters: + # forcetypeassert is skipped for the mock because the test would fail + # if the returned value doesn't match the type, so there's no need to + # check the convert. + - forcetypeassert diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index 55112f3c0..9a5e0fddd 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -446,6 +446,26 @@ func (m *MPPayment) NeedWaitAttempts() (bool, error) { } } +// GetState returns the internal state of the payment. +func (m *MPPayment) GetState() *MPPaymentState { + return m.State +} + +// Status returns the current status of the payment. +func (m *MPPayment) GetStatus() PaymentStatus { + return m.Status +} + +// GetPayment returns all the HTLCs for this payment. +func (m *MPPayment) GetHTLCs() []HTLCAttempt { + return m.HTLCs +} + +// GetFailureReason returns the failure reason. +func (m *MPPayment) GetFailureReason() *FailureReason { + return m.FailureReason +} + // serializeHTLCSettleInfo serializes the details of a settled htlc. func serializeHTLCSettleInfo(w io.Writer, s *HTLCSettleInfo) error { if _, err := w.Write(s.Preimage[:]); err != nil { diff --git a/routing/control_tower.go b/routing/control_tower.go index 0590debee..b23b7df5c 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -9,6 +9,32 @@ import ( "github.com/lightningnetwork/lnd/queue" ) +// dbMPPayment is an interface derived from channeldb.MPPayment that is used by +// the payment lifecycle. +type dbMPPayment interface { + // GetState returns the current state of the payment. + GetState() *channeldb.MPPaymentState + + // Terminated returns true if the payment is in a final state. + Terminated() bool + + // GetStatus returns the current status of the payment. + GetStatus() channeldb.PaymentStatus + + // NeedWaitAttempts specifies whether the payment needs to wait for the + // outcome of an attempt. + NeedWaitAttempts() (bool, error) + + // GetHTLCs returns all HTLCs of this payment. + GetHTLCs() []channeldb.HTLCAttempt + + // InFlightHTLCs returns all HTLCs that are in flight. + InFlightHTLCs() []channeldb.HTLCAttempt + + // GetFailureReason returns the reason the payment failed. + GetFailureReason() *channeldb.FailureReason +} + // ControlTower tracks all outgoing payments made, whose primary purpose is to // prevent duplicate payments to the same payment hash. In production, a // persistent implementation is preferred so that tracking can survive across @@ -44,7 +70,7 @@ type ControlTower interface { // FetchPayment fetches the payment corresponding to the given payment // hash. - FetchPayment(paymentHash lntypes.Hash) (*channeldb.MPPayment, error) + FetchPayment(paymentHash lntypes.Hash) (dbMPPayment, error) // FailPayment transitions a payment into the Failed state, and records // the ultimate reason the payment failed. Note that this should only @@ -224,7 +250,7 @@ func (p *controlTower) FailAttempt(paymentHash lntypes.Hash, // FetchPayment fetches the payment corresponding to the given payment hash. func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) ( - *channeldb.MPPayment, error) { + dbMPPayment, error) { return p.db.FetchPayment(paymentHash) } diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 3681b647d..c60f72d37 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -130,9 +130,9 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { } } - require.Equalf(t, channeldb.StatusSucceeded, result.Status, + require.Equalf(t, channeldb.StatusSucceeded, result.GetStatus(), "subscriber %v failed, want %s, got %s", i, - channeldb.StatusSucceeded, result.Status) + channeldb.StatusSucceeded, result.GetStatus()) settle, _ := result.TerminalInfo() if settle.Preimage != preimg { @@ -259,7 +259,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { result1 := results[info1.PaymentIdentifier] require.Equal( - t, channeldb.StatusSucceeded, result1.Status, + t, channeldb.StatusSucceeded, result1.GetStatus(), "unexpected payment state payment 1", ) @@ -278,7 +278,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) { result2 := results[info2.PaymentIdentifier] require.Equal( - t, channeldb.StatusSucceeded, result2.Status, + t, channeldb.StatusSucceeded, result2.GetStatus(), "unexpected payment state payment 2", ) @@ -486,7 +486,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, } } - if result.Status == channeldb.StatusSucceeded { + if result.GetStatus() == channeldb.StatusSucceeded { t.Fatal("unexpected payment state") } @@ -511,9 +511,9 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt, len(result.HTLCs)) } - require.Equalf(t, channeldb.StatusFailed, result.Status, + require.Equalf(t, channeldb.StatusFailed, result.GetStatus(), "subscriber %v failed, want %s, got %s", i, - channeldb.StatusFailed, result.Status) + channeldb.StatusFailed, result.GetStatus()) if *result.FailureReason != channeldb.FailureReasonTimeout { t.Fatal("unexpected failure reason") diff --git a/routing/mock_test.go b/routing/mock_test.go index a4d51e164..c24ca2508 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -491,7 +491,7 @@ func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash, } func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) ( - *channeldb.MPPayment, error) { + dbMPPayment, error) { m.Lock() defer m.Unlock() @@ -750,10 +750,8 @@ func (m *mockControlTower) FailPayment(phash lntypes.Hash, } func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( - *channeldb.MPPayment, error) { + dbMPPayment, error) { - m.Lock() - defer m.Unlock() args := m.Called(phash) // Type assertion on nil will fail, so we check and return here. @@ -761,15 +759,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( return nil, args.Error(1) } - // Make a copy of the payment here to avoid data race. - p := args.Get(0).(*channeldb.MPPayment) - payment := &channeldb.MPPayment{ - Info: p.Info, - FailureReason: p.FailureReason, - } - payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs)) - copy(payment.HTLCs, p.HTLCs) - + payment := args.Get(0).(*mockMPPayment) return payment, args.Error(1) } @@ -794,6 +784,54 @@ func (m *mockControlTower) SubscribeAllPayments() ( return args.Get(0).(ControlTowerSubscriber), args.Error(1) } +type mockMPPayment struct { + mock.Mock +} + +var _ dbMPPayment = (*mockMPPayment)(nil) + +func (m *mockMPPayment) GetState() *channeldb.MPPaymentState { + args := m.Called() + return args.Get(0).(*channeldb.MPPaymentState) +} + +func (m *mockMPPayment) GetStatus() channeldb.PaymentStatus { + args := m.Called() + return args.Get(0).(channeldb.PaymentStatus) +} + +func (m *mockMPPayment) Terminated() bool { + args := m.Called() + + return args.Bool(0) +} + +func (m *mockMPPayment) NeedWaitAttempts() (bool, error) { + args := m.Called() + return args.Bool(0), args.Error(1) +} + +func (m *mockMPPayment) GetHTLCs() []channeldb.HTLCAttempt { + args := m.Called() + return args.Get(0).([]channeldb.HTLCAttempt) +} + +func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt { + args := m.Called() + return args.Get(0).([]channeldb.HTLCAttempt) +} + +func (m *mockMPPayment) GetFailureReason() *channeldb.FailureReason { + args := m.Called() + + reason := args.Get(0) + if reason == nil { + return nil + } + + return reason.(*channeldb.FailureReason) +} + type mockLink struct { htlcswitch.ChannelLink bandwidth lnwire.MilliSatoshi diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 51f314a4a..5c3d0e6c9 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -86,7 +86,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { // exitWithErr is a helper closure that logs and returns an error. exitWithErr := func(err error) ([32]byte, *route.Route, error) { log.Errorf("Payment %v with status=%v failed: %v", - p.identifier, payment.Status, err) + p.identifier, payment.GetStatus(), err) return [32]byte{}, nil, err } @@ -112,7 +112,7 @@ lifecycle: return exitWithErr(err) } - ps := payment.State + ps := payment.GetState() remainingFees := p.calcFeeBudget(ps.FeesPaid) log.Debugf("Payment %v in state terminate=%v, "+ @@ -127,7 +127,7 @@ lifecycle: if payment.Terminated() { // Find the first successful shard and return // the preimage and route. - for _, a := range payment.HTLCs { + for _, a := range payment.GetHTLCs() { if a.Settle == nil { continue } @@ -146,7 +146,7 @@ lifecycle: } // Payment failed. - return exitWithErr(*payment.FailureReason) + return exitWithErr(*payment.GetFailureReason()) } // If we either reached a terminal error condition (but had diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 63f1a80be..c1dfe4e44 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -791,12 +791,6 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, } } -func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt { - return channeldb.HTLCAttempt{ - HTLCAttemptInfo: makeAttemptInfo(total, total-fee), - } -} - func makeSettledAttempt(total, fee int, preimage lntypes.Preimage) channeldb.HTLCAttempt { diff --git a/routing/router_test.go b/routing/router_test.go index 6b80929d8..ce5925ef8 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1120,15 +1120,16 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { p, err := ctx.router.cfg.Control.FetchPayment(payHash) require.NoError(t, err) - require.Equal(t, 2, len(p.HTLCs), "expected two attempts") + htlcs := p.GetHTLCs() + require.Equal(t, 2, len(htlcs), "expected two attempts") // We expect the first attempt to have failed with a // TemporaryChannelFailure, the second with UnknownNextPeer. - msg := p.HTLCs[0].Failure.Message + msg := htlcs[0].Failure.Message _, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok, "unexpected fail message") - msg = p.HTLCs[1].Failure.Message + msg = htlcs[1].Failure.Message _, ok = msg.(*lnwire.FailUnknownNextPeer) require.True(t, ok, "unexpected fail message") @@ -3470,29 +3471,64 @@ func TestSendMPPaymentSucceed(t *testing.T) { sessionSource.On("NewPaymentSession", req).Return(session, nil) controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - // The following mocked methods are called inside resumePayment. Note - // that the payment object below will determine the state of the - // paymentLifecycle. - payment := &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, - } - controlTower.On("FetchPayment", identifier).Return(payment, nil) + // Mock the InFlightHTLCs. + var ( + htlcs []channeldb.HTLCAttempt + numAttempts atomic.Uint32 + ) + + // Make a mock MPPayment. + payment := &mockMPPayment{} + payment.On("InFlightHTLCs").Return(htlcs). + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}) + + // Mock FetchPayment to return the payment. + controlTower.On("FetchPayment", + identifier, + ).Return(payment, nil).Run(func(args mock.Arguments) { + // When number of attempts made is less than 4, we will mock + // the payment's methods to allow the lifecycle to continue. + if numAttempts.Load() < 4 { + payment.On("Terminated").Return(false).Times(2). + On("NeedWaitAttempts").Return(false, nil).Once() + return + } + + // Otherwise, terminate the lifecycle. + payment.On("Terminated").Return(true). + On("NeedWaitAttempts").Return(true, nil) + }) + + // Mock SettleAttempt. + preimage := lntypes.Preimage{1, 2, 3} + settledAttempt := makeSettledAttempt( + int(paymentAmt/4), 0, preimage, + ) + + controlTower.On("SettleAttempt", + identifier, mock.Anything, mock.Anything, + ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { + payment.On("GetHTLCs").Return( + []channeldb.HTLCAttempt{settledAttempt}, + ) + }) // Create a route that can send 1/4 of the total amount. This value // will be returned by calling RequestRoute. shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) require.NoError(t, err, "failed to create route") + session.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, ).Return(shard, nil) // Make a new htlc attempt with zero fee and append it to the payment's // HTLCs when calling RegisterAttempt. - activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) controlTower.On("RegisterAttempt", identifier, mock.Anything, ).Return(nil).Run(func(args mock.Arguments) { - payment.HTLCs = append(payment.HTLCs, activeAttempt) + // Increase the counter whenever an attempt is made. + numAttempts.Add(1) }) // Create a buffered chan and it will be returned by GetAttemptResult. @@ -3509,30 +3545,12 @@ func TestSendMPPaymentSucceed(t *testing.T) { payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, ).Return(nil) + missionControl.On("ReportPaymentSuccess", mock.Anything, mock.Anything, - ).Return(nil) - - // Mock SettleAttempt by changing one of the HTLCs to be settled. - preimage := lntypes.Preimage{1, 2, 3} - settledAttempt := makeSettledAttempt( - int(paymentAmt/4), 0, preimage, - ) - controlTower.On("SettleAttempt", - identifier, mock.Anything, mock.Anything, - ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mark the first - // active attempt settled and exit. - for i, attempt := range payment.HTLCs { - if attempt.Settle == nil { - attempt.Settle = &channeldb.HTLCSettleInfo{ - Preimage: preimage, - } - payment.HTLCs[i] = attempt - return - } - } + ).Return(nil).Run(func(args mock.Arguments) { }) + controlTower.On("DeleteFailedAttempts", identifier).Return(nil) // Call the actual method SendPayment on router. This is place inside a @@ -3565,6 +3583,7 @@ func TestSendMPPaymentSucceed(t *testing.T) { sessionSource.AssertExpectations(t) session.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if @@ -3639,13 +3658,34 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { sessionSource.On("NewPaymentSession", req).Return(session, nil) controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - // The following mocked methods are called inside resumePayment. Note - // that the payment object below will determine the state of the - // paymentLifecycle. - payment := &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, - } - controlTower.On("FetchPayment", identifier).Return(payment, nil) + // Mock the InFlightHTLCs. + var ( + htlcs []channeldb.HTLCAttempt + numAttempts atomic.Uint32 + failAttemptCount atomic.Uint32 + ) + + // Make a mock MPPayment. + payment := &mockMPPayment{} + payment.On("InFlightHTLCs").Return(htlcs). + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}) + + // Mock FetchPayment to return the payment. + controlTower.On("FetchPayment", + identifier, + ).Return(payment, nil).Run(func(args mock.Arguments) { + // When number of attempts made is less than 6, we will mock + // the payment's methods to allow the lifecycle to continue. + if numAttempts.Load() < 6 { + payment.On("Terminated").Return(false).Times(2). + On("NeedWaitAttempts").Return(false, nil).Once() + return + } + + // Otherwise, terminate the lifecycle. + payment.On("Terminated").Return(true). + On("NeedWaitAttempts").Return(true, nil) + }) // Create a route that can send 1/4 of the total amount. This value // will be returned by calling RequestRoute. @@ -3657,11 +3697,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { // Make a new htlc attempt with zero fee and append it to the payment's // HTLCs when calling RegisterAttempt. - activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) controlTower.On("RegisterAttempt", identifier, mock.Anything, ).Return(nil).Run(func(args mock.Arguments) { - payment.HTLCs = append(payment.HTLCs, activeAttempt) + // Increase the counter whenever an attempt is made. + numAttempts.Add(1) }) // Create a buffered chan and it will be returned by GetAttemptResult. @@ -3670,7 +3710,6 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { // We use the failAttemptCount to track how many attempts we want to // fail. Each time the following mock method is called, the count gets // updated. - failAttemptCount := 0 payer.On("GetAttemptResult", mock.Anything, identifier, mock.Anything, ).Run(func(args mock.Arguments) { @@ -3678,11 +3717,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { // the read-only chan. // Update the counter. - failAttemptCount++ + failAttemptCount.Add(1) // We will make the first two attempts failed with temporary // error. - if failAttemptCount <= 2 { + if failAttemptCount.Load() <= 2 { payer.resultChan <- &htlcswitch.PaymentResult{ Error: htlcswitch.NewForwardingError( &lnwire.FailTemporaryChannelFailure{}, @@ -3700,20 +3739,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { var failedAttempt channeldb.HTLCAttempt controlTower.On("FailAttempt", identifier, mock.Anything, mock.Anything, - ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mark the first - // active attempt as failed and exit. - for i, attempt := range payment.HTLCs { - if attempt.Settle != nil || attempt.Failure != nil { - continue - } - - attempt.Failure = &channeldb.HTLCFailInfo{} - failedAttempt = attempt - payment.HTLCs[i] = attempt - return - } - }) + ).Return(&failedAttempt, nil) // Setup ReportPaymentFail to return nil reason and error so the // payment won't fail. @@ -3737,20 +3763,13 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { controlTower.On("SettleAttempt", identifier, mock.Anything, mock.Anything, ).Return(&settledAttempt, nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mark the first - // active attempt settled and exit. - for i, attempt := range payment.HTLCs { - if attempt.Settle != nil || attempt.Failure != nil { - continue - } - - attempt.Settle = &channeldb.HTLCSettleInfo{ - Preimage: preimage, - } - payment.HTLCs[i] = attempt - return - } + // Whenever this method is invoked, we will mock the payment's + // GetHTLCs() to return the settled htlc. + payment.On("GetHTLCs").Return( + []channeldb.HTLCAttempt{settledAttempt}, + ) }) + controlTower.On("DeleteFailedAttempts", identifier).Return(nil) // Call the actual method SendPayment on router. This is place inside a @@ -3779,6 +3798,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { sessionSource.AssertExpectations(t) session.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendMPPaymentFailed tests that when one of the shard fails with a @@ -3853,12 +3873,18 @@ func TestSendMPPaymentFailed(t *testing.T) { sessionSource.On("NewPaymentSession", req).Return(session, nil) controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - // The following mocked methods are called inside resumePayment. Note - // that the payment object below will determine the state of the - // paymentLifecycle. - payment := &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, - } + // Mock the InFlightHTLCs. + var htlcs []channeldb.HTLCAttempt + + // Make a mock MPPayment. + payment := &mockMPPayment{} + payment.On("InFlightHTLCs").Return(htlcs). + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). + On("GetStatus").Return(channeldb.StatusInFlight). + On("Terminated").Return(false). + On("NeedWaitAttempts").Return(false, nil) + + // Mock FetchPayment to return the payment. controlTower.On("FetchPayment", identifier).Return(payment, nil) // Create a route that can send 1/4 of the total amount. This value @@ -3871,12 +3897,9 @@ func TestSendMPPaymentFailed(t *testing.T) { // Make a new htlc attempt with zero fee and append it to the payment's // HTLCs when calling RegisterAttempt. - activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) controlTower.On("RegisterAttempt", identifier, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - payment.HTLCs = append(payment.HTLCs, activeAttempt) - }) + ).Return(nil) // Create a buffered chan and it will be returned by GetAttemptResult. payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) @@ -3918,43 +3941,24 @@ func TestSendMPPaymentFailed(t *testing.T) { var failedAttempt channeldb.HTLCAttempt controlTower.On("FailAttempt", identifier, mock.Anything, mock.Anything, - ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mark the first - // active attempt as failed and exit. - for i, attempt := range payment.HTLCs { - if attempt.Settle != nil || attempt.Failure != nil { - continue - } - - attempt.Failure = &channeldb.HTLCFailInfo{} - failedAttempt = attempt - payment.HTLCs[i] = attempt - return - } - }) + ).Return(&failedAttempt, nil) // Setup ReportPaymentFail to return nil reason and error so the // payment won't fail. - var called bool failureReason := channeldb.FailureReasonPaymentDetails missionControl.On("ReportPaymentFail", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&failureReason, nil).Run(func(args mock.Arguments) { - // We only return the terminal error once, thus when the method - // is called, we will return it with a nil error. - if called { - args[0] = nil - return - } - - // If it's the first time calling this method, we will return a - // terminal error. - payment.FailureReason = &failureReason - called = true - }) + ).Return(&failureReason, nil) // Simple mocking the rest. - controlTower.On("FailPayment", identifier, failureReason).Return(nil) + controlTower.On("FailPayment", + identifier, failureReason, + ).Return(nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mock the payment's + // Terminated() to be True. + payment.On("Terminated").Return(true) + }) + payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, ).Return(nil) @@ -3985,6 +3989,7 @@ func TestSendMPPaymentFailed(t *testing.T) { sessionSource.AssertExpectations(t) session.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in @@ -4059,12 +4064,18 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { sessionSource.On("NewPaymentSession", req).Return(session, nil) controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - // The following mocked methods are called inside resumePayment. Note - // that the payment object below will determine the state of the - // paymentLifecycle. - payment := &channeldb.MPPayment{ - Info: &channeldb.PaymentCreationInfo{Value: paymentAmt}, - } + // Mock the InFlightHTLCs. + var htlcs []channeldb.HTLCAttempt + + // Make a mock MPPayment. + payment := &mockMPPayment{} + payment.On("InFlightHTLCs").Return(htlcs). + On("GetStatus").Return(channeldb.StatusInFlight). + On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). + On("Terminated").Return(false). + On("NeedWaitAttempts").Return(false, nil) + + // Mock FetchPayment to return the payment. controlTower.On("FetchPayment", identifier).Return(payment, nil) // Create a route that can send 1/4 of the total amount. This value @@ -4077,12 +4088,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { // Make a new htlc attempt with zero fee and append it to the payment's // HTLCs when calling RegisterAttempt. - activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0) controlTower.On("RegisterAttempt", identifier, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - payment.HTLCs = append(payment.HTLCs, activeAttempt) - }) + ).Return(nil) // Create a buffered chan and it will be returned by GetAttemptResult. payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) @@ -4130,28 +4138,25 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { var failedAttempt channeldb.HTLCAttempt controlTower.On("FailAttempt", identifier, mock.Anything, mock.Anything, - ).Return(&failedAttempt, nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mark the first - // active attempt as failed and exit. - failedAttempt = payment.HTLCs[0] - failedAttempt.Failure = &channeldb.HTLCFailInfo{} - payment.HTLCs[0] = failedAttempt - }) + ).Return(&failedAttempt, nil) // Setup ReportPaymentFail to return nil reason and error so the // payment won't fail. failureReason := channeldb.FailureReasonPaymentDetails - cntReportPaymentFail := 0 missionControl.On("ReportPaymentFail", mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&failureReason, nil).Run(func(args mock.Arguments) { - payment.FailureReason = &failureReason - cntReportPaymentFail++ - }) + ).Return(&failureReason, nil) // Simple mocking the rest. cntFail := 0 - controlTower.On("FailPayment", identifier, failureReason).Return(nil) + controlTower.On("FailPayment", + identifier, failureReason, + ).Return(nil).Run(func(args mock.Arguments) { + // Whenever this method is invoked, we will mock the payment's + // Terminated() to be True. + payment.On("Terminated").Return(true) + }) + payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, ).Return(nil).Run(func(args mock.Arguments) { @@ -4179,7 +4184,6 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { require.Error(t, err, "expected send payment error") require.EqualValues(t, [32]byte{}, p, "preimage not match") require.GreaterOrEqual(t, getPaymentResultCnt, 1) - require.Equal(t, getPaymentResultCnt, cntReportPaymentFail) require.Equal(t, getPaymentResultCnt, cntFail) controlTower.AssertExpectations(t) @@ -4187,6 +4191,7 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { sessionSource.AssertExpectations(t) session.AssertExpectations(t) missionControl.AssertExpectations(t) + payment.AssertExpectations(t) } // TestBlockDifferenceFix tests if when the router is behind on blocks, the