mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-13 11:09:23 +01:00
routing+channeldb: make MPPayment into an interface
This commit turns `MPPayment` into an interface inside `routing`. Having this interface gives us the benefit to write more granular unit tests inside payment lifecycle. As seen from the modified unit tests, several hacky ways of testing the `SendPayment` method is now replaced by a mock over `MPPayment`.
This commit is contained in:
parent
c412ab5ccb
commit
34d0e5d4c5
8 changed files with 257 additions and 167 deletions
|
@ -268,3 +268,10 @@ issues:
|
|||
# if the returned value doesn't match the type, so there's no need to
|
||||
# check the convert.
|
||||
- forcetypeassert
|
||||
|
||||
- path: mock*
|
||||
linters:
|
||||
# forcetypeassert is skipped for the mock because the test would fail
|
||||
# if the returned value doesn't match the type, so there's no need to
|
||||
# check the convert.
|
||||
- forcetypeassert
|
||||
|
|
|
@ -446,6 +446,26 @@ func (m *MPPayment) NeedWaitAttempts() (bool, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// GetState returns the internal state of the payment.
|
||||
func (m *MPPayment) GetState() *MPPaymentState {
|
||||
return m.State
|
||||
}
|
||||
|
||||
// Status returns the current status of the payment.
|
||||
func (m *MPPayment) GetStatus() PaymentStatus {
|
||||
return m.Status
|
||||
}
|
||||
|
||||
// GetPayment returns all the HTLCs for this payment.
|
||||
func (m *MPPayment) GetHTLCs() []HTLCAttempt {
|
||||
return m.HTLCs
|
||||
}
|
||||
|
||||
// GetFailureReason returns the failure reason.
|
||||
func (m *MPPayment) GetFailureReason() *FailureReason {
|
||||
return m.FailureReason
|
||||
}
|
||||
|
||||
// serializeHTLCSettleInfo serializes the details of a settled htlc.
|
||||
func serializeHTLCSettleInfo(w io.Writer, s *HTLCSettleInfo) error {
|
||||
if _, err := w.Write(s.Preimage[:]); err != nil {
|
||||
|
|
|
@ -9,6 +9,32 @@ import (
|
|||
"github.com/lightningnetwork/lnd/queue"
|
||||
)
|
||||
|
||||
// dbMPPayment is an interface derived from channeldb.MPPayment that is used by
|
||||
// the payment lifecycle.
|
||||
type dbMPPayment interface {
|
||||
// GetState returns the current state of the payment.
|
||||
GetState() *channeldb.MPPaymentState
|
||||
|
||||
// Terminated returns true if the payment is in a final state.
|
||||
Terminated() bool
|
||||
|
||||
// GetStatus returns the current status of the payment.
|
||||
GetStatus() channeldb.PaymentStatus
|
||||
|
||||
// NeedWaitAttempts specifies whether the payment needs to wait for the
|
||||
// outcome of an attempt.
|
||||
NeedWaitAttempts() (bool, error)
|
||||
|
||||
// GetHTLCs returns all HTLCs of this payment.
|
||||
GetHTLCs() []channeldb.HTLCAttempt
|
||||
|
||||
// InFlightHTLCs returns all HTLCs that are in flight.
|
||||
InFlightHTLCs() []channeldb.HTLCAttempt
|
||||
|
||||
// GetFailureReason returns the reason the payment failed.
|
||||
GetFailureReason() *channeldb.FailureReason
|
||||
}
|
||||
|
||||
// ControlTower tracks all outgoing payments made, whose primary purpose is to
|
||||
// prevent duplicate payments to the same payment hash. In production, a
|
||||
// persistent implementation is preferred so that tracking can survive across
|
||||
|
@ -44,7 +70,7 @@ type ControlTower interface {
|
|||
|
||||
// FetchPayment fetches the payment corresponding to the given payment
|
||||
// hash.
|
||||
FetchPayment(paymentHash lntypes.Hash) (*channeldb.MPPayment, error)
|
||||
FetchPayment(paymentHash lntypes.Hash) (dbMPPayment, error)
|
||||
|
||||
// FailPayment transitions a payment into the Failed state, and records
|
||||
// the ultimate reason the payment failed. Note that this should only
|
||||
|
@ -224,7 +250,7 @@ func (p *controlTower) FailAttempt(paymentHash lntypes.Hash,
|
|||
|
||||
// FetchPayment fetches the payment corresponding to the given payment hash.
|
||||
func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) (
|
||||
*channeldb.MPPayment, error) {
|
||||
dbMPPayment, error) {
|
||||
|
||||
return p.db.FetchPayment(paymentHash)
|
||||
}
|
||||
|
|
|
@ -130,9 +130,9 @@ func TestControlTowerSubscribeSuccess(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
require.Equalf(t, channeldb.StatusSucceeded, result.Status,
|
||||
require.Equalf(t, channeldb.StatusSucceeded, result.GetStatus(),
|
||||
"subscriber %v failed, want %s, got %s", i,
|
||||
channeldb.StatusSucceeded, result.Status)
|
||||
channeldb.StatusSucceeded, result.GetStatus())
|
||||
|
||||
settle, _ := result.TerminalInfo()
|
||||
if settle.Preimage != preimg {
|
||||
|
@ -259,7 +259,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
|
|||
|
||||
result1 := results[info1.PaymentIdentifier]
|
||||
require.Equal(
|
||||
t, channeldb.StatusSucceeded, result1.Status,
|
||||
t, channeldb.StatusSucceeded, result1.GetStatus(),
|
||||
"unexpected payment state payment 1",
|
||||
)
|
||||
|
||||
|
@ -278,7 +278,7 @@ func TestPaymentControlSubscribeAllSuccess(t *testing.T) {
|
|||
|
||||
result2 := results[info2.PaymentIdentifier]
|
||||
require.Equal(
|
||||
t, channeldb.StatusSucceeded, result2.Status,
|
||||
t, channeldb.StatusSucceeded, result2.GetStatus(),
|
||||
"unexpected payment state payment 2",
|
||||
)
|
||||
|
||||
|
@ -486,7 +486,7 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
|
|||
}
|
||||
}
|
||||
|
||||
if result.Status == channeldb.StatusSucceeded {
|
||||
if result.GetStatus() == channeldb.StatusSucceeded {
|
||||
t.Fatal("unexpected payment state")
|
||||
}
|
||||
|
||||
|
@ -511,9 +511,9 @@ func testPaymentControlSubscribeFail(t *testing.T, registerAttempt,
|
|||
len(result.HTLCs))
|
||||
}
|
||||
|
||||
require.Equalf(t, channeldb.StatusFailed, result.Status,
|
||||
require.Equalf(t, channeldb.StatusFailed, result.GetStatus(),
|
||||
"subscriber %v failed, want %s, got %s", i,
|
||||
channeldb.StatusFailed, result.Status)
|
||||
channeldb.StatusFailed, result.GetStatus())
|
||||
|
||||
if *result.FailureReason != channeldb.FailureReasonTimeout {
|
||||
t.Fatal("unexpected failure reason")
|
||||
|
|
|
@ -491,7 +491,7 @@ func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash,
|
|||
}
|
||||
|
||||
func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) (
|
||||
*channeldb.MPPayment, error) {
|
||||
dbMPPayment, error) {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
@ -750,10 +750,8 @@ func (m *mockControlTower) FailPayment(phash lntypes.Hash,
|
|||
}
|
||||
|
||||
func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
|
||||
*channeldb.MPPayment, error) {
|
||||
dbMPPayment, error) {
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
args := m.Called(phash)
|
||||
|
||||
// Type assertion on nil will fail, so we check and return here.
|
||||
|
@ -761,15 +759,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
|
|||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
// Make a copy of the payment here to avoid data race.
|
||||
p := args.Get(0).(*channeldb.MPPayment)
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: p.Info,
|
||||
FailureReason: p.FailureReason,
|
||||
}
|
||||
payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
|
||||
copy(payment.HTLCs, p.HTLCs)
|
||||
|
||||
payment := args.Get(0).(*mockMPPayment)
|
||||
return payment, args.Error(1)
|
||||
}
|
||||
|
||||
|
@ -794,6 +784,54 @@ func (m *mockControlTower) SubscribeAllPayments() (
|
|||
return args.Get(0).(ControlTowerSubscriber), args.Error(1)
|
||||
}
|
||||
|
||||
type mockMPPayment struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
var _ dbMPPayment = (*mockMPPayment)(nil)
|
||||
|
||||
func (m *mockMPPayment) GetState() *channeldb.MPPaymentState {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*channeldb.MPPaymentState)
|
||||
}
|
||||
|
||||
func (m *mockMPPayment) GetStatus() channeldb.PaymentStatus {
|
||||
args := m.Called()
|
||||
return args.Get(0).(channeldb.PaymentStatus)
|
||||
}
|
||||
|
||||
func (m *mockMPPayment) Terminated() bool {
|
||||
args := m.Called()
|
||||
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *mockMPPayment) NeedWaitAttempts() (bool, error) {
|
||||
args := m.Called()
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockMPPayment) GetHTLCs() []channeldb.HTLCAttempt {
|
||||
args := m.Called()
|
||||
return args.Get(0).([]channeldb.HTLCAttempt)
|
||||
}
|
||||
|
||||
func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt {
|
||||
args := m.Called()
|
||||
return args.Get(0).([]channeldb.HTLCAttempt)
|
||||
}
|
||||
|
||||
func (m *mockMPPayment) GetFailureReason() *channeldb.FailureReason {
|
||||
args := m.Called()
|
||||
|
||||
reason := args.Get(0)
|
||||
if reason == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return reason.(*channeldb.FailureReason)
|
||||
}
|
||||
|
||||
type mockLink struct {
|
||||
htlcswitch.ChannelLink
|
||||
bandwidth lnwire.MilliSatoshi
|
||||
|
|
|
@ -86,7 +86,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
|
|||
// exitWithErr is a helper closure that logs and returns an error.
|
||||
exitWithErr := func(err error) ([32]byte, *route.Route, error) {
|
||||
log.Errorf("Payment %v with status=%v failed: %v",
|
||||
p.identifier, payment.Status, err)
|
||||
p.identifier, payment.GetStatus(), err)
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
|
||||
|
@ -112,7 +112,7 @@ lifecycle:
|
|||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
ps := payment.State
|
||||
ps := payment.GetState()
|
||||
remainingFees := p.calcFeeBudget(ps.FeesPaid)
|
||||
|
||||
log.Debugf("Payment %v in state terminate=%v, "+
|
||||
|
@ -127,7 +127,7 @@ lifecycle:
|
|||
if payment.Terminated() {
|
||||
// Find the first successful shard and return
|
||||
// the preimage and route.
|
||||
for _, a := range payment.HTLCs {
|
||||
for _, a := range payment.GetHTLCs() {
|
||||
if a.Settle == nil {
|
||||
continue
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ lifecycle:
|
|||
}
|
||||
|
||||
// Payment failed.
|
||||
return exitWithErr(*payment.FailureReason)
|
||||
return exitWithErr(*payment.GetFailureReason())
|
||||
}
|
||||
|
||||
// If we either reached a terminal error condition (but had
|
||||
|
|
|
@ -791,12 +791,6 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
|
|||
}
|
||||
}
|
||||
|
||||
func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt {
|
||||
return channeldb.HTLCAttempt{
|
||||
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
|
||||
}
|
||||
}
|
||||
|
||||
func makeSettledAttempt(total, fee int,
|
||||
preimage lntypes.Preimage) channeldb.HTLCAttempt {
|
||||
|
||||
|
|
|
@ -1120,15 +1120,16 @@ func TestSendPaymentErrorPathPruning(t *testing.T) {
|
|||
p, err := ctx.router.cfg.Control.FetchPayment(payHash)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 2, len(p.HTLCs), "expected two attempts")
|
||||
htlcs := p.GetHTLCs()
|
||||
require.Equal(t, 2, len(htlcs), "expected two attempts")
|
||||
|
||||
// We expect the first attempt to have failed with a
|
||||
// TemporaryChannelFailure, the second with UnknownNextPeer.
|
||||
msg := p.HTLCs[0].Failure.Message
|
||||
msg := htlcs[0].Failure.Message
|
||||
_, ok := msg.(*lnwire.FailTemporaryChannelFailure)
|
||||
require.True(t, ok, "unexpected fail message")
|
||||
|
||||
msg = p.HTLCs[1].Failure.Message
|
||||
msg = htlcs[1].Failure.Message
|
||||
_, ok = msg.(*lnwire.FailUnknownNextPeer)
|
||||
require.True(t, ok, "unexpected fail message")
|
||||
|
||||
|
@ -3470,29 +3471,64 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
|||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
// Mock the InFlightHTLCs.
|
||||
var (
|
||||
htlcs []channeldb.HTLCAttempt
|
||||
numAttempts atomic.Uint32
|
||||
)
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment",
|
||||
identifier,
|
||||
).Return(payment, nil).Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 4, we will mock
|
||||
// the payment's methods to allow the lifecycle to continue.
|
||||
if numAttempts.Load() < 4 {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, terminate the lifecycle.
|
||||
payment.On("Terminated").Return(true).
|
||||
On("NeedWaitAttempts").Return(true, nil)
|
||||
})
|
||||
|
||||
// Mock SettleAttempt.
|
||||
preimage := lntypes.Preimage{1, 2, 3}
|
||||
settledAttempt := makeSettledAttempt(
|
||||
int(paymentAmt/4), 0, preimage,
|
||||
)
|
||||
|
||||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
payment.On("GetHTLCs").Return(
|
||||
[]channeldb.HTLCAttempt{settledAttempt},
|
||||
)
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
shard, err := createTestRoute(paymentAmt/4, testGraph.aliasMap)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
|
||||
session.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(shard, nil)
|
||||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
// Increase the counter whenever an attempt is made.
|
||||
numAttempts.Add(1)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
|
@ -3509,30 +3545,12 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
|||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
missionControl.On("ReportPaymentSuccess",
|
||||
mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
// Mock SettleAttempt by changing one of the HTLCs to be settled.
|
||||
preimage := lntypes.Preimage{1, 2, 3}
|
||||
settledAttempt := makeSettledAttempt(
|
||||
int(paymentAmt/4), 0, preimage,
|
||||
)
|
||||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt settled and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle == nil {
|
||||
attempt.Settle = &channeldb.HTLCSettleInfo{
|
||||
Preimage: preimage,
|
||||
}
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
}
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
})
|
||||
|
||||
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
|
@ -3565,6 +3583,7 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
|||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentSucceedOnExtraShards tests that we need extra attempts if
|
||||
|
@ -3639,13 +3658,34 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
// Mock the InFlightHTLCs.
|
||||
var (
|
||||
htlcs []channeldb.HTLCAttempt
|
||||
numAttempts atomic.Uint32
|
||||
failAttemptCount atomic.Uint32
|
||||
)
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment",
|
||||
identifier,
|
||||
).Return(payment, nil).Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 6, we will mock
|
||||
// the payment's methods to allow the lifecycle to continue.
|
||||
if numAttempts.Load() < 6 {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, terminate the lifecycle.
|
||||
payment.On("Terminated").Return(true).
|
||||
On("NeedWaitAttempts").Return(true, nil)
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
|
@ -3657,11 +3697,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
// Increase the counter whenever an attempt is made.
|
||||
numAttempts.Add(1)
|
||||
})
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
|
@ -3670,7 +3710,6 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||
// We use the failAttemptCount to track how many attempts we want to
|
||||
// fail. Each time the following mock method is called, the count gets
|
||||
// updated.
|
||||
failAttemptCount := 0
|
||||
payer.On("GetAttemptResult",
|
||||
mock.Anything, identifier, mock.Anything,
|
||||
).Run(func(args mock.Arguments) {
|
||||
|
@ -3678,11 +3717,11 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||
// the read-only chan.
|
||||
|
||||
// Update the counter.
|
||||
failAttemptCount++
|
||||
failAttemptCount.Add(1)
|
||||
|
||||
// We will make the first two attempts failed with temporary
|
||||
// error.
|
||||
if failAttemptCount <= 2 {
|
||||
if failAttemptCount.Load() <= 2 {
|
||||
payer.resultChan <- &htlcswitch.PaymentResult{
|
||||
Error: htlcswitch.NewForwardingError(
|
||||
&lnwire.FailTemporaryChannelFailure{},
|
||||
|
@ -3700,20 +3739,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
failedAttempt = attempt
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
})
|
||||
).Return(&failedAttempt, nil)
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
|
@ -3737,20 +3763,13 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||
controlTower.On("SettleAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&settledAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt settled and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Settle = &channeldb.HTLCSettleInfo{
|
||||
Preimage: preimage,
|
||||
}
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// GetHTLCs() to return the settled htlc.
|
||||
payment.On("GetHTLCs").Return(
|
||||
[]channeldb.HTLCAttempt{settledAttempt},
|
||||
)
|
||||
})
|
||||
|
||||
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
|
@ -3779,6 +3798,7 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
|||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentFailed tests that when one of the shard fails with a
|
||||
|
@ -3853,12 +3873,18 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
// Mock the InFlightHTLCs.
|
||||
var htlcs []channeldb.HTLCAttempt
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("GetStatus").Return(channeldb.StatusInFlight).
|
||||
On("Terminated").Return(false).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
|
@ -3871,12 +3897,9 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
).Return(nil)
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
|
@ -3918,43 +3941,24 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
for i, attempt := range payment.HTLCs {
|
||||
if attempt.Settle != nil || attempt.Failure != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
attempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
failedAttempt = attempt
|
||||
payment.HTLCs[i] = attempt
|
||||
return
|
||||
}
|
||||
})
|
||||
).Return(&failedAttempt, nil)
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
var called bool
|
||||
failureReason := channeldb.FailureReasonPaymentDetails
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
|
||||
// We only return the terminal error once, thus when the method
|
||||
// is called, we will return it with a nil error.
|
||||
if called {
|
||||
args[0] = nil
|
||||
return
|
||||
}
|
||||
|
||||
// If it's the first time calling this method, we will return a
|
||||
// terminal error.
|
||||
payment.FailureReason = &failureReason
|
||||
called = true
|
||||
})
|
||||
).Return(&failureReason, nil)
|
||||
|
||||
// Simple mocking the rest.
|
||||
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
|
||||
controlTower.On("FailPayment",
|
||||
identifier, failureReason,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// Terminated() to be True.
|
||||
payment.On("Terminated").Return(true)
|
||||
})
|
||||
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
@ -3985,6 +3989,7 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
|||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendMPPaymentFailedWithShardsInFlight tests that when the payment is in
|
||||
|
@ -4059,12 +4064,18 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
|||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// The following mocked methods are called inside resumePayment. Note
|
||||
// that the payment object below will determine the state of the
|
||||
// paymentLifecycle.
|
||||
payment := &channeldb.MPPayment{
|
||||
Info: &channeldb.PaymentCreationInfo{Value: paymentAmt},
|
||||
}
|
||||
// Mock the InFlightHTLCs.
|
||||
var htlcs []channeldb.HTLCAttempt
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetStatus").Return(channeldb.StatusInFlight).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("Terminated").Return(false).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil)
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
|
@ -4077,12 +4088,9 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
|||
|
||||
// Make a new htlc attempt with zero fee and append it to the payment's
|
||||
// HTLCs when calling RegisterAttempt.
|
||||
activeAttempt := makeActiveAttempt(int(paymentAmt/4), 0)
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
payment.HTLCs = append(payment.HTLCs, activeAttempt)
|
||||
})
|
||||
).Return(nil)
|
||||
|
||||
// Create a buffered chan and it will be returned by GetAttemptResult.
|
||||
payer.resultChan = make(chan *htlcswitch.PaymentResult, 10)
|
||||
|
@ -4130,28 +4138,25 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
|||
var failedAttempt channeldb.HTLCAttempt
|
||||
controlTower.On("FailAttempt",
|
||||
identifier, mock.Anything, mock.Anything,
|
||||
).Return(&failedAttempt, nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mark the first
|
||||
// active attempt as failed and exit.
|
||||
failedAttempt = payment.HTLCs[0]
|
||||
failedAttempt.Failure = &channeldb.HTLCFailInfo{}
|
||||
payment.HTLCs[0] = failedAttempt
|
||||
})
|
||||
).Return(&failedAttempt, nil)
|
||||
|
||||
// Setup ReportPaymentFail to return nil reason and error so the
|
||||
// payment won't fail.
|
||||
failureReason := channeldb.FailureReasonPaymentDetails
|
||||
cntReportPaymentFail := 0
|
||||
missionControl.On("ReportPaymentFail",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(&failureReason, nil).Run(func(args mock.Arguments) {
|
||||
payment.FailureReason = &failureReason
|
||||
cntReportPaymentFail++
|
||||
})
|
||||
).Return(&failureReason, nil)
|
||||
|
||||
// Simple mocking the rest.
|
||||
cntFail := 0
|
||||
controlTower.On("FailPayment", identifier, failureReason).Return(nil)
|
||||
controlTower.On("FailPayment",
|
||||
identifier, failureReason,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
// Whenever this method is invoked, we will mock the payment's
|
||||
// Terminated() to be True.
|
||||
payment.On("Terminated").Return(true)
|
||||
})
|
||||
|
||||
payer.On("SendHTLC",
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
|
@ -4179,7 +4184,6 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
|||
require.Error(t, err, "expected send payment error")
|
||||
require.EqualValues(t, [32]byte{}, p, "preimage not match")
|
||||
require.GreaterOrEqual(t, getPaymentResultCnt, 1)
|
||||
require.Equal(t, getPaymentResultCnt, cntReportPaymentFail)
|
||||
require.Equal(t, getPaymentResultCnt, cntFail)
|
||||
|
||||
controlTower.AssertExpectations(t)
|
||||
|
@ -4187,6 +4191,7 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) {
|
|||
sessionSource.AssertExpectations(t)
|
||||
session.AssertExpectations(t)
|
||||
missionControl.AssertExpectations(t)
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestBlockDifferenceFix tests if when the router is behind on blocks, the
|
||||
|
|
Loading…
Add table
Reference in a new issue