mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-13 11:09:23 +01:00
routing: fix tests after main refactor
Delete TestSendMPPaymentFailedWithShardsInFlight as it seems to be the same test as TestSendMPPaymentFailed.
This commit is contained in:
parent
173900c8dc
commit
9a0db291b5
2 changed files with 46 additions and 220 deletions
|
@ -732,7 +732,7 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
|
|||
|
||||
select {
|
||||
case err := <-paymentResult:
|
||||
require.Equal(t, test.paymentErr, err)
|
||||
require.ErrorIs(t, err, test.paymentErr)
|
||||
|
||||
case <-time.After(stepTimeout):
|
||||
fatal("got no payment result")
|
||||
|
|
|
@ -3874,18 +3874,43 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// Mock the InFlightHTLCs.
|
||||
var htlcs []channeldb.HTLCAttempt
|
||||
var (
|
||||
htlcs []channeldb.HTLCAttempt
|
||||
numAttempts atomic.Uint32
|
||||
failAttemptCount atomic.Uint32
|
||||
failed atomic.Bool
|
||||
numParts = uint32(4)
|
||||
)
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("GetStatus").Return(channeldb.StatusInFlight).
|
||||
On("Terminated").Return(false).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
payment.On("InFlightHTLCs").Return(htlcs).Once()
|
||||
payment.On("GetStatus").Return(channeldb.StatusInFlight).Once()
|
||||
payment.On("GetState").Return(&channeldb.MPPaymentState{})
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
// Mock the sequential FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil).Run(
|
||||
func(_ mock.Arguments) {
|
||||
// We want to at least send out all parts in order to
|
||||
// wait for them later.
|
||||
if numAttempts.Load() < numParts {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Wait if the payment wasn't failed yet.
|
||||
if !failed.Load() {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(true, nil).Once()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
payment.On("Terminated").Return(true).
|
||||
On("GetHTLCs").Return(htlcs).Once()
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
|
@ -3899,7 +3924,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
// HTLCs when calling RegisterAttempt.
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil)
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
numAttempts.Add(1)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
|
@ -3907,18 +3934,17 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
// We use the failAttemptCount to track how many attempts we want to
|
||||
// fail. Each time the following mock method is called, the count gets
|
||||
// updated.
|
||||
failAttemptCount := 0
|
||||
payer.On("GetAttemptResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
).Run(func(_ mock.Arguments) {
|
||||
// Before the mock method is returned, we send the result to
|
||||
// the read-only chan.
|
||||
|
||||
// Update the counter.
|
||||
failAttemptCount++
|
||||
failAttemptCount.Add(1)
|
||||
|
||||
// We fail the first attempt with terminal error.
|
||||
if failAttemptCount == 1 {
|
||||
if failAttemptCount.Load() == 1 {
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailIncorrectDetails{},
|
||||
|
@ -3953,12 +3979,13 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
// Simple mocking the rest.
|
||||
controlTower.On("FailPayment",
|
||||
identifier, failureReason,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// Terminated() to be True.
|
||||
payment.On("Terminated").Return(true)
|
||||
).Return(nil).Run(func(_ mock.Arguments) {
|
||||
failed.Store(true)
|
||||
})
|
||||
|
||||
// Mock the payment to return the failure reason.
|
||||
payment.On("GetFailureReason").Return(&failureReason)
|
||||
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
@ -3983,208 +4010,7 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
// methods are called as expected.
|
||||
require.Error(t, err, "expected send payment error")
|
||||
require.EqualValues(t, [32]byte{}, p, "preimage not match")
|
||||
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
|
||||
// terminal state, even if we have shards in flight, we still fail the payment
|
||||
// and exit. This test mainly focuses on testing the logic of the method
|
||||
// resumePayment is implemented as expected.
|
||||
func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
||||
const startingBlockHeight = 101
|
||||
|
||||
// Create mockers to initialize the router.
|
||||
controlTower := &mockControlTower{}
|
||||
sessionSource := &mockPaymentSessionSource{}
|
||||
missionControl := &mockMissionControl{}
|
||||
payer := &mockPaymentAttemptDispatcher{}
|
||||
chain := newMockChain(startingBlockHeight)
|
||||
chainView := newMockChainView(chain)
|
||||
testGraph := createDummyTestGraph(t)
|
||||
|
||||
// Define the behavior of the mockers to the point where we can
|
||||
// successfully start the router.
|
||||
controlTower.On("FetchInFlightPayments").Return(
|
||||
[]*channeldb.MPPayment{}, nil,
|
||||
)
|
||||
payer.On("CleanStore", mock.Anything).Return(nil)
|
||||
|
||||
// Create and start the router.
|
||||
router, err := New(Config{
|
||||
Control: controlTower,
|
||||
SessionSource: sessionSource,
|
||||
MissionControl: missionControl,
|
||||
Payer: payer,
|
||||
|
||||
// TODO(yy): create new mocks for the chain and chainview.
|
||||
Chain: chain,
|
||||
ChainView: chainView,
|
||||
|
||||
// TODO(yy): mock the graph once it's changed into interface.
|
||||
Graph: testGraph.graph,
|
||||
|
||||
Clock: clock.NewTestClock(time.Unix(1, 0)),
|
||||
GraphPruneInterval: time.Hour * 2,
|
||||
NextPaymentID: func() (uint64, error) {
|
||||
next := atomic.AddUint64(&uniquePaymentID, 1)
|
||||
return next, nil
|
||||
},
|
||||
|
||||
IsAlias: func(scid lnwire.ShortChannelID) bool {
|
||||
return false
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "failed to create router")
|
||||
|
||||
// Make sure the router can start and stop without error.
|
||||
require.NoError(t, router.Start(), "router failed to start")
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, router.Stop(), "router failed to stop")
|
||||
})
|
||||
|
||||
// Once the router is started, check that the mocked methods are called
|
||||
// as expected.
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
|
||||
// Mock the methods to the point where we are inside the function
|
||||
// resumePayment.
|
||||
paymentAmt := lnwire.MilliSatoshi(10000)
|
||||
req := createDummyLightningPayment(
|
||||
t, testGraph.aliasMap["c"], paymentAmt,
|
||||
)
|
||||
identifier := lntypes.Hash(req.Identifier())
|
||||
session := &mockPaymentSession{}
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// Mock the InFlightHTLCs.
|
||||
var htlcs []channeldb.HTLCAttempt
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetStatus").Return(channeldb.StatusInFlight).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("Terminated").Return(false).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
session.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(shard, nil)
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
|
||||
// We use the getPaymentResultCnt to track how many times we called
|
||||
// GetAttemptResult. As shard launch is sequential, and we fail the
|
||||
// first shard that calls GetAttemptResult, we may end up with different
|
||||
// counts since the lifecycle itself is asynchronous. To avoid flakes
|
||||
// due to this undeterminsitic behavior, we'll compare the final
|
||||
// getPaymentResultCnt with other counters to create a final test
|
||||
// expectation.
|
||||
getPaymentResultCnt := 0
|
||||
payer.On("GetAttemptResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
// Before the mock method is returned, we send the result to
|
||||
// the read-only chan.
|
||||
|
||||
// Update the counter.
|
||||
getPaymentResultCnt++
|
||||
|
||||
// We fail the first attempt with terminal error.
|
||||
if getPaymentResultCnt == 1 {
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailIncorrectDetails{},
|
||||
1,
|
||||
),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For the rest of the attempts we'll simulate that a network
|
||||
// result update_fail_htlc has been received. This way the
|
||||
// payment will fail cleanly.
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailTemporaryChannelFailure{},
|
||||
1,
|
||||
),
|
||||
}
|
||||
})
|
||||
|
||||
// Mock the FailAttempt method to fail (at least once).
|
||||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil)
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
failureReason := channeldb.FailureReasonPaymentDetails
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(&failureReason, nil)
|
||||
|
||||
// Simple mocking the rest.
|
||||
cntFail := 0
|
||||
controlTower.On("FailPayment",
|
||||
identifier, failureReason,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// Terminated() to be True.
|
||||
payment.On("Terminated").Return(true)
|
||||
})
|
||||
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
cntFail++
|
||||
})
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
// goroutine so we can set a timeout for the whole test, in case
|
||||
// anything goes wrong and the test never finishes.
|
||||
done := make(chan struct{})
|
||||
var p lntypes.Hash
|
||||
go func() {
|
||||
p, _, err = router.SendPayment(req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(testTimeout):
|
||||
t.Fatalf("SendPayment didn't exit")
|
||||
}
|
||||
|
||||
// Finally, validate the returned values and check that the mock
|
||||
// methods are called as expected.
|
||||
require.Error(t, err, "expected send payment error")
|
||||
require.EqualValues(t, [32]byte{}, p, "preimage not match")
|
||||
require.GreaterOrEqual(t, getPaymentResultCnt, 1)
|
||||
require.Equal(t, getPaymentResultCnt, cntFail)
|
||||
require.GreaterOrEqual(t, failAttemptCount.Load(), uint32(1))
|
||||
|
||||
controlTower.AssertExpectations(t)
|
||||
payer.AssertExpectations(t)
|
||||
|
|
Loading…
Add table
Reference in a new issue