routing: add methods checkTimeout and requestRoute

This commit refactors the `resumePayment` method by adding the methods
`checkTimeout` and `requestRoute` so it's easier to understand the flow
and reason about the error handling.
This commit is contained in:
yyforyongyu 2022-06-27 05:02:30 +08:00
parent 7209c65ccf
commit 703ea08316
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
3 changed files with 400 additions and 61 deletions

View File

@ -673,6 +673,12 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32) (*route.Route, error) {
args := m.Called(maxAmt, feeLimit, activeShards, height)
// Type assertion on nil will fail, so we check and return here.
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*route.Route), args.Error(1)
}

View File

@ -1,12 +1,13 @@
package routing
import (
"errors"
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models"
@ -204,70 +205,19 @@ lifecycle:
// router is exiting. In either case, we'll stop this payment
// attempt short. If a timeout is not applicable, timeoutChan
// will be nil.
select {
case <-p.timeoutChan:
log.Warnf("payment attempt not completed before " +
"timeout")
// By marking the payment failed with the control
// tower, no further shards will be launched and we'll
// return with an error the moment all active shards
// have finished.
saveErr := p.router.cfg.Control.FailPayment(
p.identifier, channeldb.FailureReasonTimeout,
)
if saveErr != nil {
return exitWithErr(saveErr)
}
continue lifecycle
case <-p.router.quit:
return exitWithErr(ErrRouterShuttingDown)
// Fall through if we haven't hit our time limit.
default:
if err := p.checkTimeout(); err != nil {
return exitWithErr(err)
}
// Create a new payment attempt from the given payment session.
rt, err := p.paySession.RequestRoute(
ps.RemainingAmt, remainingFees,
uint32(ps.NumAttemptsInFlight),
uint32(p.currentHeight),
)
// Now request a route to be used to create our HTLC attempt.
rt, err := p.requestRoute(ps)
if err != nil {
log.Warnf("Failed to find route for payment %v: %v",
p.identifier, err)
return exitWithErr(err)
}
routeErr, ok := err.(noRouteError)
if !ok {
return exitWithErr(err)
}
// There is no route to try, and we have no active
// shards. This means that there is no way for us to
// send the payment, so mark it failed with no route.
if ps.NumAttemptsInFlight == 0 {
failureCode := routeErr.FailureReason()
log.Debugf("Marking payment %v permanently "+
"failed with no route: %v",
p.identifier, failureCode)
saveErr := p.router.cfg.Control.FailPayment(
p.identifier, failureCode,
)
if saveErr != nil {
return exitWithErr(saveErr)
}
continue lifecycle
}
// We still have active shards, we'll wait for an
// outcome to be available before retrying.
if err := p.waitForShard(); err != nil {
return exitWithErr(err)
}
// NOTE: might cause an infinite loop, see notes in
// `requestRoute` for details.
if rt == nil {
continue lifecycle
}
@ -293,6 +243,95 @@ lifecycle:
}
}
// checkTimeout checks whether the payment has reached its timeout.
func (p *paymentLifecycle) checkTimeout() error {
select {
case <-p.timeoutChan:
log.Warnf("payment attempt not completed before timeout")
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(
p.identifier, channeldb.FailureReasonTimeout,
)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}
case <-p.router.quit:
return fmt.Errorf("check payment timeout got: %w",
ErrRouterShuttingDown)
// Fall through if we haven't hit our time limit.
default:
}
return nil
}
// requestRoute is responsible for finding a route to be used to create an HTLC
// attempt.
func (p *paymentLifecycle) requestRoute(
ps *channeldb.MPPaymentState) (*route.Route, error) {
remainingFees := p.calcFeeBudget(ps.FeesPaid)
// Query our payment session to construct a route.
rt, err := p.paySession.RequestRoute(
ps.RemainingAmt, remainingFees,
uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight),
)
// Exit early if there's no error.
if err == nil {
return rt, nil
}
// Otherwise we need to handle the error.
log.Warnf("Failed to find route for payment %v: %v", p.identifier, err)
// If the error belongs to `noRouteError` set, it means a non-critical
// error has happened during path finding and we might be able to find
// another route during next HTLC attempt. Otherwise, we'll return the
// critical error found.
var routeErr noRouteError
if !errors.As(err, &routeErr) {
return nil, fmt.Errorf("requestRoute got: %w", err)
}
// There is no route to try, and we have no active shards. This means
// that there is no way for us to send the payment, so mark it failed
// with no route.
//
// NOTE: if we have zero `numShardsInFlight`, it means all the HTLC
// attempts have failed. Otherwise, if there are still inflight
// attempts, we might enter an infinite loop in our lifecycle if
// there's still remaining amount since we will keep adding new HTLC
// attempts and they all fail with `noRouteError`.
//
// TODO(yy): further check the error returned here. It's the
// `paymentSession`'s responsibility to find a route for us with best
// effort. When it cannot find a path, we need to treat it as a
// terminal condition and fail the payment no matter it has inflight
// HTLCs or not.
if ps.NumAttemptsInFlight == 0 {
failureCode := routeErr.FailureReason()
log.Debugf("Marking payment %v permanently failed with no "+
"route: %v", p.identifier, failureCode)
err := p.router.cfg.Control.FailPayment(
p.identifier, failureCode,
)
if err != nil {
return nil, fmt.Errorf("FailPayment got: %w", err)
}
}
return nil, nil
}
// stop signals any active shard goroutine to exit and waits for them to exit.
func (p *paymentLifecycle) stop() {
close(p.quit)

View File

@ -15,6 +15,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
@ -818,3 +819,296 @@ func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo {
},
}
}
// TestCheckTimeoutTimedOut checks that when the payment times out, it is
// marked as failed.
func TestCheckTimeoutTimedOut(t *testing.T) {
t.Parallel()
p := createTestPaymentLifecycle()
// Mock the control tower's `FailPayment` method.
ct := &mockControlTower{}
ct.On("FailPayment",
p.identifier, channeldb.FailureReasonTimeout).Return(nil)
// Mount the mocked control tower.
p.router.cfg.Control = ct
// Make the timeout happens instantly.
p.timeoutChan = time.After(1 * time.Nanosecond)
// Sleep one millisecond to make sure it timed out.
time.Sleep(1 * time.Millisecond)
// Call the function and expect no error.
err := p.checkTimeout()
require.NoError(t, err)
// Assert that `FailPayment` is called as expected.
ct.AssertExpectations(t)
// We now test that when `FailPayment` returns an error, it's returned
// by the function too.
//
// Mock `FailPayment` to return a dummy error.
dummyErr := errors.New("dummy")
ct = &mockControlTower{}
ct.On("FailPayment",
p.identifier, channeldb.FailureReasonTimeout).Return(dummyErr)
// Mount the mocked control tower.
p.router.cfg.Control = ct
// Make the timeout happens instantly.
p.timeoutChan = time.After(1 * time.Nanosecond)
// Sleep one millisecond to make sure it timed out.
time.Sleep(1 * time.Millisecond)
// Call the function and expect an error.
err = p.checkTimeout()
require.ErrorIs(t, err, dummyErr)
// Assert that `FailPayment` is called as expected.
ct.AssertExpectations(t)
}
// TestCheckTimeoutOnRouterQuit checks that when the router has quit, an error
// is returned from checkTimeout.
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
t.Parallel()
p := createTestPaymentLifecycle()
close(p.router.quit)
err := p.checkTimeout()
require.ErrorIs(t, err, ErrRouterShuttingDown)
}
// createTestPaymentLifecycle creates a `paymentLifecycle` using the mocks.
func createTestPaymentLifecycle() *paymentLifecycle {
paymentHash := lntypes.Hash{1, 2, 3}
quitChan := make(chan struct{})
rt := &ChannelRouter{
cfg: &Config{},
quit: quitChan,
}
return &paymentLifecycle{
router: rt,
identifier: paymentHash,
}
}
// TestRequestRouteSucceed checks that `requestRoute` can successfully request
// a route.
func TestRequestRouteSucceed(t *testing.T) {
t.Parallel()
p := createTestPaymentLifecycle()
// Create a mock payment session and a dummy route.
paySession := &mockPaymentSession{}
dummyRoute := &route.Route{}
// Mount the mocked payment session.
p.paySession = paySession
// Create a dummy payment state.
ps := &channeldb.MPPaymentState{
NumAttemptsInFlight: 1,
RemainingAmt: 1,
FeesPaid: 100,
}
// Mock remainingFees to be 1.
p.feeLimit = ps.FeesPaid + 1
// Mock the paySession's `RequestRoute` method to return no error.
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(dummyRoute, nil)
result, err := p.requestRoute(ps)
require.NoError(t, err, "expect no error")
require.Equal(t, dummyRoute, result, "returned route not matched")
// Assert that `RequestRoute` is called as expected.
paySession.AssertExpectations(t)
}
// TestRequestRouteHandleCriticalErr checks that `requestRoute` can
// successfully handle a critical error returned from payment session.
func TestRequestRouteHandleCriticalErr(t *testing.T) {
t.Parallel()
p := createTestPaymentLifecycle()
// Create a mock payment session.
paySession := &mockPaymentSession{}
// Mount the mocked payment session.
p.paySession = paySession
// Create a dummy payment state.
ps := &channeldb.MPPaymentState{
NumAttemptsInFlight: 1,
RemainingAmt: 1,
FeesPaid: 100,
}
// Mock remainingFees to be 1.
p.feeLimit = ps.FeesPaid + 1
// Mock the paySession's `RequestRoute` method to return an error.
dummyErr := errors.New("dummy")
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(nil, dummyErr)
result, err := p.requestRoute(ps)
// Expect an error is returned since it's critical.
require.ErrorIs(t, err, dummyErr, "error not matched")
require.Nil(t, result, "expected no route returned")
// Assert that `RequestRoute` is called as expected.
paySession.AssertExpectations(t)
}
// TestRequestRouteHandleNoRouteErr checks that `requestRoute` can successfully
// handle the `noRouteError` returned from payment session.
func TestRequestRouteHandleNoRouteErr(t *testing.T) {
t.Parallel()
p := createTestPaymentLifecycle()
// Create a mock payment session.
paySession := &mockPaymentSession{}
// Mount the mocked payment session.
p.paySession = paySession
// Create a dummy payment state.
ps := &channeldb.MPPaymentState{
NumAttemptsInFlight: 1,
RemainingAmt: 1,
FeesPaid: 100,
}
// Mock remainingFees to be 1.
p.feeLimit = ps.FeesPaid + 1
// Mock the paySession's `RequestRoute` method to return an error.
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(nil, errNoTlvPayload)
result, err := p.requestRoute(ps)
// Expect no error is returned since it's not critical.
require.NoError(t, err, "expected no error")
require.Nil(t, result, "expected no route returned")
// Assert that `RequestRoute` is called as expected.
paySession.AssertExpectations(t)
}
// TestRequestRouteFailPaymentSucceed checks that `requestRoute` fails the
// payment when received an `noRouteError` returned from payment session while
// it has no inflight attempts.
func TestRequestRouteFailPaymentSucceed(t *testing.T) {
t.Parallel()
p := createTestPaymentLifecycle()
// Create a mock payment session.
paySession := &mockPaymentSession{}
// Mock the control tower's `FailPayment` method.
ct := &mockControlTower{}
ct.On("FailPayment",
p.identifier, errNoTlvPayload.FailureReason(),
).Return(nil)
// Mount the mocked control tower and payment session.
p.router.cfg.Control = ct
p.paySession = paySession
// Create a dummy payment state with zero inflight attempts.
ps := &channeldb.MPPaymentState{
NumAttemptsInFlight: 0,
RemainingAmt: 1,
FeesPaid: 100,
}
// Mock remainingFees to be 1.
p.feeLimit = ps.FeesPaid + 1
// Mock the paySession's `RequestRoute` method to return an error.
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(nil, errNoTlvPayload)
result, err := p.requestRoute(ps)
// Expect no error is returned since it's not critical.
require.NoError(t, err, "expected no error")
require.Nil(t, result, "expected no route returned")
// Assert that `RequestRoute` is called as expected.
paySession.AssertExpectations(t)
// Assert that `FailPayment` is called as expected.
ct.AssertExpectations(t)
}
// TestRequestRouteFailPaymentError checks that `requestRoute` returns the
// error from calling `FailPayment`.
func TestRequestRouteFailPaymentError(t *testing.T) {
t.Parallel()
p := createTestPaymentLifecycle()
// Create a mock payment session.
paySession := &mockPaymentSession{}
// Mock the control tower's `FailPayment` method.
ct := &mockControlTower{}
dummyErr := errors.New("dummy")
ct.On("FailPayment",
p.identifier, errNoTlvPayload.FailureReason(),
).Return(dummyErr)
// Mount the mocked control tower and payment session.
p.router.cfg.Control = ct
p.paySession = paySession
// Create a dummy payment state with zero inflight attempts.
ps := &channeldb.MPPaymentState{
NumAttemptsInFlight: 0,
RemainingAmt: 1,
FeesPaid: 100,
}
// Mock remainingFees to be 1.
p.feeLimit = ps.FeesPaid + 1
// Mock the paySession's `RequestRoute` method to return an error.
paySession.On("RequestRoute",
mock.Anything, mock.Anything, mock.Anything, mock.Anything,
).Return(nil, errNoTlvPayload)
result, err := p.requestRoute(ps)
// Expect an error is returned.
require.ErrorIs(t, err, dummyErr, "error not matched")
require.Nil(t, result, "expected no route returned")
// Assert that `RequestRoute` is called as expected.
paySession.AssertExpectations(t)
// Assert that `FailPayment` is called as expected.
ct.AssertExpectations(t)
}