From 9a0db291b5ba147b6eb4a86836377660ec343085 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Wed, 27 Sep 2023 09:36:55 -0400 Subject: [PATCH] routing: fix tests after main refactor Delete TestSendMPPaymentFailedWithShardsInFlight as it seems to be the same test as TestSendMPPaymentFailed. --- routing/payment_lifecycle_test.go | 2 +- routing/router_test.go | 264 +++++------------------------- 2 files changed, 46 insertions(+), 220 deletions(-) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 8e1389687..eca18305f 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -732,7 +732,7 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, select { case err := <-paymentResult: - require.Equal(t, test.paymentErr, err) + require.ErrorIs(t, err, test.paymentErr) case <-time.After(stepTimeout): fatal("got no payment result") diff --git a/routing/router_test.go b/routing/router_test.go index f18e80024..6adefe12a 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3874,18 +3874,43 @@ func TestSendMPPaymentFailed(t *testing.T) { controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) // Mock the InFlightHTLCs. - var htlcs []channeldb.HTLCAttempt + var ( + htlcs []channeldb.HTLCAttempt + numAttempts atomic.Uint32 + failAttemptCount atomic.Uint32 + failed atomic.Bool + numParts = uint32(4) + ) // Make a mock MPPayment. payment := &mockMPPayment{} - payment.On("InFlightHTLCs").Return(htlcs). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("GetStatus").Return(channeldb.StatusInFlight). - On("Terminated").Return(false). - On("NeedWaitAttempts").Return(false, nil) + payment.On("InFlightHTLCs").Return(htlcs).Once() + payment.On("GetStatus").Return(channeldb.StatusInFlight).Once() + payment.On("GetState").Return(&channeldb.MPPaymentState{}) + controlTower.On("FetchPayment", identifier).Return(payment, nil).Once() - // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", identifier).Return(payment, nil) + // Mock the sequential FetchPayment to return the payment. + controlTower.On("FetchPayment", identifier).Return(payment, nil).Run( + func(_ mock.Arguments) { + // We want to at least send out all parts in order to + // wait for them later. + if numAttempts.Load() < numParts { + payment.On("Terminated").Return(false).Times(2). + On("NeedWaitAttempts").Return(false, nil).Once() + return + } + + // Wait if the payment wasn't failed yet. + if !failed.Load() { + payment.On("Terminated").Return(false).Times(2). + On("NeedWaitAttempts").Return(true, nil).Once() + + return + } + + payment.On("Terminated").Return(true). + On("GetHTLCs").Return(htlcs).Once() + }) // Create a route that can send 1/4 of the total amount. This value // will be returned by calling RequestRoute. @@ -3899,7 +3924,9 @@ func TestSendMPPaymentFailed(t *testing.T) { // HTLCs when calling RegisterAttempt. controlTower.On("RegisterAttempt", identifier, mock.Anything, - ).Return(nil) + ).Return(nil).Run(func(args mock.Arguments) { + numAttempts.Add(1) + }) // Create a buffered chan and it will be returned by GetAttemptResult. payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) @@ -3907,18 +3934,17 @@ func TestSendMPPaymentFailed(t *testing.T) { // We use the failAttemptCount to track how many attempts we want to // fail. Each time the following mock method is called, the count gets // updated. - failAttemptCount := 0 payer.On("GetAttemptResult", mock.Anything, identifier, mock.Anything, - ).Run(func(args mock.Arguments) { + ).Run(func(_ mock.Arguments) { // Before the mock method is returned, we send the result to // the read-only chan. // Update the counter. - failAttemptCount++ + failAttemptCount.Add(1) // We fail the first attempt with terminal error. - if failAttemptCount == 1 { + if failAttemptCount.Load() == 1 { payer.resultChan <- &htlcswitch.PaymentResult{ Error: htlcswitch.NewForwardingError( &lnwire.FailIncorrectDetails{}, @@ -3953,12 +3979,13 @@ func TestSendMPPaymentFailed(t *testing.T) { // Simple mocking the rest. controlTower.On("FailPayment", identifier, failureReason, - ).Return(nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mock the payment's - // Terminated() to be True. - payment.On("Terminated").Return(true) + ).Return(nil).Run(func(_ mock.Arguments) { + failed.Store(true) }) + // Mock the payment to return the failure reason. + payment.On("GetFailureReason").Return(&failureReason) + payer.On("SendHTLC", mock.Anything, mock.Anything, mock.Anything, ).Return(nil) @@ -3983,208 +4010,7 @@ func TestSendMPPaymentFailed(t *testing.T) { // methods are called as expected. require.Error(t, err, "expected send payment error") require.EqualValues(t, [32]byte{}, p, "preimage not match") - - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - sessionSource.AssertExpectations(t) - session.AssertExpectations(t) - missionControl.AssertExpectations(t) - payment.AssertExpectations(t) -} - -// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in -// terminal state, even if we have shards in flight, we still fail the payment -// and exit. This test mainly focuses on testing the logic of the method -// resumePayment is implemented as expected. -func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { - const startingBlockHeight = 101 - - // Create mockers to initialize the router. - controlTower := &mockControlTower{} - sessionSource := &mockPaymentSessionSource{} - missionControl := &mockMissionControl{} - payer := &mockPaymentAttemptDispatcher{} - chain := newMockChain(startingBlockHeight) - chainView := newMockChainView(chain) - testGraph := createDummyTestGraph(t) - - // Define the behavior of the mockers to the point where we can - // successfully start the router. - controlTower.On("FetchInFlightPayments").Return( - []*channeldb.MPPayment{}, nil, - ) - payer.On("CleanStore", mock.Anything).Return(nil) - - // Create and start the router. - router, err := New(Config{ - Control: controlTower, - SessionSource: sessionSource, - MissionControl: missionControl, - Payer: payer, - - // TODO(yy): create new mocks for the chain and chainview. - Chain: chain, - ChainView: chainView, - - // TODO(yy): mock the graph once it's changed into interface. - Graph: testGraph.graph, - - Clock: clock.NewTestClock(time.Unix(1, 0)), - GraphPruneInterval: time.Hour * 2, - NextPaymentID: func() (uint64, error) { - next := atomic.AddUint64(&uniquePaymentID, 1) - return next, nil - }, - - IsAlias: func(scid lnwire.ShortChannelID) bool { - return false - }, - }) - require.NoError(t, err, "failed to create router") - - // Make sure the router can start and stop without error. - require.NoError(t, router.Start(), "router failed to start") - t.Cleanup(func() { - require.NoError(t, router.Stop(), "router failed to stop") - }) - - // Once the router is started, check that the mocked methods are called - // as expected. - controlTower.AssertExpectations(t) - payer.AssertExpectations(t) - - // Mock the methods to the point where we are inside the function - // resumePayment. - paymentAmt := lnwire.MilliSatoshi(10000) - req := createDummyLightningPayment( - t, testGraph.aliasMap["c"], paymentAmt, - ) - identifier := lntypes.Hash(req.Identifier()) - session := &mockPaymentSession{} - sessionSource.On("NewPaymentSession", req).Return(session, nil) - controlTower.On("InitPayment", identifier, mock.Anything).Return(nil) - - // Mock the InFlightHTLCs. - var htlcs []channeldb.HTLCAttempt - - // Make a mock MPPayment. - payment := &mockMPPayment{} - payment.On("InFlightHTLCs").Return(htlcs). - On("GetStatus").Return(channeldb.StatusInFlight). - On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}). - On("Terminated").Return(false). - On("NeedWaitAttempts").Return(false, nil) - - // Mock FetchPayment to return the payment. - controlTower.On("FetchPayment", identifier).Return(payment, nil) - - // Create a route that can send 1/4 of the total amount. This value - // will be returned by calling RequestRoute. - shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap) - require.NoError(t, err, "failed to create route") - session.On("RequestRoute", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(shard, nil) - - // Make a new htlc attempt with zero fee and append it to the payment's - // HTLCs when calling RegisterAttempt. - controlTower.On("RegisterAttempt", - identifier, mock.Anything, - ).Return(nil) - - // Create a buffered chan and it will be returned by GetAttemptResult. - payer.resultChan = make(chan *htlcswitch.PaymentResult, 10) - - // We use the getPaymentResultCnt to track how many times we called - // GetAttemptResult. As shard launch is sequential, and we fail the - // first shard that calls GetAttemptResult, we may end up with different - // counts since the lifecycle itself is asynchronous. To avoid flakes - // due to this undeterminsitic behavior, we'll compare the final - // getPaymentResultCnt with other counters to create a final test - // expectation. - getPaymentResultCnt := 0 - payer.On("GetAttemptResult", - mock.Anything, identifier, mock.Anything, - ).Run(func(args mock.Arguments) { - // Before the mock method is returned, we send the result to - // the read-only chan. - - // Update the counter. - getPaymentResultCnt++ - - // We fail the first attempt with terminal error. - if getPaymentResultCnt == 1 { - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: htlcswitch.NewForwardingError( - &lnwire.FailIncorrectDetails{}, - 1, - ), - } - return - } - - // For the rest of the attempts we'll simulate that a network - // result update_fail_htlc has been received. This way the - // payment will fail cleanly. - payer.resultChan <- &htlcswitch.PaymentResult{ - Error: htlcswitch.NewForwardingError( - &lnwire.FailTemporaryChannelFailure{}, - 1, - ), - } - }) - - // Mock the FailAttempt method to fail (at least once). - var failedAttempt channeldb.HTLCAttempt - controlTower.On("FailAttempt", - identifier, mock.Anything, mock.Anything, - ).Return(&failedAttempt, nil) - - // Setup ReportPaymentFail to return nil reason and error so the - // payment won't fail. - failureReason := channeldb.FailureReasonPaymentDetails - missionControl.On("ReportPaymentFail", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(&failureReason, nil) - - // Simple mocking the rest. - cntFail := 0 - controlTower.On("FailPayment", - identifier, failureReason, - ).Return(nil).Run(func(args mock.Arguments) { - // Whenever this method is invoked, we will mock the payment's - // Terminated() to be True. - payment.On("Terminated").Return(true) - }) - - payer.On("SendHTLC", - mock.Anything, mock.Anything, mock.Anything, - ).Return(nil).Run(func(args mock.Arguments) { - cntFail++ - }) - - // Call the actual method SendPayment on router. This is place inside a - // goroutine so we can set a timeout for the whole test, in case - // anything goes wrong and the test never finishes. - done := make(chan struct{}) - var p lntypes.Hash - go func() { - p, _, err = router.SendPayment(req) - close(done) - }() - - select { - case <-done: - case <-time.After(testTimeout): - t.Fatalf("SendPayment didn't exit") - } - - // Finally, validate the returned values and check that the mock - // methods are called as expected. - require.Error(t, err, "expected send payment error") - require.EqualValues(t, [32]byte{}, p, "preimage not match") - require.GreaterOrEqual(t, getPaymentResultCnt, 1) - require.Equal(t, getPaymentResultCnt, cntFail) + require.GreaterOrEqual(t, failAttemptCount.Load(), uint32(1)) controlTower.AssertExpectations(t) payer.AssertExpectations(t)