routing+channeldb: make MPPayment into an interface

This commit turns `MPPayment` into an interface inside `routing`. Having
this interface gives us the benefit to write more granular unit tests
inside payment lifecycle. As seen from the modified unit tests, several
hacky ways of testing the `SendPayment` method is now replaced by a mock
over `MPPayment`.
This commit is contained in:
yyforyongyu 2023-02-09 12:51:43 +08:00 committed by Olaoluwa Osuntokun
parent c412ab5ccb
commit 34d0e5d4c5
8 changed files with 257 additions and 167 deletions

View file

@ -268,3 +268,10 @@ issues:
# if the returned value doesn't match the type, so there's no need to
# check the convert.
- forcetypeassert
- path: mock*
linters:
# forcetypeassert is skipped for the mock because the test would fail
# if the returned value doesn't match the type, so there's no need to
# check the convert.
- forcetypeassert

View file

@ -446,6 +446,26 @@ func (m *MPPayment) NeedWaitAttempts() (bool, error) {
}
}
// GetState returns the internal state of the payment.
func (m *MPPayment) GetState() *MPPaymentState {
return m.State
}
// Status returns the current status of the payment.
func (m *MPPayment) GetStatus() PaymentStatus {
return m.Status
}
// GetPayment returns all the HTLCs for this payment.
func (m *MPPayment) GetHTLCs() []HTLCAttempt {
return m.HTLCs
}
// GetFailureReason returns the failure reason.
func (m *MPPayment) GetFailureReason() *FailureReason {
return m.FailureReason
}
// serializeHTLCSettleInfo serializes the details of a settled htlc.
func serializeHTLCSettleInfo(w io.Writer, s *HTLCSettleInfo) error {
if _, err := w.Write(s.Preimage[:]); err != nil {

View file

@ -9,6 +9,32 @@ import (
"github.com/lightningnetwork/lnd/queue"
)
// dbMPPayment is an interface derived from channeldb.MPPayment that is used by
// the payment lifecycle.
type dbMPPayment interface {
// GetState returns the current state of the payment.
GetState() *channeldb.MPPaymentState
// Terminated returns true if the payment is in a final state.
Terminated() bool
// GetStatus returns the current status of the payment.
GetStatus() channeldb.PaymentStatus
// NeedWaitAttempts specifies whether the payment needs to wait for the
// outcome of an attempt.
NeedWaitAttempts() (bool, error)
// GetHTLCs returns all HTLCs of this payment.
GetHTLCs() []channeldb.HTLCAttempt
// InFlightHTLCs returns all HTLCs that are in flight.
InFlightHTLCs() []channeldb.HTLCAttempt
// GetFailureReason returns the reason the payment failed.
GetFailureReason() *channeldb.FailureReason
}
// ControlTower tracks all outgoing payments made, whose primary purpose is to
// prevent duplicate payments to the same payment hash. In production, a
// persistent implementation is preferred so that tracking can survive across
@ -44,7 +70,7 @@ type ControlTower interface {
// FetchPayment fetches the payment corresponding to the given payment
// hash.
FetchPayment(paymentHash lntypes.Hash) (*channeldb.MPPayment, error)
FetchPayment(paymentHash lntypes.Hash) (dbMPPayment, error)
// FailPayment transitions a payment into the Failed state, and records
// the ultimate reason the payment failed. Note that this should only
@ -224,7 +250,7 @@ func (p *controlTower) FailAttempt(paymentHash lntypes.Hash,
// FetchPayment fetches the payment corresponding to the given payment hash.
func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) (
*channeldb.MPPayment, error) {
dbMPPayment, error) {
return p.db.FetchPayment(paymentHash)
}

View file

@ -130,9 +130,9 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
}
}
require.Equalf(t, channeldb.StatusSucceeded, result.Status,
require.Equalf(t, channeldb.StatusSucceeded, result.GetStatus(),
"subscriber %v failed, want %s, got %s", i,
channeldb.StatusSucceeded, result.Status)
channeldb.StatusSucceeded, result.GetStatus())
settle, _ := result.TerminalInfo()
if settle.Preimage != preimg {
@ -259,7 +259,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
result1 := results[info1.PaymentIdentifier]
require.Equal(
t, channeldb.StatusSucceeded, result1.Status,
t, channeldb.StatusSucceeded, result1.GetStatus(),
"unexpected payment state payment 1",
)
@ -278,7 +278,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
result2 := results[info2.PaymentIdentifier]
require.Equal(
t, channeldb.StatusSucceeded, result2.Status,
t, channeldb.StatusSucceeded, result2.GetStatus(),
"unexpected payment state payment 2",
)
@ -486,7 +486,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
}
}
if result.Status == channeldb.StatusSucceeded {
if result.GetStatus() == channeldb.StatusSucceeded {
t.Fatal("unexpected payment state")
}
@ -511,9 +511,9 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
len(result.HTLCs))
}
require.Equalf(t, channeldb.StatusFailed, result.Status,
require.Equalf(t, channeldb.StatusFailed, result.GetStatus(),
"subscriber %v failed, want %s, got %s", i,
channeldb.StatusFailed, result.Status)
channeldb.StatusFailed, result.GetStatus())
if *result.FailureReason != channeldb.FailureReasonTimeout {
t.Fatal("unexpected failure reason")

View file

@ -491,7 +491,7 @@ func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash,
}
func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {
dbMPPayment, error) {
m.Lock()
defer m.Unlock()
@ -750,10 +750,8 @@ func (m *mockControlTower) FailPayment(phash lntypes.Hash,
}
func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
*channeldb.MPPayment, error) {
dbMPPayment, error) {
m.Lock()
defer m.Unlock()
args := m.Called(phash)
// Type assertion on nil will fail, so we check and return here.
@ -761,15 +759,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
return nil, args.Error(1)
}
// Make a copy of the payment here to avoid data race.
p := args.Get(0).(*channeldb.MPPayment)
payment := &channeldb.MPPayment{
Info: p.Info,
FailureReason: p.FailureReason,
}
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
copy(payment.HTLCs, p.HTLCs)
payment := args.Get(0).(*mockMPPayment)
return payment, args.Error(1)
}
@ -794,6 +784,54 @@ func (m *mockControlTower) SubscribeAllPayments() (
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}
type mockMPPayment struct {
mock.Mock
}
var _ dbMPPayment = (*mockMPPayment)(nil)
func (m *mockMPPayment) GetState() *channeldb.MPPaymentState {
args := m.Called()
return args.Get(0).(*channeldb.MPPaymentState)
}
func (m *mockMPPayment) GetStatus() channeldb.PaymentStatus {
args := m.Called()
return args.Get(0).(channeldb.PaymentStatus)
}
func (m *mockMPPayment) Terminated() bool {
args := m.Called()
return args.Bool(0)
}
func (m *mockMPPayment) NeedWaitAttempts() (bool, error) {
args := m.Called()
return args.Bool(0), args.Error(1)
}
func (m *mockMPPayment) GetHTLCs() []channeldb.HTLCAttempt {
args := m.Called()
return args.Get(0).([]channeldb.HTLCAttempt)
}
func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt {
args := m.Called()
return args.Get(0).([]channeldb.HTLCAttempt)
}
func (m *mockMPPayment) GetFailureReason() *channeldb.FailureReason {
args := m.Called()
reason := args.Get(0)
if reason == nil {
return nil
}
return reason.(*channeldb.FailureReason)
}
type mockLink struct {
htlcswitch.ChannelLink
bandwidth lnwire.MilliSatoshi

View file

@ -86,7 +86,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
// exitWithErr is a helper closure that logs and returns an error.
exitWithErr := func(err error) ([32]byte, *route.Route, error) {
log.Errorf("Payment %v with status=%v failed: %v",
p.identifier, payment.Status, err)
p.identifier, payment.GetStatus(), err)
return [32]byte{}, nil, err
}
@ -112,7 +112,7 @@ lifecycle:
return exitWithErr(err)
}
ps := payment.State
ps := payment.GetState()
remainingFees := p.calcFeeBudget(ps.FeesPaid)
log.Debugf("Payment %v in state terminate=%v, "+
@ -127,7 +127,7 @@ lifecycle:
if payment.Terminated() {
// Find the first successful shard and return
// the preimage and route.
for _, a := range payment.HTLCs {
for _, a := range payment.GetHTLCs() {
if a.Settle == nil {
continue
}
@ -146,7 +146,7 @@ lifecycle:
}
// Payment failed.
return exitWithErr(*payment.FailureReason)
return exitWithErr(*payment.GetFailureReason())
}
// If we either reached a terminal error condition (but had

View file

@ -791,12 +791,6 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
}
}
func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt {
return channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
}
}
func makeSettledAttempt(total, fee int,
preimage lntypes.Preimage) channeldb.HTLCAttempt {

View file

@ -1120,15 +1120,16 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
p, err := ctx.router.cfg.Control.FetchPayment(payHash)
require.NoError(t, err)
require.Equal(t, 2, len(p.HTLCs), "expected two attempts")
htlcs := p.GetHTLCs()
require.Equal(t, 2, len(htlcs), "expected two attempts")
// We expect the first attempt to have failed with a
// TemporaryChannelFailure, the second with UnknownNextPeer.
msg := p.HTLCs[0].Failure.Message
msg := htlcs[0].Failure.Message
_, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok, "unexpected fail message")
msg = p.HTLCs[1].Failure.Message
msg = htlcs[1].Failure.Message
_, ok = msg.(*lnwire.FailUnknownNextPeer)
require.True(t, ok, "unexpected fail message")
@ -3470,29 +3471,64 @@ func TestSendMPPaymentSucceed(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Mock the InFlightHTLCs.
var (
htlcs []channeldb.HTLCAttempt
numAttempts atomic.Uint32
)
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment",
identifier,
).Return(payment, nil).Run(func(args mock.Arguments) {
// When number of attempts made is less than 4, we will mock
// the payment's methods to allow the lifecycle to continue.
if numAttempts.Load() < 4 {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
return
}
// Otherwise, terminate the lifecycle.
payment.On("Terminated").Return(true).
On("NeedWaitAttempts").Return(true, nil)
})
// Mock SettleAttempt.
preimage := lntypes.Preimage{1, 2, 3}
settledAttempt := makeSettledAttempt(
int(paymentAmt/4), 0, preimage,
)
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
payment.On("GetHTLCs").Return(
[]channeldb.HTLCAttempt{settledAttempt},
)
})
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
require.NoError(t, err, "failed to create route")
session.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(shard, nil)
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
// Increase the counter whenever an attempt is made.
numAttempts.Add(1)
})
// Create a buffered chan and it will be returned by GetAttemptResult.
@ -3509,30 +3545,12 @@ func TestSendMPPaymentSucceed(t *testing.T) {
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
missionControl.On("ReportPaymentSuccess",
mock.Anything, mock.Anything,
).Return(nil)
// Mock SettleAttempt by changing one of the HTLCs to be settled.
preimage := lntypes.Preimage{1, 2, 3}
settledAttempt := makeSettledAttempt(
int(paymentAmt/4), 0, preimage,
)
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt settled and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle == nil {
attempt.Settle = &channeldb.HTLCSettleInfo{
Preimage: preimage,
}
payment.HTLCs[i] = attempt
return
}
}
).Return(nil).Run(func(args mock.Arguments) {
})
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
// Call the actual method SendPayment on router. This is place inside a
@ -3565,6 +3583,7 @@ func TestSendMPPaymentSucceed(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if
@ -3639,13 +3658,34 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Mock the InFlightHTLCs.
var (
htlcs []channeldb.HTLCAttempt
numAttempts atomic.Uint32
failAttemptCount atomic.Uint32
)
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment",
identifier,
).Return(payment, nil).Run(func(args mock.Arguments) {
// When number of attempts made is less than 6, we will mock
// the payment's methods to allow the lifecycle to continue.
if numAttempts.Load() < 6 {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
return
}
// Otherwise, terminate the lifecycle.
payment.On("Terminated").Return(true).
On("NeedWaitAttempts").Return(true, nil)
})
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
@ -3657,11 +3697,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
// Increase the counter whenever an attempt is made.
numAttempts.Add(1)
})
// Create a buffered chan and it will be returned by GetAttemptResult.
@ -3670,7 +3710,6 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
// We use the failAttemptCount to track how many attempts we want to
// fail. Each time the following mock method is called, the count gets
// updated.
failAttemptCount := 0
payer.On("GetAttemptResult",
mock.Anything, identifier, mock.Anything,
).Run(func(args mock.Arguments) {
@ -3678,11 +3717,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
// the read-only chan.
// Update the counter.
failAttemptCount++
failAttemptCount.Add(1)
// We will make the first two attempts failed with temporary
// error.
if failAttemptCount <= 2 {
if failAttemptCount.Load() <= 2 {
payer.resultChan <- &htlcswitch.PaymentResult{
Error: htlcswitch.NewForwardingError(
&lnwire.FailTemporaryChannelFailure{},
@ -3700,20 +3739,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Failure = &channeldb.HTLCFailInfo{}
failedAttempt = attempt
payment.HTLCs[i] = attempt
return
}
})
).Return(&failedAttempt, nil)
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
@ -3737,20 +3763,13 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
controlTower.On("SettleAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt settled and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Settle = &channeldb.HTLCSettleInfo{
Preimage: preimage,
}
payment.HTLCs[i] = attempt
return
}
// Whenever this method is invoked, we will mock the payment's
// GetHTLCs() to return the settled htlc.
payment.On("GetHTLCs").Return(
[]channeldb.HTLCAttempt{settledAttempt},
)
})
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
// Call the actual method SendPayment on router. This is place inside a
@ -3779,6 +3798,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestSendMPPaymentFailed tests that when one of the shard fails with a
@ -3853,12 +3873,18 @@ func TestSendMPPaymentFailed(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
// Mock the InFlightHTLCs.
var htlcs []channeldb.HTLCAttempt
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("GetStatus").Return(channeldb.StatusInFlight).
On("Terminated").Return(false).
On("NeedWaitAttempts").Return(false, nil)
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
@ -3871,12 +3897,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
@ -3918,43 +3941,24 @@ func TestSendMPPaymentFailed(t *testing.T) {
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
for i, attempt := range payment.HTLCs {
if attempt.Settle != nil || attempt.Failure != nil {
continue
}
attempt.Failure = &channeldb.HTLCFailInfo{}
failedAttempt = attempt
payment.HTLCs[i] = attempt
return
}
})
).Return(&failedAttempt, nil)
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
var called bool
failureReason := channeldb.FailureReasonPaymentDetails
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
// We only return the terminal error once, thus when the method
// is called, we will return it with a nil error.
if called {
args[0] = nil
return
}
// If it's the first time calling this method, we will return a
// terminal error.
payment.FailureReason = &failureReason
called = true
})
).Return(&failureReason, nil)
// Simple mocking the rest.
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
controlTower.On("FailPayment",
identifier, failureReason,
).Return(nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mock the payment's
// Terminated() to be True.
payment.On("Terminated").Return(true)
})
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
@ -3985,6 +3989,7 @@ func TestSendMPPaymentFailed(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
@ -4059,12 +4064,18 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// The following mocked methods are called inside resumePayment. Note
// that the payment object below will determine the state of the
// paymentLifecycle.
payment := &channeldb.MPPayment{
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
}
// Mock the InFlightHTLCs.
var htlcs []channeldb.HTLCAttempt
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetStatus").Return(channeldb.StatusInFlight).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("Terminated").Return(false).
On("NeedWaitAttempts").Return(false, nil)
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment", identifier).Return(payment, nil)
// Create a route that can send 1/4 of the total amount. This value
@ -4077,12 +4088,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
// Make a new htlc attempt with zero fee and append it to the payment's
// HTLCs when calling RegisterAttempt.
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
payment.HTLCs = append(payment.HTLCs, activeAttempt)
})
).Return(nil)
// Create a buffered chan and it will be returned by GetAttemptResult.
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
@ -4130,28 +4138,25 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
var failedAttempt channeldb.HTLCAttempt
controlTower.On("FailAttempt",
identifier, mock.Anything, mock.Anything,
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mark the first
// active attempt as failed and exit.
failedAttempt = payment.HTLCs[0]
failedAttempt.Failure = &channeldb.HTLCFailInfo{}
payment.HTLCs[0] = failedAttempt
})
).Return(&failedAttempt, nil)
// Setup ReportPaymentFail to return nil reason and error so the
// payment won't fail.
failureReason := channeldb.FailureReasonPaymentDetails
cntReportPaymentFail := 0
missionControl.On("ReportPaymentFail",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
payment.FailureReason = &failureReason
cntReportPaymentFail++
})
).Return(&failureReason, nil)
// Simple mocking the rest.
cntFail := 0
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
controlTower.On("FailPayment",
identifier, failureReason,
).Return(nil).Run(func(args mock.Arguments) {
// Whenever this method is invoked, we will mock the payment's
// Terminated() to be True.
payment.On("Terminated").Return(true)
})
payer.On("SendHTLC",
mock.Anything, mock.Anything, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
@ -4179,7 +4184,6 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
require.Error(t, err, "expected send payment error")
require.EqualValues(t, [32]byte{}, p, "preimage not match")
require.GreaterOrEqual(t, getPaymentResultCnt, 1)
require.Equal(t, getPaymentResultCnt, cntReportPaymentFail)
require.Equal(t, getPaymentResultCnt, cntFail)
controlTower.AssertExpectations(t)
@ -4187,6 +4191,7 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
sessionSource.AssertExpectations(t)
session.AssertExpectations(t)
missionControl.AssertExpectations(t)
payment.AssertExpectations(t)
}
// TestBlockDifferenceFix tests if when the router is behind on blocks, the