mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 13:27:56 +01:00
1637 lines
49 KiB
Go
1637 lines
49 KiB
Go
package routing
|
|
|
|
import (
|
|
"context"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/btcsuite/btcd/btcec/v2"
|
|
"github.com/go-errors/errors"
|
|
"github.com/lightningnetwork/lnd/channeldb"
|
|
"github.com/lightningnetwork/lnd/fn"
|
|
"github.com/lightningnetwork/lnd/htlcswitch"
|
|
"github.com/lightningnetwork/lnd/lnmock"
|
|
"github.com/lightningnetwork/lnd/lntest/wait"
|
|
"github.com/lightningnetwork/lnd/lntypes"
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/lightningnetwork/lnd/record"
|
|
"github.com/lightningnetwork/lnd/routing/route"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// errDummy is used by the mockers to return a dummy error.
|
|
var errDummy = errors.New("dummy")
|
|
|
|
// createTestPaymentLifecycle creates a `paymentLifecycle` without mocks.
|
|
func createTestPaymentLifecycle() *paymentLifecycle {
|
|
paymentHash := lntypes.Hash{1, 2, 3}
|
|
quitChan := make(chan struct{})
|
|
rt := &ChannelRouter{
|
|
cfg: &Config{
|
|
TrafficShaper: fn.Some[TlvTrafficShaper](
|
|
&mockTrafficShaper{},
|
|
),
|
|
},
|
|
quit: quitChan,
|
|
}
|
|
|
|
return &paymentLifecycle{
|
|
router: rt,
|
|
identifier: paymentHash,
|
|
}
|
|
}
|
|
|
|
// mockers wraps a list of mocked interfaces used inside payment lifecycle.
|
|
type mockers struct {
|
|
shard *mockShard
|
|
shardTracker *mockShardTracker
|
|
control *mockControlTower
|
|
paySession *mockPaymentSession
|
|
payer *mockPaymentAttemptDispatcher
|
|
clock *lnmock.MockClock
|
|
missionControl *mockMissionControl
|
|
|
|
// collectResultsCount is the number of times the collectResultAsync
|
|
// has been called.
|
|
collectResultsCount int
|
|
|
|
// payment is the mocked `dbMPPayment` used in the test.
|
|
payment *mockMPPayment
|
|
}
|
|
|
|
// newTestPaymentLifecycle creates a `paymentLifecycle` using the mocks. It
|
|
// also asserts the mockers are called as expected when the test is finished.
|
|
func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {
|
|
paymentHash := lntypes.Hash{1, 2, 3}
|
|
quitChan := make(chan struct{})
|
|
|
|
// Create a mock shard to be return from `NewShard`.
|
|
mockShard := &mockShard{}
|
|
|
|
// Create a list of mocks and add it to the router config.
|
|
mockControlTower := &mockControlTower{}
|
|
mockPayer := &mockPaymentAttemptDispatcher{}
|
|
mockClock := &lnmock.MockClock{}
|
|
mockMissionControl := &mockMissionControl{}
|
|
|
|
// Make a channel router.
|
|
rt := &ChannelRouter{
|
|
cfg: &Config{
|
|
Control: mockControlTower,
|
|
Payer: mockPayer,
|
|
Clock: mockClock,
|
|
MissionControl: mockMissionControl,
|
|
TrafficShaper: fn.Some[TlvTrafficShaper](
|
|
&mockTrafficShaper{},
|
|
),
|
|
},
|
|
quit: quitChan,
|
|
}
|
|
|
|
// Create mockers to init a payment lifecycle.
|
|
mockPaymentSession := &mockPaymentSession{}
|
|
mockShardTracker := &mockShardTracker{}
|
|
|
|
// Create a test payment lifecycle with no fee limit and no timeout.
|
|
p := newPaymentLifecycle(
|
|
rt, noFeeLimit, paymentHash, mockPaymentSession,
|
|
mockShardTracker, 0, nil,
|
|
)
|
|
|
|
// Create a mock payment which is returned from mockControlTower.
|
|
mockPayment := &mockMPPayment{}
|
|
|
|
mockers := &mockers{
|
|
shard: mockShard,
|
|
shardTracker: mockShardTracker,
|
|
control: mockControlTower,
|
|
paySession: mockPaymentSession,
|
|
payer: mockPayer,
|
|
clock: mockClock,
|
|
missionControl: mockMissionControl,
|
|
payment: mockPayment,
|
|
}
|
|
|
|
// Overwrite the collectResultAsync to focus on testing the payment
|
|
// lifecycle within the goroutine.
|
|
resultCollector := func(attempt *channeldb.HTLCAttempt) {
|
|
mockers.collectResultsCount++
|
|
}
|
|
p.resultCollector = resultCollector
|
|
|
|
// Validate the mockers are called as expected before exiting the test.
|
|
t.Cleanup(func() {
|
|
mockShard.AssertExpectations(t)
|
|
mockShardTracker.AssertExpectations(t)
|
|
mockControlTower.AssertExpectations(t)
|
|
mockPaymentSession.AssertExpectations(t)
|
|
mockPayer.AssertExpectations(t)
|
|
mockClock.AssertExpectations(t)
|
|
mockMissionControl.AssertExpectations(t)
|
|
mockPayment.AssertExpectations(t)
|
|
})
|
|
|
|
return p, mockers
|
|
}
|
|
|
|
// setupTestPaymentLifecycle creates a new `paymentLifecycle` and mocks the
|
|
// initial steps of the payment lifecycle so we can enter into the loop
|
|
// directly.
|
|
func setupTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
// Mock the first two calls.
|
|
m.control.On("FetchPayment", p.identifier).Return(
|
|
m.payment, nil,
|
|
).Once()
|
|
|
|
htlcs := []channeldb.HTLCAttempt{}
|
|
m.payment.On("InFlightHTLCs").Return(htlcs).Once()
|
|
|
|
return p, m
|
|
}
|
|
|
|
// resumePaymentResult is used to hold the returned values from
|
|
// `resumePayment`.
|
|
type resumePaymentResult struct {
|
|
preimage lntypes.Hash
|
|
err error
|
|
}
|
|
|
|
// sendPaymentAndAssertError calls `resumePayment` and asserts that an error is
|
|
// returned.
|
|
func sendPaymentAndAssertError(t *testing.T, ctx context.Context,
|
|
p *paymentLifecycle, errExpected error) {
|
|
|
|
resultChan := make(chan *resumePaymentResult, 1)
|
|
|
|
// We now make a call to `resumePayment` and expect it to return the
|
|
// error.
|
|
go func() {
|
|
preimage, _, err := p.resumePayment(ctx)
|
|
resultChan <- &resumePaymentResult{
|
|
preimage: preimage,
|
|
err: err,
|
|
}
|
|
}()
|
|
|
|
// Validate the returned values or timeout.
|
|
select {
|
|
case r := <-resultChan:
|
|
require.ErrorIs(t, r.err, errExpected, "expected error")
|
|
require.Empty(t, r.preimage, "preimage should be empty")
|
|
|
|
case <-time.After(testTimeout):
|
|
require.Fail(t, "timeout waiting for result")
|
|
}
|
|
}
|
|
|
|
// sendPaymentAndAssertSucceeded calls `resumePayment` and asserts that the
|
|
// returned preimage is correct.
|
|
func sendPaymentAndAssertSucceeded(t *testing.T,
|
|
p *paymentLifecycle, expected lntypes.Preimage) {
|
|
|
|
resultChan := make(chan *resumePaymentResult, 1)
|
|
|
|
// We now make a call to `resumePayment` and expect it to return the
|
|
// preimage.
|
|
go func() {
|
|
preimage, _, err := p.resumePayment(context.Background())
|
|
resultChan <- &resumePaymentResult{
|
|
preimage: preimage,
|
|
err: err,
|
|
}
|
|
}()
|
|
|
|
// Validate the returned values or timeout.
|
|
select {
|
|
case r := <-resultChan:
|
|
require.NoError(t, r.err, "unexpected error")
|
|
require.EqualValues(t, expected, r.preimage,
|
|
"preimage not match")
|
|
|
|
case <-time.After(testTimeout):
|
|
require.Fail(t, "timeout waiting for result")
|
|
}
|
|
}
|
|
|
|
// createDummyRoute builds a route a->b->c paying the given amt to c.
|
|
func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route {
|
|
t.Helper()
|
|
|
|
priv, err := btcec.NewPrivateKey()
|
|
require.NoError(t, err, "failed to create private key")
|
|
hop1 := route.NewVertex(priv.PubKey())
|
|
|
|
priv, err = btcec.NewPrivateKey()
|
|
require.NoError(t, err, "failed to create private key")
|
|
hop2 := route.NewVertex(priv.PubKey())
|
|
|
|
hopFee := lnwire.NewMSatFromSatoshis(3)
|
|
hops := []*route.Hop{
|
|
{
|
|
ChannelID: 1,
|
|
PubKeyBytes: hop1,
|
|
LegacyPayload: true,
|
|
AmtToForward: amt + hopFee,
|
|
},
|
|
{
|
|
ChannelID: 2,
|
|
PubKeyBytes: hop2,
|
|
LegacyPayload: true,
|
|
AmtToForward: amt,
|
|
},
|
|
}
|
|
|
|
priv, err = btcec.NewPrivateKey()
|
|
require.NoError(t, err, "failed to create private key")
|
|
source := route.NewVertex(priv.PubKey())
|
|
|
|
// We create a simple route that we will supply every time the router
|
|
// requests one.
|
|
rt, err := route.NewRouteFromHops(amt+2*hopFee, 100, source, hops)
|
|
require.NoError(t, err, "failed to create route")
|
|
|
|
return rt
|
|
}
|
|
|
|
func makeSettledAttempt(t *testing.T, total int,
|
|
preimage lntypes.Preimage) *channeldb.HTLCAttempt {
|
|
|
|
return &channeldb.HTLCAttempt{
|
|
HTLCAttemptInfo: makeAttemptInfo(t, total),
|
|
Settle: &channeldb.HTLCSettleInfo{Preimage: preimage},
|
|
}
|
|
}
|
|
|
|
func makeFailedAttempt(t *testing.T, total int) *channeldb.HTLCAttempt {
|
|
return &channeldb.HTLCAttempt{
|
|
HTLCAttemptInfo: makeAttemptInfo(t, total),
|
|
Failure: &channeldb.HTLCFailInfo{
|
|
Reason: channeldb.HTLCFailInternal,
|
|
},
|
|
}
|
|
}
|
|
|
|
func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo {
|
|
rt := createDummyRoute(t, lnwire.MilliSatoshi(amt))
|
|
return channeldb.HTLCAttemptInfo{
|
|
Route: *rt,
|
|
}
|
|
}
|
|
|
|
// TestCheckTimeoutTimedOut checks that when the payment times out, it is
|
|
// marked as failed.
|
|
func TestCheckTimeoutTimedOut(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
deadline := time.Now().Add(time.Nanosecond)
|
|
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
|
defer cancel()
|
|
|
|
p := createTestPaymentLifecycle()
|
|
|
|
// Mock the control tower's `FailPayment` method.
|
|
ct := &mockControlTower{}
|
|
ct.On("FailPayment",
|
|
p.identifier, channeldb.FailureReasonTimeout).Return(nil)
|
|
|
|
// Mount the mocked control tower.
|
|
p.router.cfg.Control = ct
|
|
|
|
// Sleep one millisecond to make sure it timed out.
|
|
time.Sleep(1 * time.Millisecond)
|
|
|
|
// Call the function and expect no error.
|
|
err := p.checkContext(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// Assert that `FailPayment` is called as expected.
|
|
ct.AssertExpectations(t)
|
|
|
|
// We now test that when `FailPayment` returns an error, it's returned
|
|
// by the function too.
|
|
//
|
|
// Mock `FailPayment` to return a dummy error.
|
|
ct = &mockControlTower{}
|
|
ct.On("FailPayment",
|
|
p.identifier, channeldb.FailureReasonTimeout).Return(errDummy)
|
|
|
|
// Mount the mocked control tower.
|
|
p.router.cfg.Control = ct
|
|
|
|
// Make the timeout happens instantly.
|
|
deadline = time.Now().Add(time.Nanosecond)
|
|
ctx, cancel = context.WithDeadline(context.Background(), deadline)
|
|
defer cancel()
|
|
|
|
// Sleep one millisecond to make sure it timed out.
|
|
time.Sleep(1 * time.Millisecond)
|
|
|
|
// Call the function and expect an error.
|
|
err = p.checkContext(ctx)
|
|
require.ErrorIs(t, err, errDummy)
|
|
|
|
// Assert that `FailPayment` is called as expected.
|
|
ct.AssertExpectations(t)
|
|
}
|
|
|
|
// TestCheckTimeoutOnRouterQuit checks that when the router has quit, an error
|
|
// is returned from checkTimeout.
|
|
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
p := createTestPaymentLifecycle()
|
|
|
|
close(p.router.quit)
|
|
err := p.checkContext(ctx)
|
|
require.ErrorIs(t, err, ErrRouterShuttingDown)
|
|
}
|
|
|
|
// TestRequestRouteSucceed checks that `requestRoute` can successfully request
|
|
// a route.
|
|
func TestRequestRouteSucceed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
p := createTestPaymentLifecycle()
|
|
|
|
// Create a mock payment session and a dummy route.
|
|
paySession := &mockPaymentSession{}
|
|
dummyRoute := &route.Route{}
|
|
|
|
// Mount the mocked payment session.
|
|
p.paySession = paySession
|
|
|
|
// Create a dummy payment state.
|
|
ps := &channeldb.MPPaymentState{
|
|
NumAttemptsInFlight: 1,
|
|
RemainingAmt: 1,
|
|
FeesPaid: 100,
|
|
}
|
|
|
|
// Mock remainingFees to be 1.
|
|
p.feeLimit = ps.FeesPaid + 1
|
|
|
|
// Mock the paySession's `RequestRoute` method to return no error.
|
|
paySession.On("RequestRoute",
|
|
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
|
mock.Anything,
|
|
).Return(dummyRoute, nil)
|
|
|
|
result, err := p.requestRoute(ps)
|
|
require.NoError(t, err, "expect no error")
|
|
require.Equal(t, dummyRoute, result, "returned route not matched")
|
|
|
|
// Assert that `RequestRoute` is called as expected.
|
|
paySession.AssertExpectations(t)
|
|
}
|
|
|
|
// TestRequestRouteHandleCriticalErr checks that `requestRoute` can
|
|
// successfully handle a critical error returned from payment session.
|
|
func TestRequestRouteHandleCriticalErr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
p := createTestPaymentLifecycle()
|
|
|
|
// Create a mock payment session.
|
|
paySession := &mockPaymentSession{}
|
|
|
|
// Mount the mocked payment session.
|
|
p.paySession = paySession
|
|
|
|
// Create a dummy payment state.
|
|
ps := &channeldb.MPPaymentState{
|
|
NumAttemptsInFlight: 1,
|
|
RemainingAmt: 1,
|
|
FeesPaid: 100,
|
|
}
|
|
|
|
// Mock remainingFees to be 1.
|
|
p.feeLimit = ps.FeesPaid + 1
|
|
|
|
// Mock the paySession's `RequestRoute` method to return an error.
|
|
paySession.On("RequestRoute",
|
|
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
|
mock.Anything,
|
|
).Return(nil, errDummy)
|
|
|
|
result, err := p.requestRoute(ps)
|
|
|
|
// Expect an error is returned since it's critical.
|
|
require.ErrorIs(t, err, errDummy, "error not matched")
|
|
require.Nil(t, result, "expected no route returned")
|
|
|
|
// Assert that `RequestRoute` is called as expected.
|
|
paySession.AssertExpectations(t)
|
|
}
|
|
|
|
// TestRequestRouteHandleNoRouteErr checks that `requestRoute` can successfully
|
|
// handle the `noRouteError` returned from payment session.
|
|
func TestRequestRouteHandleNoRouteErr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a paymentLifecycle with mockers.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
// Create a dummy payment state.
|
|
ps := &channeldb.MPPaymentState{
|
|
NumAttemptsInFlight: 1,
|
|
RemainingAmt: 1,
|
|
FeesPaid: 100,
|
|
}
|
|
|
|
// Mock remainingFees to be 1.
|
|
p.feeLimit = ps.FeesPaid + 1
|
|
|
|
// Mock the paySession's `RequestRoute` method to return a NoRouteErr
|
|
// type.
|
|
m.paySession.On("RequestRoute",
|
|
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
|
mock.Anything,
|
|
).Return(nil, errNoTlvPayload)
|
|
|
|
// The payment should be failed with reason no route.
|
|
m.control.On("FailPayment",
|
|
p.identifier, channeldb.FailureReasonNoRoute,
|
|
).Return(nil).Once()
|
|
|
|
result, err := p.requestRoute(ps)
|
|
|
|
// Expect no error is returned since it's not critical.
|
|
require.NoError(t, err, "expected no error")
|
|
require.Nil(t, result, "expected no route returned")
|
|
}
|
|
|
|
// TestRequestRouteFailPaymentError checks that `requestRoute` returns the
|
|
// error from calling `FailPayment`.
|
|
func TestRequestRouteFailPaymentError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
p := createTestPaymentLifecycle()
|
|
|
|
// Create a mock payment session.
|
|
paySession := &mockPaymentSession{}
|
|
|
|
// Mock the control tower's `FailPayment` method.
|
|
ct := &mockControlTower{}
|
|
ct.On("FailPayment",
|
|
p.identifier, errNoTlvPayload.FailureReason(),
|
|
).Return(errDummy)
|
|
|
|
// Mount the mocked control tower and payment session.
|
|
p.router.cfg.Control = ct
|
|
p.paySession = paySession
|
|
|
|
// Create a dummy payment state with zero inflight attempts.
|
|
ps := &channeldb.MPPaymentState{
|
|
NumAttemptsInFlight: 0,
|
|
RemainingAmt: 1,
|
|
FeesPaid: 100,
|
|
}
|
|
|
|
// Mock remainingFees to be 1.
|
|
p.feeLimit = ps.FeesPaid + 1
|
|
|
|
// Mock the paySession's `RequestRoute` method to return an error.
|
|
paySession.On("RequestRoute",
|
|
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
|
mock.Anything,
|
|
).Return(nil, errNoTlvPayload)
|
|
|
|
result, err := p.requestRoute(ps)
|
|
|
|
// Expect an error is returned.
|
|
require.ErrorIs(t, err, errDummy, "error not matched")
|
|
require.Nil(t, result, "expected no route returned")
|
|
|
|
// Assert that `RequestRoute` is called as expected.
|
|
paySession.AssertExpectations(t)
|
|
|
|
// Assert that `FailPayment` is called as expected.
|
|
ct.AssertExpectations(t)
|
|
}
|
|
|
|
// TestDecideNextStep checks the method `decideNextStep` behaves as expected.
|
|
func TestDecideNextStep(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// mockReturn is used to hold the return values from AllowMoreAttempts
|
|
// or NeedWaitAttempts.
|
|
type mockReturn struct {
|
|
allowOrWait bool
|
|
err error
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
allowMoreAttempts *mockReturn
|
|
needWaitAttempts *mockReturn
|
|
|
|
// When the attemptResultChan has returned.
|
|
closeResultChan bool
|
|
|
|
// Whether the router has quit.
|
|
routerQuit bool
|
|
|
|
expectedStep stateStep
|
|
expectedErr error
|
|
}{
|
|
{
|
|
name: "allow more attempts",
|
|
allowMoreAttempts: &mockReturn{true, nil},
|
|
expectedStep: stepProceed,
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "error on allow more attempts",
|
|
allowMoreAttempts: &mockReturn{false, errDummy},
|
|
expectedStep: stepExit,
|
|
expectedErr: errDummy,
|
|
},
|
|
{
|
|
name: "no wait and exit",
|
|
allowMoreAttempts: &mockReturn{false, nil},
|
|
needWaitAttempts: &mockReturn{false, nil},
|
|
expectedStep: stepExit,
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "wait returns an error",
|
|
allowMoreAttempts: &mockReturn{false, nil},
|
|
needWaitAttempts: &mockReturn{false, errDummy},
|
|
expectedStep: stepExit,
|
|
expectedErr: errDummy,
|
|
},
|
|
|
|
{
|
|
name: "wait and exit on result chan",
|
|
allowMoreAttempts: &mockReturn{false, nil},
|
|
needWaitAttempts: &mockReturn{true, nil},
|
|
closeResultChan: true,
|
|
expectedStep: stepSkip,
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "wait and exit on router quit",
|
|
allowMoreAttempts: &mockReturn{false, nil},
|
|
needWaitAttempts: &mockReturn{true, nil},
|
|
routerQuit: true,
|
|
expectedStep: stepExit,
|
|
expectedErr: ErrRouterShuttingDown,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
|
|
// Create a test paymentLifecycle.
|
|
p := createTestPaymentLifecycle()
|
|
|
|
// Make a mock payment.
|
|
payment := &mockMPPayment{}
|
|
|
|
// Mock the method AllowMoreAttempts.
|
|
payment.On("AllowMoreAttempts").Return(
|
|
tc.allowMoreAttempts.allowOrWait,
|
|
tc.allowMoreAttempts.err,
|
|
).Once()
|
|
|
|
// Mock the method NeedWaitAttempts.
|
|
if tc.needWaitAttempts != nil {
|
|
payment.On("NeedWaitAttempts").Return(
|
|
tc.needWaitAttempts.allowOrWait,
|
|
tc.needWaitAttempts.err,
|
|
).Once()
|
|
}
|
|
|
|
// Send a nil error to the attemptResultChan if requested.
|
|
if tc.closeResultChan {
|
|
p.resultCollected = make(chan error, 1)
|
|
p.resultCollected <- nil
|
|
}
|
|
|
|
// Quit the router if requested.
|
|
if tc.routerQuit {
|
|
close(p.router.quit)
|
|
}
|
|
|
|
// Once the setup is finished, run the test cases.
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
step, err := p.decideNextStep(payment)
|
|
require.Equal(t, tc.expectedStep, step)
|
|
require.ErrorIs(t, tc.expectedErr, err)
|
|
})
|
|
|
|
// Check the payment's methods are called as expected.
|
|
payment.AssertExpectations(t)
|
|
}
|
|
}
|
|
|
|
// TestResumePaymentFailOnFetchPayment checks when we fail to fetch the
|
|
// payment, the error is returned.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailOnFetchPayment(t *testing.T) {
|
|
// Create a test paymentLifecycle.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
// Mock an error returned.
|
|
m.control.On("FetchPayment", p.identifier).Return(nil, errDummy)
|
|
|
|
// Send the payment and assert it failed.
|
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentFailOnTimeout checks that when timeout is reached, the
|
|
// payment is failed.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailOnTimeout(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is not
|
|
// critical, so we loosen the checks on how many times it's been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 3. make the timeout happens instantly and sleep one millisecond to
|
|
// make sure it timed out.
|
|
deadline := time.Now().Add(time.Nanosecond)
|
|
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
|
defer cancel()
|
|
time.Sleep(1 * time.Millisecond)
|
|
|
|
// 4. the payment should be failed with reason timeout.
|
|
m.control.On("FailPayment",
|
|
p.identifier, channeldb.FailureReasonTimeout,
|
|
).Return(nil).Once()
|
|
|
|
// 5. decideNextStep now returns stepExit.
|
|
m.payment.On("AllowMoreAttempts").Return(false, nil).Once().
|
|
On("NeedWaitAttempts").Return(false, nil).Once()
|
|
|
|
// 6. control tower deletes failed attempts.
|
|
m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once()
|
|
|
|
// 7. the payment returns the failed reason.
|
|
reason := channeldb.FailureReasonTimeout
|
|
m.payment.On("TerminalInfo").Return(nil, &reason)
|
|
|
|
// Send the payment and assert it failed with the timeout reason.
|
|
sendPaymentAndAssertError(t, ctx, p, reason)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentFailOnTimeoutErr checks that the lifecycle fails when an
|
|
// error is returned from `checkTimeout`.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailOnTimeoutErr(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is
|
|
// not critical so we loosen the checks on how many times it's
|
|
// been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 3. quit the router to return an error.
|
|
close(p.router.quit)
|
|
|
|
// Send the payment and assert it failed when router is shutting down.
|
|
sendPaymentAndAssertError(
|
|
t, context.Background(), p, ErrRouterShuttingDown,
|
|
)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentFailContextCancel checks that the lifecycle fails when the
|
|
// context is canceled.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailContextCancel(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
// Create the cancelable payment context.
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is not
|
|
// critical, so we loosen the checks on how many times it's been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 3. Cancel the context and skip the FailPayment error to trigger the
|
|
// context cancellation of the payment.
|
|
cancel()
|
|
|
|
m.control.On(
|
|
"FailPayment", p.identifier, channeldb.FailureReasonCanceled,
|
|
).Return(nil).Once()
|
|
|
|
// 4. decideNextStep now returns stepExit.
|
|
m.payment.On("AllowMoreAttempts").Return(false, nil).Once().
|
|
On("NeedWaitAttempts").Return(false, nil).Once()
|
|
|
|
// 5. Control tower deletes failed attempts.
|
|
m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once()
|
|
|
|
// 6. We will observe FailureReasonError if the context was cancelled.
|
|
reason := channeldb.FailureReasonError
|
|
m.payment.On("TerminalInfo").Return(nil, &reason)
|
|
|
|
// Send the payment and assert it failed with the timeout reason.
|
|
sendPaymentAndAssertError(t, ctx, p, reason)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentFailOnStepErr checks that the lifecycle fails when an
|
|
// error is returned from `decideNextStep`.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailOnStepErr(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is
|
|
// not critical so we loosen the checks on how many times it's
|
|
// been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 3. decideNextStep now returns an error.
|
|
m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once()
|
|
|
|
// Send the payment and assert it failed.
|
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentFailOnRequestRouteErr checks that the lifecycle fails when
|
|
// an error is returned from `requestRoute`.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailOnRequestRouteErr(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is
|
|
// not critical so we loosen the checks on how many times it's
|
|
// been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 3. decideNextStep now returns stepProceed.
|
|
m.payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
|
|
|
// 4. mock requestRoute to return an error.
|
|
m.paySession.On("RequestRoute",
|
|
paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight),
|
|
uint32(p.currentHeight), mock.Anything,
|
|
).Return(nil, errDummy).Once()
|
|
|
|
// Send the payment and assert it failed.
|
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentFailOnRegisterAttemptErr checks that the lifecycle fails
|
|
// when an error is returned from `registerAttempt`.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
// Create a dummy route that will be returned by `RequestRoute`.
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
rt := createDummyRoute(t, paymentAmt)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is
|
|
// not critical so we loosen the checks on how many times it's
|
|
// been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 3. decideNextStep now returns stepProceed.
|
|
m.payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
|
|
|
// 4. mock requestRoute to return an route.
|
|
m.paySession.On("RequestRoute",
|
|
paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight),
|
|
uint32(p.currentHeight), mock.Anything,
|
|
).Return(rt, nil).Once()
|
|
|
|
// 5. mock shardTracker used in `createNewPaymentAttempt` to return an
|
|
// error.
|
|
//
|
|
// Mock NextPaymentID to always return the attemptID.
|
|
attemptID := uint64(1)
|
|
p.router.cfg.NextPaymentID = func() (uint64, error) {
|
|
return attemptID, nil
|
|
}
|
|
|
|
// Return an error to end the lifecycle.
|
|
m.shardTracker.On("NewShard",
|
|
attemptID, true,
|
|
).Return(nil, errDummy).Once()
|
|
|
|
// Send the payment and assert it failed.
|
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentFailOnSendAttemptErr checks that the lifecycle fails when
|
|
// an error is returned from `sendAttempt`.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentFailOnSendAttemptErr(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
// Create a dummy route that will be returned by `RequestRoute`.
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
rt := createDummyRoute(t, paymentAmt)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is
|
|
// not critical so we loosen the checks on how many times it's
|
|
// been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 3. decideNextStep now returns stepProceed.
|
|
m.payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
|
|
|
// 4. mock requestRoute to return an route.
|
|
m.paySession.On("RequestRoute",
|
|
paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight),
|
|
uint32(p.currentHeight), mock.Anything,
|
|
).Return(rt, nil).Once()
|
|
|
|
// 5. mock `registerAttempt` to return an attempt.
|
|
//
|
|
// Mock NextPaymentID to always return the attemptID.
|
|
attemptID := uint64(1)
|
|
p.router.cfg.NextPaymentID = func() (uint64, error) {
|
|
return attemptID, nil
|
|
}
|
|
|
|
// Mock shardTracker to return the mock shard.
|
|
m.shardTracker.On("NewShard",
|
|
attemptID, true,
|
|
).Return(m.shard, nil).Once()
|
|
|
|
// Mock the methods on the shard.
|
|
m.shard.On("MPP").Return(&record.MPP{}).Twice().
|
|
On("AMP").Return(nil).Once().
|
|
On("Hash").Return(p.identifier).Once()
|
|
|
|
// Mock the time and expect it to be call twice.
|
|
m.clock.On("Now").Return(time.Now()).Twice()
|
|
|
|
// We now register attempt and return no error.
|
|
m.control.On("RegisterAttempt",
|
|
p.identifier, mock.Anything,
|
|
).Return(nil).Once()
|
|
|
|
// 6. mock `sendAttempt` to return an error.
|
|
m.payer.On("SendHTLC",
|
|
mock.Anything, attemptID, mock.Anything,
|
|
).Return(errDummy).Once()
|
|
|
|
// The above error will end up being handled by `handleSwitchErr`, in
|
|
// which we'd fail the payment, cancel the shard and fail the attempt.
|
|
//
|
|
// `FailPayment` should be called with an internal reason.
|
|
reason := channeldb.FailureReasonError
|
|
m.control.On("FailPayment", p.identifier, reason).Return(nil).Once()
|
|
|
|
// `CancelShard` should be called with the attemptID.
|
|
m.shardTracker.On("CancelShard", attemptID).Return(nil).Once()
|
|
|
|
// Mock `FailAttempt` to return a dummy error to exit the loop.
|
|
m.control.On("FailAttempt",
|
|
p.identifier, attemptID, mock.Anything,
|
|
).Return(nil, errDummy).Once()
|
|
|
|
// Send the payment and assert it failed.
|
|
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
|
|
|
|
// Expected collectResultAsync to not be called.
|
|
require.Zero(t, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentSuccess checks that a normal payment flow that is
|
|
// succeeded.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentSuccess(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
// Create a dummy route that will be returned by `RequestRoute`.
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
rt := createDummyRoute(t, paymentAmt)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1.1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 1.2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is
|
|
// not critical so we loosen the checks on how many times it's
|
|
// been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 1.3. decideNextStep now returns stepProceed.
|
|
m.payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
|
|
|
// 1.4. mock requestRoute to return an route.
|
|
m.paySession.On("RequestRoute",
|
|
paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight),
|
|
uint32(p.currentHeight), mock.Anything,
|
|
).Return(rt, nil).Once()
|
|
|
|
// 1.5. mock `registerAttempt` to return an attempt.
|
|
//
|
|
// Mock NextPaymentID to always return the attemptID.
|
|
attemptID := uint64(1)
|
|
p.router.cfg.NextPaymentID = func() (uint64, error) {
|
|
return attemptID, nil
|
|
}
|
|
|
|
// Mock shardTracker to return the mock shard.
|
|
m.shardTracker.On("NewShard",
|
|
attemptID, true,
|
|
).Return(m.shard, nil).Once()
|
|
|
|
// Mock the methods on the shard.
|
|
m.shard.On("MPP").Return(&record.MPP{}).Twice().
|
|
On("AMP").Return(nil).Once().
|
|
On("Hash").Return(p.identifier).Once()
|
|
|
|
// Mock the time and expect it to be called.
|
|
m.clock.On("Now").Return(time.Now())
|
|
|
|
// We now register attempt and return no error.
|
|
m.control.On("RegisterAttempt",
|
|
p.identifier, mock.Anything,
|
|
).Return(nil).Once()
|
|
|
|
// 1.6. mock `sendAttempt` to succeed, which brings us into the next
|
|
// iteration of the lifecycle.
|
|
m.payer.On("SendHTLC",
|
|
mock.Anything, attemptID, mock.Anything,
|
|
).Return(nil).Once()
|
|
|
|
// We now enter the second iteration of the lifecycle loop.
|
|
//
|
|
// 2.1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2.2. calls `GetState` and return the state.
|
|
m.payment.On("GetState").Return(ps).Run(func(args mock.Arguments) {
|
|
ps.RemainingAmt = 0
|
|
}).Once()
|
|
|
|
// 2.3. decideNextStep now returns stepExit and exits the loop.
|
|
m.payment.On("AllowMoreAttempts").Return(false, nil).Once().
|
|
On("NeedWaitAttempts").Return(false, nil).Once()
|
|
|
|
// We should perform an optional deletion over failed attempts.
|
|
m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once()
|
|
|
|
// Finally, mock the `TerminalInfo` to return the settled attempt.
|
|
// Create a SettleAttempt.
|
|
testPreimage := lntypes.Preimage{1, 2, 3}
|
|
settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage)
|
|
m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once()
|
|
|
|
// Send the payment and assert the preimage is matched.
|
|
sendPaymentAndAssertSucceeded(t, p, testPreimage)
|
|
|
|
// Expected collectResultAsync to called.
|
|
require.Equal(t, 1, m.collectResultsCount)
|
|
}
|
|
|
|
// TestResumePaymentSuccessWithTwoAttempts checks a successful payment flow
|
|
// with two HTLC attempts.
|
|
//
|
|
// NOTE: No parallel test because it overwrites global variables.
|
|
//
|
|
//nolint:paralleltest
|
|
func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) {
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := setupTestPaymentLifecycle(t)
|
|
|
|
// Create a dummy route that will be returned by `RequestRoute`.
|
|
paymentAmt := lnwire.MilliSatoshi(10000)
|
|
rt := createDummyRoute(t, paymentAmt/2)
|
|
|
|
// We now enter the payment lifecycle loop.
|
|
//
|
|
// 1.1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 1.2. calls `GetState` and return the state.
|
|
ps := &channeldb.MPPaymentState{
|
|
RemainingAmt: paymentAmt,
|
|
}
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// NOTE: GetStatus is only used to populate the logs which is
|
|
// not critical so we loosen the checks on how many times it's
|
|
// been called.
|
|
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
|
|
|
|
// 1.3. decideNextStep now returns stepProceed.
|
|
m.payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
|
|
|
// 1.4. mock requestRoute to return an route.
|
|
m.paySession.On("RequestRoute",
|
|
paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight),
|
|
uint32(p.currentHeight), mock.Anything,
|
|
).Return(rt, nil).Once()
|
|
|
|
// Create two attempt IDs here.
|
|
attemptID1 := uint64(1)
|
|
attemptID2 := uint64(2)
|
|
|
|
// 1.5. mock `registerAttempt` to return an attempt.
|
|
//
|
|
// Mock NextPaymentID to return the first attemptID on the first call
|
|
// and the second attemptID on the second call.
|
|
var numAttempts atomic.Uint64
|
|
p.router.cfg.NextPaymentID = func() (uint64, error) {
|
|
numAttempts.Add(1)
|
|
if numAttempts.Load() == 1 {
|
|
return attemptID1, nil
|
|
}
|
|
|
|
return attemptID2, nil
|
|
}
|
|
|
|
// Mock shardTracker to return the mock shard.
|
|
m.shardTracker.On("NewShard",
|
|
attemptID1, false,
|
|
).Return(m.shard, nil).Once()
|
|
|
|
// Mock the methods on the shard.
|
|
m.shard.On("MPP").Return(&record.MPP{}).Twice().
|
|
On("AMP").Return(nil).Once().
|
|
On("Hash").Return(p.identifier).Once()
|
|
|
|
// Mock the time and expect it to be called.
|
|
m.clock.On("Now").Return(time.Now())
|
|
|
|
// We now register attempt and return no error.
|
|
m.control.On("RegisterAttempt",
|
|
p.identifier, mock.Anything,
|
|
).Return(nil).Run(func(args mock.Arguments) {
|
|
ps.RemainingAmt = paymentAmt / 2
|
|
}).Once()
|
|
|
|
// 1.6. mock `sendAttempt` to succeed, which brings us into the next
|
|
// iteration of the lifecycle where we mock a temp failure.
|
|
m.payer.On("SendHTLC",
|
|
mock.Anything, attemptID1, mock.Anything,
|
|
).Return(nil).Once()
|
|
|
|
// We now enter the second iteration of the lifecycle loop.
|
|
//
|
|
// 2.1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 2.2. calls `GetState` and return the state.
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// 2.3. decideNextStep now returns stepProceed so we can send the
|
|
// second attempt.
|
|
m.payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
|
|
|
// 2.4. mock requestRoute to return an route.
|
|
m.paySession.On("RequestRoute",
|
|
paymentAmt/2, p.feeLimit, uint32(ps.NumAttemptsInFlight),
|
|
uint32(p.currentHeight), mock.Anything,
|
|
).Return(rt, nil).Once()
|
|
|
|
// 2.5. mock `registerAttempt` to return an attempt.
|
|
//
|
|
// Mock shardTracker to return the mock shard.
|
|
m.shardTracker.On("NewShard",
|
|
attemptID2, true,
|
|
).Return(m.shard, nil).Once()
|
|
|
|
// Mock the methods on the shard.
|
|
m.shard.On("MPP").Return(&record.MPP{}).Twice().
|
|
On("AMP").Return(nil).Once().
|
|
On("Hash").Return(p.identifier).Once()
|
|
|
|
// We now register attempt and return no error.
|
|
m.control.On("RegisterAttempt",
|
|
p.identifier, mock.Anything,
|
|
).Return(nil).Once()
|
|
|
|
// 2.6. mock `sendAttempt` to succeed, which brings us into the next
|
|
// iteration of the lifecycle.
|
|
m.payer.On("SendHTLC",
|
|
mock.Anything, attemptID2, mock.Anything,
|
|
).Return(nil).Once()
|
|
|
|
// We now enter the third iteration of the lifecycle loop.
|
|
//
|
|
// 3.1. calls `FetchPayment` and return the payment.
|
|
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
|
|
|
|
// 3.2. calls `GetState` and return the state.
|
|
m.payment.On("GetState").Return(ps).Once()
|
|
|
|
// 3.3. decideNextStep now returns stepExit to exit the loop.
|
|
m.payment.On("AllowMoreAttempts").Return(false, nil).Once().
|
|
On("NeedWaitAttempts").Return(false, nil).Once()
|
|
|
|
// We should perform an optional deletion over failed attempts.
|
|
m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once()
|
|
|
|
// Finally, mock the `TerminalInfo` to return the settled attempt.
|
|
// Create a SettleAttempt.
|
|
testPreimage := lntypes.Preimage{1, 2, 3}
|
|
settledAttempt := makeSettledAttempt(t, int(paymentAmt), testPreimage)
|
|
m.payment.On("TerminalInfo").Return(settledAttempt, nil).Once()
|
|
|
|
// Send the payment and assert the preimage is matched.
|
|
sendPaymentAndAssertSucceeded(t, p, testPreimage)
|
|
|
|
// Expected collectResultAsync to called.
|
|
require.Equal(t, 2, m.collectResultsCount)
|
|
}
|
|
|
|
// TestCollectResultExitOnErr checks that when there's an error returned from
|
|
// htlcswitch via `GetAttemptResult`, it's handled and returned.
|
|
func TestCollectResultExitOnErr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
attempt := makeFailedAttempt(t, paymentAmt)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a dummy error.
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(nil, errDummy).Once()
|
|
|
|
// The above error will end up being handled by `handleSwitchErr`, in
|
|
// which we'd fail the payment, cancel the shard and fail the attempt.
|
|
//
|
|
// `FailPayment` should be called with an internal reason.
|
|
reason := channeldb.FailureReasonError
|
|
m.control.On("FailPayment", p.identifier, reason).Return(nil).Once()
|
|
|
|
// `CancelShard` should be called with the attemptID.
|
|
m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once()
|
|
|
|
// Mock `FailAttempt` to return a dummy error.
|
|
m.control.On("FailAttempt",
|
|
p.identifier, attempt.AttemptID, mock.Anything,
|
|
).Return(nil, errDummy).Once()
|
|
|
|
// Mock the clock to return a current time.
|
|
m.clock.On("Now").Return(time.Now())
|
|
|
|
// Now call the method under test.
|
|
result, err := p.collectResult(attempt)
|
|
require.ErrorIs(t, err, errDummy, "expected dummy error")
|
|
require.Nil(t, result, "expected nil attempt")
|
|
}
|
|
|
|
// TestCollectResultExitOnResultErr checks that when there's an error returned
|
|
// from htlcswitch via the result channel, it's handled and returned.
|
|
func TestCollectResultExitOnResultErr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
attempt := makeFailedAttempt(t, paymentAmt)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a the result chan.
|
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(resultChan, nil).Once().Run(func(args mock.Arguments) {
|
|
// Send an error to the result chan.
|
|
resultChan <- &htlcswitch.PaymentResult{
|
|
Error: errDummy,
|
|
}
|
|
})
|
|
|
|
// The above error will end up being handled by `handleSwitchErr`, in
|
|
// which we'd fail the payment, cancel the shard and fail the attempt.
|
|
//
|
|
// `FailPayment` should be called with an internal reason.
|
|
reason := channeldb.FailureReasonError
|
|
m.control.On("FailPayment", p.identifier, reason).Return(nil).Once()
|
|
|
|
// `CancelShard` should be called with the attemptID.
|
|
m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once()
|
|
|
|
// Mock `FailAttempt` to return a dummy error.
|
|
m.control.On("FailAttempt",
|
|
p.identifier, attempt.AttemptID, mock.Anything,
|
|
).Return(nil, errDummy).Once()
|
|
|
|
// Mock the clock to return a current time.
|
|
m.clock.On("Now").Return(time.Now())
|
|
|
|
// Now call the method under test.
|
|
result, err := p.collectResult(attempt)
|
|
require.ErrorIs(t, err, errDummy, "expected dummy error")
|
|
require.Nil(t, result, "expected nil attempt")
|
|
}
|
|
|
|
// TestCollectResultExitOnSwitcQuit checks that when the htlcswitch is shutting
|
|
// down an error is returned.
|
|
func TestCollectResultExitOnSwitchQuit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
attempt := makeFailedAttempt(t, paymentAmt)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a the result chan.
|
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(resultChan, nil).Once().Run(func(args mock.Arguments) {
|
|
// Close the result chan to simulate a htlcswitch quit.
|
|
close(resultChan)
|
|
})
|
|
|
|
// Now call the method under test.
|
|
result, err := p.collectResult(attempt)
|
|
require.ErrorIs(t, err, htlcswitch.ErrSwitchExiting,
|
|
"expected switch exit")
|
|
require.Nil(t, result, "expected nil attempt")
|
|
}
|
|
|
|
// TestCollectResultExitOnRouterQuit checks that when the channel router is
|
|
// shutting down an error is returned.
|
|
func TestCollectResultExitOnRouterQuit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
attempt := makeFailedAttempt(t, paymentAmt)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a the result chan.
|
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(resultChan, nil).Once().Run(func(args mock.Arguments) {
|
|
// Close the channel router.
|
|
close(p.router.quit)
|
|
})
|
|
|
|
// Now call the method under test.
|
|
result, err := p.collectResult(attempt)
|
|
require.ErrorIs(t, err, ErrRouterShuttingDown, "expected router exit")
|
|
require.Nil(t, result, "expected nil attempt")
|
|
}
|
|
|
|
// TestCollectResultExitOnLifecycleQuit checks that when the payment lifecycle
|
|
// is shutting down an error is returned.
|
|
func TestCollectResultExitOnLifecycleQuit(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
attempt := makeFailedAttempt(t, paymentAmt)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a the result chan.
|
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(resultChan, nil).Once().Run(func(args mock.Arguments) {
|
|
// Stop the lifecycle.
|
|
p.stop()
|
|
})
|
|
|
|
// Now call the method under test.
|
|
result, err := p.collectResult(attempt)
|
|
require.ErrorIs(t, err, ErrPaymentLifecycleExiting,
|
|
"expected lifecycle exit")
|
|
require.Nil(t, result, "expected nil attempt")
|
|
}
|
|
|
|
// TestCollectResultExitOnSettleErr checks that when settling the attempt
|
|
// fails an error is returned.
|
|
func TestCollectResultExitOnSettleErr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
preimage := lntypes.Preimage{1}
|
|
attempt := makeSettledAttempt(t, paymentAmt, preimage)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a the result chan.
|
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(resultChan, nil).Once().Run(func(args mock.Arguments) {
|
|
// Send the preimage to the result chan.
|
|
resultChan <- &htlcswitch.PaymentResult{
|
|
Preimage: preimage,
|
|
}
|
|
})
|
|
|
|
// Once the result is received, `ReportPaymentSuccess` should be
|
|
// called.
|
|
m.missionControl.On("ReportPaymentSuccess",
|
|
attempt.AttemptID, &attempt.Route,
|
|
).Return(nil).Once()
|
|
|
|
// Now mock an error being returned from `SettleAttempt`.
|
|
m.control.On("SettleAttempt",
|
|
p.identifier, attempt.AttemptID, mock.Anything,
|
|
).Return(nil, errDummy).Once()
|
|
|
|
// Mock the clock to return a current time.
|
|
m.clock.On("Now").Return(time.Now())
|
|
|
|
// Now call the method under test.
|
|
result, err := p.collectResult(attempt)
|
|
require.ErrorIs(t, err, errDummy, "expected settle error")
|
|
require.Nil(t, result, "expected nil attempt")
|
|
}
|
|
|
|
// TestCollectResultSuccess checks a successful htlc settlement.
|
|
func TestCollectResultSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
preimage := lntypes.Preimage{1}
|
|
attempt := makeSettledAttempt(t, paymentAmt, preimage)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a the result chan.
|
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(resultChan, nil).Once().Run(func(args mock.Arguments) {
|
|
// Send the preimage to the result chan.
|
|
resultChan <- &htlcswitch.PaymentResult{
|
|
Preimage: preimage,
|
|
}
|
|
})
|
|
|
|
// Once the result is received, `ReportPaymentSuccess` should be
|
|
// called.
|
|
m.missionControl.On("ReportPaymentSuccess",
|
|
attempt.AttemptID, &attempt.Route,
|
|
).Return(nil).Once()
|
|
|
|
// Now the settled htlc being returned from `SettleAttempt`.
|
|
m.control.On("SettleAttempt",
|
|
p.identifier, attempt.AttemptID, mock.Anything,
|
|
).Return(attempt, nil).Once()
|
|
|
|
// Mock the clock to return a current time.
|
|
m.clock.On("Now").Return(time.Now())
|
|
|
|
// Now call the method under test.
|
|
result, err := p.collectResult(attempt)
|
|
require.NoError(t, err, "expected no error")
|
|
require.Equal(t, preimage, result.attempt.Settle.Preimage,
|
|
"preimage mismatch")
|
|
}
|
|
|
|
// TestCollectResultAsyncSuccess checks a successful htlc settlement.
|
|
func TestCollectResultAsyncSuccess(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a test paymentLifecycle with the initial two calls mocked.
|
|
p, m := newTestPaymentLifecycle(t)
|
|
|
|
paymentAmt := 10_000
|
|
preimage := lntypes.Preimage{1}
|
|
attempt := makeSettledAttempt(t, paymentAmt, preimage)
|
|
|
|
// Mock shardTracker to return the payment hash.
|
|
m.shardTracker.On("GetHash",
|
|
attempt.AttemptID,
|
|
).Return(p.identifier, nil).Once()
|
|
|
|
// Mock the htlcswitch to return a the result chan.
|
|
resultChan := make(chan *htlcswitch.PaymentResult, 1)
|
|
m.payer.On("GetAttemptResult",
|
|
attempt.AttemptID, p.identifier, mock.Anything,
|
|
).Return(resultChan, nil).Once().Run(func(args mock.Arguments) {
|
|
// Send the preimage to the result chan.
|
|
resultChan <- &htlcswitch.PaymentResult{
|
|
Preimage: preimage,
|
|
}
|
|
})
|
|
|
|
// Once the result is received, `ReportPaymentSuccess` should be
|
|
// called.
|
|
m.missionControl.On("ReportPaymentSuccess",
|
|
attempt.AttemptID, &attempt.Route,
|
|
).Return(nil).Once()
|
|
|
|
// Now the settled htlc being returned from `SettleAttempt`.
|
|
m.control.On("SettleAttempt",
|
|
p.identifier, attempt.AttemptID, mock.Anything,
|
|
).Return(attempt, nil).Once()
|
|
|
|
// Mock the clock to return a current time.
|
|
m.clock.On("Now").Return(time.Now())
|
|
|
|
// Now call the method under test.
|
|
p.collectResultAsync(attempt)
|
|
|
|
// Assert the result is returned within 5 seconds.
|
|
var err error
|
|
waitErr := wait.NoError(func() error {
|
|
err = <-p.resultCollected
|
|
return nil
|
|
}, testTimeout)
|
|
require.NoError(t, waitErr, "timeout waiting for result")
|
|
|
|
// Assert that a nil error is received.
|
|
require.NoError(t, err, "expected no error")
|
|
}
|