routing: fix tests after main refactor

Delete TestSendMPPaymentFailedWithShardsInFlight as it seems to be the
same test as TestSendMPPaymentFailed.
This commit is contained in:
bitromortac 2023-09-27 09:36:55 -04:00 committed by yyforyongyu
parent 173900c8dc
commit 9a0db291b5
No known key found for this signature in database
GPG key ID: 9BCD95C4FF296868
2 changed files with 46 additions and 220 deletions

View file

@ -732,7 +732,7 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
select {
case err := <-paymentResult:
require.Equal(t, test.paymentErr, err)
require.ErrorIs(t, err, test.paymentErr)
case <-time.After(stepTimeout):
fatal("got no payment result")

View file

@ -3874,18 +3874,43 @@ func TestSendMPPaymentFailed(t *testing.T) {
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// Mock the InFlightHTLCs.
var htlcs []channeldb.HTLCAttempt
var (
htlcs []channeldb.HTLCAttempt
numAttempts atomic.Uint32
failAttemptCount atomic.Uint32
failed atomic.Bool
numParts = uint32(4)
)
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("GetStatus").Return(channeldb.StatusInFlight).
On("Terminated").Return(false).
On("NeedWaitAttempts").Return(false, nil)
payment.On("InFlightHTLCs").Return(htlcs).Once()
payment.On("GetStatus").Return(channeldb.StatusInFlight).Once()
payment.On("GetState").Return(&channeldb.MPPaymentState{})
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Mock the sequential FetchPayment to return the payment.
controlTower.On("FetchPayment", identifier).Return(payment, nil).Run(
func(_ mock.Arguments) {
// We want to at least send out all parts in order to
// wait for them later.
if numAttempts.Load() < numParts {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
return
}
// Wait if the payment wasn't failed yet.
if !failed.Load() {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(true, nil).Once()
return
}
payment.On("Terminated").Return(true).
On("GetHTLCs").Return(htlcs).Once()
})
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
@ -3899,7 +3924,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
// HTLCs when calling RegisterAttempt.
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil)
).Return(nil).Run(func(args mock.Arguments) {
numAttempts.Add(1)
})
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
@ -3907,18 +3934,17 @@ func TestSendMPPaymentFailed(t *testing.T) {
// We use the failAttemptCount to track how many attempts we want to
// fail. Each time the following mock method is called, the count gets
// updated.
failAttemptCount := 0
payer.On("GetAttemptResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
).Run(func(_ mock.Arguments) {
// Before the mock method is returned, we send the result to
// the read-only chan.
// Update the counter.
failAttemptCount++
failAttemptCount.Add(1)
// We fail the first attempt with terminal error.
if failAttemptCount == 1 {
if failAttemptCount.Load() == 1 {
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailIncorrectDetails{},
@ -3953,12 +3979,13 @@ func TestSendMPPaymentFailed(t *testing.T) {
// Simple mocking the rest.
controlTower.On("FailPayment",
identifier, failureReason,
).Return(nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mock the payment's
// Terminated() to be True.
payment.On("Terminated").Return(true)
).Return(nil).Run(func(_ mock.Arguments) {
failed.Store(true)
})
// Mock the payment to return the failure reason.
payment.On("GetFailureReason").Return(&failureReason)
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
@ -3983,208 +4010,7 @@ func TestSendMPPaymentFailed(t *testing.T) {
// methods are called as expected.
require.Error(t, err, "expected send payment error")
require.EqualValues(t, [32]byte{}, p, "preimage not match")
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
// terminal state, even if we have shards in flight, we still fail the payment
// and exit. This test mainly focuses on testing the logic of the method
// resumePayment is implemented as expected.
func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
const startingBlockHeight = 101
// Create mockers to initialize the router.
controlTower := &mockControlTower{}
sessionSource := &mockPaymentSessionSource{}
missionControl := &mockMissionControl{}
payer := &mockPaymentAttemptDispatcher{}
chain := newMockChain(startingBlockHeight)
chainView := newMockChainView(chain)
testGraph := createDummyTestGraph(t)
// Define the behavior of the mockers to the point where we can
// successfully start the router.
controlTower.On("FetchInFlightPayments").Return(
[]*channeldb.MPPayment{}, nil,
)
payer.On("CleanStore", mock.Anything).Return(nil)
// Create and start the router.
router, err := New(Config{
Control: controlTower,
SessionSource: sessionSource,
MissionControl: missionControl,
Payer: payer,
// TODO(yy): create new mocks for the chain and chainview.
Chain: chain,
ChainView: chainView,
// TODO(yy): mock the graph once it's changed into interface.
Graph: testGraph.graph,
Clock: clock.NewTestClock(time.Unix(1, 0)),
GraphPruneInterval: time.Hour * 2,
NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1)
return next, nil
},
IsAlias: func(scid lnwire.ShortChannelID) bool {
return false
},
})
require.NoError(t, err, "failed to create router")
// Make sure the router can start and stop without error.
require.NoError(t, router.Start(), "router failed to start")
t.Cleanup(func() {
require.NoError(t, router.Stop(), "router failed to stop")
})
// Once the router is started, check that the mocked methods are called
// as expected.
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)
// Mock the methods to the point where we are inside the function
// resumePayment.
paymentAmt := lnwire.MilliSatoshi(10000)
req := createDummyLightningPayment(
t, testGraph.aliasMap["c"], paymentAmt,
)
identifier := lntypes.Hash(req.Identifier())
session := &mockPaymentSession{}
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// Mock the InFlightHTLCs.
var htlcs []channeldb.HTLCAttempt
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetStatus").Return(channeldb.StatusInFlight).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("Terminated").Return(false).
On("NeedWaitAttempts").Return(false, nil)
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
require.NoError(t, err, "failed to create route")
session.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(shard, nil)
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
// We use the getPaymentResultCnt to track how many times we called
// GetAttemptResult. As shard launch is sequential, and we fail the
// first shard that calls GetAttemptResult, we may end up with different
// counts since the lifecycle itself is asynchronous. To avoid flakes
// due to this undeterminsitic behavior, we'll compare the final
// getPaymentResultCnt with other counters to create a final test
// expectation.
getPaymentResultCnt := 0
payer.On("GetAttemptResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
// Before the mock method is returned, we send the result to
// the read-only chan.
// Update the counter.
getPaymentResultCnt++
// We fail the first attempt with terminal error.
if getPaymentResultCnt == 1 {
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailIncorrectDetails{},
1,
),
}
return
}
// For the rest of the attempts we'll simulate that a network
// result update_fail_htlc has been received. This way the
// payment will fail cleanly.
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{},
1,
),
}
})
// Mock the FailAttempt method to fail (at least once).
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil)
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
failureReason := channeldb.FailureReasonPaymentDetails
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(&failureReason, nil)
// Simple mocking the rest.
cntFail := 0
controlTower.On("FailPayment",
identifier, failureReason,
).Return(nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mock the payment's
// Terminated() to be True.
payment.On("Terminated").Return(true)
})
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
cntFail++
})
// Call the actual method SendPayment on router. This is place inside a
// goroutine so we can set a timeout for the whole test, in case
// anything goes wrong and the test never finishes.
done := make(chan struct{})
var p lntypes.Hash
go func() {
p, _, err = router.SendPayment(req)
close(done)
}()
select {
case <-done:
case <-time.After(testTimeout):
t.Fatalf("SendPayment didn't exit")
}
// Finally, validate the returned values and check that the mock
// methods are called as expected.
require.Error(t, err, "expected send payment error")
require.EqualValues(t, [32]byte{}, p, "preimage not match")
require.GreaterOrEqual(t, getPaymentResultCnt, 1)
require.Equal(t, getPaymentResultCnt, cntFail)
require.GreaterOrEqual(t, failAttemptCount.Load(), uint32(1))
controlTower.AssertExpectations(t)
payer.AssertExpectations(t)