mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
routing: add methods checkTimeout
and requestRoute
This commit refactors the `resumePayment` method by adding the methods `checkTimeout` and `requestRoute` so it's easier to understand the flow and reason about the error handling.
This commit is contained in:
parent
7209c65ccf
commit
703ea08316
@ -673,6 +673,12 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
|
||||
activeShards, height uint32) (*route.Route, error) {
|
||||
|
||||
args := m.Called(maxAmt, feeLimit, activeShards, height)
|
||||
|
||||
// Type assertion on nil will fail, so we check and return here.
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return args.Get(0).(*route.Route), args.Error(1)
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-errors/errors"
|
||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
@ -204,70 +205,19 @@ lifecycle:
|
||||
// router is exiting. In either case, we'll stop this payment
|
||||
// attempt short. If a timeout is not applicable, timeoutChan
|
||||
// will be nil.
|
||||
select {
|
||||
case <-p.timeoutChan:
|
||||
log.Warnf("payment attempt not completed before " +
|
||||
"timeout")
|
||||
|
||||
// By marking the payment failed with the control
|
||||
// tower, no further shards will be launched and we'll
|
||||
// return with an error the moment all active shards
|
||||
// have finished.
|
||||
saveErr := p.router.cfg.Control.FailPayment(
|
||||
p.identifier, channeldb.FailureReasonTimeout,
|
||||
)
|
||||
if saveErr != nil {
|
||||
return exitWithErr(saveErr)
|
||||
}
|
||||
|
||||
continue lifecycle
|
||||
|
||||
case <-p.router.quit:
|
||||
return exitWithErr(ErrRouterShuttingDown)
|
||||
|
||||
// Fall through if we haven't hit our time limit.
|
||||
default:
|
||||
if err := p.checkTimeout(); err != nil {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
// Create a new payment attempt from the given payment session.
|
||||
rt, err := p.paySession.RequestRoute(
|
||||
ps.RemainingAmt, remainingFees,
|
||||
uint32(ps.NumAttemptsInFlight),
|
||||
uint32(p.currentHeight),
|
||||
)
|
||||
// Now request a route to be used to create our HTLC attempt.
|
||||
rt, err := p.requestRoute(ps)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to find route for payment %v: %v",
|
||||
p.identifier, err)
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
routeErr, ok := err.(noRouteError)
|
||||
if !ok {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
// There is no route to try, and we have no active
|
||||
// shards. This means that there is no way for us to
|
||||
// send the payment, so mark it failed with no route.
|
||||
if ps.NumAttemptsInFlight == 0 {
|
||||
failureCode := routeErr.FailureReason()
|
||||
log.Debugf("Marking payment %v permanently "+
|
||||
"failed with no route: %v",
|
||||
p.identifier, failureCode)
|
||||
|
||||
saveErr := p.router.cfg.Control.FailPayment(
|
||||
p.identifier, failureCode,
|
||||
)
|
||||
if saveErr != nil {
|
||||
return exitWithErr(saveErr)
|
||||
}
|
||||
|
||||
continue lifecycle
|
||||
}
|
||||
|
||||
// We still have active shards, we'll wait for an
|
||||
// outcome to be available before retrying.
|
||||
if err := p.waitForShard(); err != nil {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
// NOTE: might cause an infinite loop, see notes in
|
||||
// `requestRoute` for details.
|
||||
if rt == nil {
|
||||
continue lifecycle
|
||||
}
|
||||
|
||||
@ -293,6 +243,95 @@ lifecycle:
|
||||
}
|
||||
}
|
||||
|
||||
// checkTimeout checks whether the payment has reached its timeout.
|
||||
func (p *paymentLifecycle) checkTimeout() error {
|
||||
select {
|
||||
case <-p.timeoutChan:
|
||||
log.Warnf("payment attempt not completed before timeout")
|
||||
|
||||
// By marking the payment failed, depending on whether it has
|
||||
// inflight HTLCs or not, its status will now either be
|
||||
// `StatusInflight` or `StatusFailed`. In either case, no more
|
||||
// HTLCs will be attempted.
|
||||
err := p.router.cfg.Control.FailPayment(
|
||||
p.identifier, channeldb.FailureReasonTimeout,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("FailPayment got %w", err)
|
||||
}
|
||||
|
||||
case <-p.router.quit:
|
||||
return fmt.Errorf("check payment timeout got: %w",
|
||||
ErrRouterShuttingDown)
|
||||
|
||||
// Fall through if we haven't hit our time limit.
|
||||
default:
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// requestRoute is responsible for finding a route to be used to create an HTLC
|
||||
// attempt.
|
||||
func (p *paymentLifecycle) requestRoute(
|
||||
ps *channeldb.MPPaymentState) (*route.Route, error) {
|
||||
|
||||
remainingFees := p.calcFeeBudget(ps.FeesPaid)
|
||||
|
||||
// Query our payment session to construct a route.
|
||||
rt, err := p.paySession.RequestRoute(
|
||||
ps.RemainingAmt, remainingFees,
|
||||
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight),
|
||||
)
|
||||
|
||||
// Exit early if there's no error.
|
||||
if err == nil {
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
// Otherwise we need to handle the error.
|
||||
log.Warnf("Failed to find route for payment %v: %v", p.identifier, err)
|
||||
|
||||
// If the error belongs to `noRouteError` set, it means a non-critical
|
||||
// error has happened during path finding and we might be able to find
|
||||
// another route during next HTLC attempt. Otherwise, we'll return the
|
||||
// critical error found.
|
||||
var routeErr noRouteError
|
||||
if !errors.As(err, &routeErr) {
|
||||
return nil, fmt.Errorf("requestRoute got: %w", err)
|
||||
}
|
||||
|
||||
// There is no route to try, and we have no active shards. This means
|
||||
// that there is no way for us to send the payment, so mark it failed
|
||||
// with no route.
|
||||
//
|
||||
// NOTE: if we have zero `numShardsInFlight`, it means all the HTLC
|
||||
// attempts have failed. Otherwise, if there are still inflight
|
||||
// attempts, we might enter an infinite loop in our lifecycle if
|
||||
// there's still remaining amount since we will keep adding new HTLC
|
||||
// attempts and they all fail with `noRouteError`.
|
||||
//
|
||||
// TODO(yy): further check the error returned here. It's the
|
||||
// `paymentSession`'s responsibility to find a route for us with best
|
||||
// effort. When it cannot find a path, we need to treat it as a
|
||||
// terminal condition and fail the payment no matter it has inflight
|
||||
// HTLCs or not.
|
||||
if ps.NumAttemptsInFlight == 0 {
|
||||
failureCode := routeErr.FailureReason()
|
||||
log.Debugf("Marking payment %v permanently failed with no "+
|
||||
"route: %v", p.identifier, failureCode)
|
||||
|
||||
err := p.router.cfg.Control.FailPayment(
|
||||
p.identifier, failureCode,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("FailPayment got: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// stop signals any active shard goroutine to exit and waits for them to exit.
|
||||
func (p *paymentLifecycle) stop() {
|
||||
close(p.quit)
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -818,3 +819,296 @@ func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckTimeoutTimedOut checks that when the payment times out, it is
|
||||
// marked as failed.
|
||||
func TestCheckTimeoutTimedOut(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
// Mock the control tower's `FailPayment` method.
|
||||
ct := &mockControlTower{}
|
||||
ct.On("FailPayment",
|
||||
p.identifier, channeldb.FailureReasonTimeout).Return(nil)
|
||||
|
||||
// Mount the mocked control tower.
|
||||
p.router.cfg.Control = ct
|
||||
|
||||
// Make the timeout happens instantly.
|
||||
p.timeoutChan = time.After(1 * time.Nanosecond)
|
||||
|
||||
// Sleep one millisecond to make sure it timed out.
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
|
||||
// Call the function and expect no error.
|
||||
err := p.checkTimeout()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that `FailPayment` is called as expected.
|
||||
ct.AssertExpectations(t)
|
||||
|
||||
// We now test that when `FailPayment` returns an error, it's returned
|
||||
// by the function too.
|
||||
//
|
||||
// Mock `FailPayment` to return a dummy error.
|
||||
dummyErr := errors.New("dummy")
|
||||
ct = &mockControlTower{}
|
||||
ct.On("FailPayment",
|
||||
p.identifier, channeldb.FailureReasonTimeout).Return(dummyErr)
|
||||
|
||||
// Mount the mocked control tower.
|
||||
p.router.cfg.Control = ct
|
||||
|
||||
// Make the timeout happens instantly.
|
||||
p.timeoutChan = time.After(1 * time.Nanosecond)
|
||||
|
||||
// Sleep one millisecond to make sure it timed out.
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
|
||||
// Call the function and expect an error.
|
||||
err = p.checkTimeout()
|
||||
require.ErrorIs(t, err, dummyErr)
|
||||
|
||||
// Assert that `FailPayment` is called as expected.
|
||||
ct.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestCheckTimeoutOnRouterQuit checks that when the router has quit, an error
|
||||
// is returned from checkTimeout.
|
||||
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
close(p.router.quit)
|
||||
err := p.checkTimeout()
|
||||
require.ErrorIs(t, err, ErrRouterShuttingDown)
|
||||
}
|
||||
|
||||
// createTestPaymentLifecycle creates a `paymentLifecycle` using the mocks.
|
||||
func createTestPaymentLifecycle() *paymentLifecycle {
|
||||
paymentHash := lntypes.Hash{1, 2, 3}
|
||||
quitChan := make(chan struct{})
|
||||
rt := &ChannelRouter{
|
||||
cfg: &Config{},
|
||||
quit: quitChan,
|
||||
}
|
||||
|
||||
return &paymentLifecycle{
|
||||
router: rt,
|
||||
identifier: paymentHash,
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestRouteSucceed checks that `requestRoute` can successfully request
|
||||
// a route.
|
||||
func TestRequestRouteSucceed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
// Create a mock payment session and a dummy route.
|
||||
paySession := &mockPaymentSession{}
|
||||
dummyRoute := &route.Route{}
|
||||
|
||||
// Mount the mocked payment session.
|
||||
p.paySession = paySession
|
||||
|
||||
// Create a dummy payment state.
|
||||
ps := &channeldb.MPPaymentState{
|
||||
NumAttemptsInFlight: 1,
|
||||
RemainingAmt: 1,
|
||||
FeesPaid: 100,
|
||||
}
|
||||
|
||||
// Mock remainingFees to be 1.
|
||||
p.feeLimit = ps.FeesPaid + 1
|
||||
|
||||
// Mock the paySession's `RequestRoute` method to return no error.
|
||||
paySession.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(dummyRoute, nil)
|
||||
|
||||
result, err := p.requestRoute(ps)
|
||||
require.NoError(t, err, "expect no error")
|
||||
require.Equal(t, dummyRoute, result, "returned route not matched")
|
||||
|
||||
// Assert that `RequestRoute` is called as expected.
|
||||
paySession.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestRequestRouteHandleCriticalErr checks that `requestRoute` can
|
||||
// successfully handle a critical error returned from payment session.
|
||||
func TestRequestRouteHandleCriticalErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
// Create a mock payment session.
|
||||
paySession := &mockPaymentSession{}
|
||||
|
||||
// Mount the mocked payment session.
|
||||
p.paySession = paySession
|
||||
|
||||
// Create a dummy payment state.
|
||||
ps := &channeldb.MPPaymentState{
|
||||
NumAttemptsInFlight: 1,
|
||||
RemainingAmt: 1,
|
||||
FeesPaid: 100,
|
||||
}
|
||||
|
||||
// Mock remainingFees to be 1.
|
||||
p.feeLimit = ps.FeesPaid + 1
|
||||
|
||||
// Mock the paySession's `RequestRoute` method to return an error.
|
||||
dummyErr := errors.New("dummy")
|
||||
paySession.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil, dummyErr)
|
||||
|
||||
result, err := p.requestRoute(ps)
|
||||
|
||||
// Expect an error is returned since it's critical.
|
||||
require.ErrorIs(t, err, dummyErr, "error not matched")
|
||||
require.Nil(t, result, "expected no route returned")
|
||||
|
||||
// Assert that `RequestRoute` is called as expected.
|
||||
paySession.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestRequestRouteHandleNoRouteErr checks that `requestRoute` can successfully
|
||||
// handle the `noRouteError` returned from payment session.
|
||||
func TestRequestRouteHandleNoRouteErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
// Create a mock payment session.
|
||||
paySession := &mockPaymentSession{}
|
||||
|
||||
// Mount the mocked payment session.
|
||||
p.paySession = paySession
|
||||
|
||||
// Create a dummy payment state.
|
||||
ps := &channeldb.MPPaymentState{
|
||||
NumAttemptsInFlight: 1,
|
||||
RemainingAmt: 1,
|
||||
FeesPaid: 100,
|
||||
}
|
||||
|
||||
// Mock remainingFees to be 1.
|
||||
p.feeLimit = ps.FeesPaid + 1
|
||||
|
||||
// Mock the paySession's `RequestRoute` method to return an error.
|
||||
paySession.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil, errNoTlvPayload)
|
||||
|
||||
result, err := p.requestRoute(ps)
|
||||
|
||||
// Expect no error is returned since it's not critical.
|
||||
require.NoError(t, err, "expected no error")
|
||||
require.Nil(t, result, "expected no route returned")
|
||||
|
||||
// Assert that `RequestRoute` is called as expected.
|
||||
paySession.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestRequestRouteFailPaymentSucceed checks that `requestRoute` fails the
|
||||
// payment when received an `noRouteError` returned from payment session while
|
||||
// it has no inflight attempts.
|
||||
func TestRequestRouteFailPaymentSucceed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
// Create a mock payment session.
|
||||
paySession := &mockPaymentSession{}
|
||||
|
||||
// Mock the control tower's `FailPayment` method.
|
||||
ct := &mockControlTower{}
|
||||
ct.On("FailPayment",
|
||||
p.identifier, errNoTlvPayload.FailureReason(),
|
||||
).Return(nil)
|
||||
|
||||
// Mount the mocked control tower and payment session.
|
||||
p.router.cfg.Control = ct
|
||||
p.paySession = paySession
|
||||
|
||||
// Create a dummy payment state with zero inflight attempts.
|
||||
ps := &channeldb.MPPaymentState{
|
||||
NumAttemptsInFlight: 0,
|
||||
RemainingAmt: 1,
|
||||
FeesPaid: 100,
|
||||
}
|
||||
|
||||
// Mock remainingFees to be 1.
|
||||
p.feeLimit = ps.FeesPaid + 1
|
||||
|
||||
// Mock the paySession's `RequestRoute` method to return an error.
|
||||
paySession.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil, errNoTlvPayload)
|
||||
|
||||
result, err := p.requestRoute(ps)
|
||||
|
||||
// Expect no error is returned since it's not critical.
|
||||
require.NoError(t, err, "expected no error")
|
||||
require.Nil(t, result, "expected no route returned")
|
||||
|
||||
// Assert that `RequestRoute` is called as expected.
|
||||
paySession.AssertExpectations(t)
|
||||
|
||||
// Assert that `FailPayment` is called as expected.
|
||||
ct.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestRequestRouteFailPaymentError checks that `requestRoute` returns the
|
||||
// error from calling `FailPayment`.
|
||||
func TestRequestRouteFailPaymentError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
// Create a mock payment session.
|
||||
paySession := &mockPaymentSession{}
|
||||
|
||||
// Mock the control tower's `FailPayment` method.
|
||||
ct := &mockControlTower{}
|
||||
dummyErr := errors.New("dummy")
|
||||
ct.On("FailPayment",
|
||||
p.identifier, errNoTlvPayload.FailureReason(),
|
||||
).Return(dummyErr)
|
||||
|
||||
// Mount the mocked control tower and payment session.
|
||||
p.router.cfg.Control = ct
|
||||
p.paySession = paySession
|
||||
|
||||
// Create a dummy payment state with zero inflight attempts.
|
||||
ps := &channeldb.MPPaymentState{
|
||||
NumAttemptsInFlight: 0,
|
||||
RemainingAmt: 1,
|
||||
FeesPaid: 100,
|
||||
}
|
||||
|
||||
// Mock remainingFees to be 1.
|
||||
p.feeLimit = ps.FeesPaid + 1
|
||||
|
||||
// Mock the paySession's `RequestRoute` method to return an error.
|
||||
paySession.On("RequestRoute",
|
||||
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil, errNoTlvPayload)
|
||||
|
||||
result, err := p.requestRoute(ps)
|
||||
|
||||
// Expect an error is returned.
|
||||
require.ErrorIs(t, err, dummyErr, "error not matched")
|
||||
require.Nil(t, result, "expected no route returned")
|
||||
|
||||
// Assert that `RequestRoute` is called as expected.
|
||||
paySession.AssertExpectations(t)
|
||||
|
||||
// Assert that `FailPayment` is called as expected.
|
||||
ct.AssertExpectations(t)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user