routing: introduce stateStep to manage payment lifecycle

This commit adds a new struct, `stateStep`, to decide the workflow
inside `resumePayment`.

It also refactors `collectResultAsync` introducing a new channel
`resultCollected`. This channel is used to signal the payment
lifecycle that an HTLC attempt result is ready to be processed.
This commit is contained in:
yyforyongyu 2023-03-08 00:58:14 +08:00
parent e8c0226e1c
commit 3c5c37b693
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
3 changed files with 343 additions and 177 deletions

View File

@ -3,7 +3,6 @@ package routing
import (
"errors"
"fmt"
"sync"
"time"
"github.com/btcsuite/btcd/btcec/v2"
@ -18,9 +17,6 @@ import (
"github.com/lightningnetwork/lnd/routing/shards"
)
// errShardHandlerExiting is returned from the shardHandler when it exits.
var errShardHandlerExiting = errors.New("shard handler exiting")
// paymentLifecycle holds all information about the current state of a payment
// needed to resume if from any point.
type paymentLifecycle struct {
@ -32,18 +28,15 @@ type paymentLifecycle struct {
timeoutChan <-chan time.Time
currentHeight int32
// shardErrors is a channel where errors collected by calling
// collectResultAsync will be delivered. These results are meant to be
// inspected by calling waitForShard or checkShards, and the channel
// doesn't need to be initiated if the caller is using the sync
// collectResult directly.
// TODO(yy): delete.
shardErrors chan error
// quit is closed to signal the sub goroutines of the payment lifecycle
// to stop.
quit chan struct{}
wg sync.WaitGroup
// resultCollected is used to signal that the result of an attempt has
// been collected. A nil error means the attempt is either successful
// or failed with temporary error. Otherwise, we should exit the
// lifecycle loop as a terminal error has occurred.
resultCollected chan error
}
// newPaymentLifecycle initiates a new payment lifecycle and returns it.
@ -53,14 +46,14 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
currentHeight int32) *paymentLifecycle {
p := &paymentLifecycle{
router: r,
feeLimit: feeLimit,
identifier: identifier,
paySession: paySession,
shardTracker: shardTracker,
currentHeight: currentHeight,
shardErrors: make(chan error),
quit: make(chan struct{}),
router: r,
feeLimit: feeLimit,
identifier: identifier,
paySession: paySession,
shardTracker: shardTracker,
currentHeight: currentHeight,
quit: make(chan struct{}),
resultCollected: make(chan error, 1),
}
// If a timeout is specified, create a timeout channel. If no timeout is
@ -92,6 +85,74 @@ func (p *paymentLifecycle) calcFeeBudget(
return budget
}
// stateStep defines an action to be taken in our payment lifecycle. We either
// quit, continue, or exit the lifecycle, see details below.
type stateStep uint8
const (
// stepSkip is used when we need to skip the current lifecycle and jump
// to the next one.
stepSkip stateStep = iota
// stepProceed is used when we can proceed the current lifecycle.
stepProceed
// stepExit is used when we need to quit the current lifecycle.
stepExit
)
// decideNextStep is used to determine the next step in the payment lifecycle.
func (p *paymentLifecycle) decideNextStep(
payment dbMPPayment) (stateStep, error) {
// Check whether we could make new HTLC attempts.
allow, err := payment.AllowMoreAttempts()
if err != nil {
return stepExit, err
}
if !allow {
// Check whether we need to wait for results.
wait, err := payment.NeedWaitAttempts()
if err != nil {
return stepExit, err
}
// If we are not allowed to make new HTLC attempts and there's
// no need to wait, the lifecycle is done and we can exit.
if !wait {
return stepExit, nil
}
log.Tracef("Waiting for attempt results for payment %v",
p.identifier)
// Otherwise we wait for one HTLC attempt then continue
// the lifecycle.
//
// NOTE: we don't check `p.quit` since `decideNextStep` is
// running in the same goroutine as `resumePayment`.
select {
case err := <-p.resultCollected:
// If an error is returned, exit with it.
if err != nil {
return stepExit, err
}
log.Tracef("Received attempt result for payment %v",
p.identifier)
case <-p.router.quit:
return stepExit, ErrRouterShuttingDown
}
return stepSkip, nil
}
// Otherwise we need to make more attempts.
return stepProceed, nil
}
// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
// When the payment lifecycle loop exits, we make sure to signal any
@ -127,20 +188,12 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
// critical error during path finding.
lifecycle:
for {
// Start by quickly checking if there are any outcomes already
// available to handle before we reevaluate our state.
if err := p.checkShards(); err != nil {
return exitWithErr(err)
}
// We update the payment state on every iteration. Since the
// payment state is affected by multiple goroutines (ie,
// collectResultAsync), it is NOT guaranteed that we always
// have the latest state here. This is fine as long as the
// state is consistent as a whole.
// Fetch the latest payment from db.
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
payment, err = p.router.cfg.Control.FetchPayment(p.identifier)
if err != nil {
return exitWithErr(err)
}
@ -153,53 +206,14 @@ lifecycle:
p.identifier, payment.Terminated(),
ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees)
// TODO(yy): sanity check all the states to make sure
// everything is expected.
// We have a terminal condition and no active shards, we are
// ready to exit.
if payment.Terminated() {
// Find the first successful shard and return
// the preimage and route.
for _, a := range payment.GetHTLCs() {
if a.Settle == nil {
continue
}
err := p.router.cfg.Control.DeleteFailedAttempts(
p.identifier,
)
if err != nil {
log.Errorf("Error deleting failed "+
"payment attempts for "+
"payment %v: %v", p.identifier,
err)
}
return a.Settle.Preimage, &a.Route, nil
}
// Payment failed.
return exitWithErr(*payment.GetFailureReason())
}
// If we either reached a terminal error condition (but had
// active shards still) or there is no remaining value to send,
// we'll wait for a shard outcome.
wait, err := payment.NeedWaitAttempts()
if err != nil {
return exitWithErr(err)
}
if wait {
// We still have outstanding shards, so wait for a new
// outcome to be available before re-evaluating our
// state.
if err := p.waitForShard(); err != nil {
return exitWithErr(err)
}
continue lifecycle
}
// We now proceed our lifecycle with the following tasks in
// order,
// 1. check timeout.
// 2. request route.
// 3. create HTLC attempt.
// 4. send HTLC attempt.
// 5. collect HTLC attempt result.
//
// Before we attempt any new shard, we'll check to see if
// either we've gone past the payment attempt timeout, or the
// router is exiting. In either case, we'll stop this payment
@ -209,6 +223,30 @@ lifecycle:
return exitWithErr(err)
}
// Now decide the next step of the current lifecycle.
step, err := p.decideNextStep(payment)
if err != nil {
return exitWithErr(err)
}
switch step {
// Exit the for loop and return below.
case stepExit:
break lifecycle
// Continue the for loop and skip the rest.
case stepSkip:
continue lifecycle
// Continue the for loop and proceed the rest.
case stepProceed:
// Unknown step received, exit with an error.
default:
err = fmt.Errorf("unknown step: %v", step)
return exitWithErr(err)
}
// Now request a route to be used to create our HTLC attempt.
rt, err := p.requestRoute(ps)
if err != nil {
@ -241,6 +279,27 @@ lifecycle:
p.collectResultAsync(attempt)
}
}
// Once we are out the lifecycle loop, it means we've reached a
// terminal condition. We either return the settled preimage or the
// payment's failure reason.
//
// Optionally delete the failed attempts from the database.
err = p.router.cfg.Control.DeleteFailedAttempts(p.identifier)
if err != nil {
log.Errorf("Error deleting failed htlc attempts for payment "+
"%v: %v", p.identifier, err)
}
// Find the first successful shard and return the preimage and route.
for _, a := range payment.GetHTLCs() {
if a.Settle != nil {
return a.Settle.Preimage, &a.Route, nil
}
}
// Otherwise return the payment failure reason.
return [32]byte{}, nil, *payment.GetFailureReason()
}
// checkTimeout checks whether the payment has reached its timeout.
@ -332,46 +391,9 @@ func (p *paymentLifecycle) requestRoute(
return nil, nil
}
// stop signals any active shard goroutine to exit and waits for them to exit.
// stop signals any active shard goroutine to exit.
func (p *paymentLifecycle) stop() {
close(p.quit)
p.wg.Wait()
}
// waitForShard blocks until any of the outstanding shards return.
func (p *paymentLifecycle) waitForShard() error {
select {
case err := <-p.shardErrors:
return err
case <-p.quit:
return errShardHandlerExiting
case <-p.router.quit:
return ErrRouterShuttingDown
}
}
// checkShards is a non-blocking method that check if any shards has finished
// their execution.
func (p *paymentLifecycle) checkShards() error {
for {
select {
case err := <-p.shardErrors:
if err != nil {
return err
}
case <-p.quit:
return errShardHandlerExiting
case <-p.router.quit:
return ErrRouterShuttingDown
default:
return nil
}
}
}
// attemptResult holds the HTLC attempt and a possible error returned from
@ -388,38 +410,33 @@ type attemptResult struct {
}
// collectResultAsync launches a goroutine that will wait for the result of the
// given HTLC attempt to be available then handle its result. It will fail the
// payment with the control tower if a terminal error is encountered.
// given HTLC attempt to be available then handle its result. Once received, it
// will send a nil error to channel `resultCollected` to indicate there's an
// result.
func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) {
// errToSend is the error to be sent to sh.shardErrors.
var errToSend error
// handleResultErr is a function closure must be called using defer. It
// finishes collecting result by updating the payment state and send
// the error (or nil) to sh.shardErrors.
handleResultErr := func() {
// Send the error or quit.
select {
case p.shardErrors <- errToSend:
case <-p.router.quit:
case <-p.quit:
}
p.wg.Done()
}
p.wg.Add(1)
go func() {
defer handleResultErr()
// Block until the result is available.
_, err := p.collectResult(attempt)
if err != nil {
log.Errorf("Error collecting result for attempt %v "+
"in payment %v: %v", attempt.AttemptID,
p.identifier, err)
}
errToSend = err
log.Debugf("Result collected for attempt %v in payment %v",
attempt.AttemptID, p.identifier)
// Once the result is collected, we signal it by writing the
// error to `resultCollected`.
select {
// Send the signal or quit.
case p.resultCollected <- err:
case <-p.quit:
log.Debugf("Lifecycle exiting while collecting "+
"result for payment %v", p.identifier)
case <-p.router.quit:
return
}
}()

View File

@ -21,6 +21,10 @@ import (
const stepTimeout = 5 * time.Second
var (
dummyErr = errors.New("dummy")
)
// createTestRoute builds a route a->b->c paying the given amt to c.
func createTestRoute(amt lnwire.MilliSatoshi,
aliasMap map[string]route.Vertex) (*route.Route, error) {
@ -1112,3 +1116,119 @@ func TestRequestRouteFailPaymentError(t *testing.T) {
// Assert that `FailPayment` is called as expected.
ct.AssertExpectations(t)
}
// TestDecideNextStep checks the method `decideNextStep` behaves as expected.
func TestDecideNextStep(t *testing.T) {
t.Parallel()
// mockReturn is used to hold the return values from AllowMoreAttempts
// or NeedWaitAttempts.
type mockReturn struct {
allowOrWait bool
err error
}
testCases := []struct {
name string
allowMoreAttempts *mockReturn
needWaitAttempts *mockReturn
// When the attemptResultChan has returned.
closeResultChan bool
// Whether the router has quit.
routerQuit bool
expectedStep stateStep
expectedErr error
}{
{
name: "allow more attempts",
allowMoreAttempts: &mockReturn{true, nil},
expectedStep: stepProceed,
expectedErr: nil,
},
{
name: "error on allow more attempts",
allowMoreAttempts: &mockReturn{false, dummyErr},
expectedStep: stepExit,
expectedErr: dummyErr,
},
{
name: "no wait and exit",
allowMoreAttempts: &mockReturn{false, nil},
needWaitAttempts: &mockReturn{false, nil},
expectedStep: stepExit,
expectedErr: nil,
},
{
name: "wait returns an error",
allowMoreAttempts: &mockReturn{false, nil},
needWaitAttempts: &mockReturn{false, dummyErr},
expectedStep: stepExit,
expectedErr: dummyErr,
},
{
name: "wait and exit on result chan",
allowMoreAttempts: &mockReturn{false, nil},
needWaitAttempts: &mockReturn{true, nil},
closeResultChan: true,
expectedStep: stepSkip,
expectedErr: nil,
},
{
name: "wait and exit on router quit",
allowMoreAttempts: &mockReturn{false, nil},
needWaitAttempts: &mockReturn{true, nil},
routerQuit: true,
expectedStep: stepExit,
expectedErr: ErrRouterShuttingDown,
},
}
for _, tc := range testCases {
tc := tc
// Create a test paymentLifecycle.
p := createTestPaymentLifecycle()
// Make a mock payment.
payment := &mockMPPayment{}
// Mock the method AllowMoreAttempts.
payment.On("AllowMoreAttempts").Return(
tc.allowMoreAttempts.allowOrWait,
tc.allowMoreAttempts.err,
).Once()
// Mock the method NeedWaitAttempts.
if tc.needWaitAttempts != nil {
payment.On("NeedWaitAttempts").Return(
tc.needWaitAttempts.allowOrWait,
tc.needWaitAttempts.err,
).Once()
}
// Send a nil error to the attemptResultChan if requested.
if tc.closeResultChan {
p.resultCollected = make(chan error, 1)
p.resultCollected <- nil
}
// Quit the router if requested.
if tc.routerQuit {
close(p.router.quit)
}
// Once the setup is finished, run the test cases.
t.Run(tc.name, func(t *testing.T) {
step, err := p.decideNextStep(payment)
require.Equal(t, tc.expectedStep, step)
require.ErrorIs(t, tc.expectedErr, err)
})
// Check the payment's methods are called as expected.
payment.AssertExpectations(t)
}
}

View File

@ -3470,34 +3470,44 @@ func TestSendMPPaymentSucceed(t *testing.T) {
session := &mockPaymentSession{}
sessionSource.On("NewPaymentSession", req).Return(session, nil)
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
// Mock the InFlightHTLCs.
var (
htlcs []channeldb.HTLCAttempt
numAttempts atomic.Uint32
settled atomic.Bool
numParts = uint32(4)
)
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("Terminated").Return(false)
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment",
identifier,
).Return(payment, nil).Run(func(args mock.Arguments) {
// When number of attempts made is less than 4, we will mock
// the payment's methods to allow the lifecycle to continue.
if numAttempts.Load() < 4 {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
return
}
controlTower.On("FetchPayment", identifier).Return(payment, nil).
Run(func(args mock.Arguments) {
// When number of attempts made is less than 4, we will
// mock the payment's methods to allow the lifecycle to
// continue.
if numAttempts.Load() < numParts {
payment.On("AllowMoreAttempts").Return(true, nil).Once()
return
}
// Otherwise, terminate the lifecycle.
payment.On("Terminated").Return(true).
On("NeedWaitAttempts").Return(true, nil)
})
if !settled.Load() {
fmt.Println("wait")
payment.On("AllowMoreAttempts").Return(false, nil).Once()
payment.On("NeedWaitAttempts").Return(true, nil).Once()
// We add another attempt to the counter to
// unblock next time.
return
}
payment.On("AllowMoreAttempts").Return(false, nil).
On("NeedWaitAttempts").Return(false, nil)
})
// Mock SettleAttempt.
preimage := lntypes.Preimage{1, 2, 3}
@ -3511,6 +3521,10 @@ func TestSendMPPaymentSucceed(t *testing.T) {
payment.On("GetHTLCs").Return(
[]channeldb.HTLCAttempt{settledAttempt},
)
// We want to at least wait for one settlement.
if numAttempts.Load() > 1 {
settled.Store(true)
}
})
// Create a route that can send 1/4 of the total amount. This value
@ -3527,7 +3541,6 @@ func TestSendMPPaymentSucceed(t *testing.T) {
controlTower.On("RegisterAttempt",
identifier, mock.Anything,
).Return(nil).Run(func(args mock.Arguments) {
// Increase the counter whenever an attempt is made.
numAttempts.Add(1)
})
@ -3663,29 +3676,40 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
htlcs []channeldb.HTLCAttempt
numAttempts atomic.Uint32
failAttemptCount atomic.Uint32
settled atomic.Bool
)
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
On("Terminated").Return(false)
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
// Mock FetchPayment to return the payment.
controlTower.On("FetchPayment",
identifier,
).Return(payment, nil).Run(func(args mock.Arguments) {
// When number of attempts made is less than 6, we will mock
// the payment's methods to allow the lifecycle to continue.
if numAttempts.Load() < 6 {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
return
}
controlTower.On("FetchPayment", identifier).Return(payment, nil).
Run(func(args mock.Arguments) {
// When number of attempts made is less than 4, we will
// mock the payment's methods to allow the lifecycle to
// continue.
attempts := numAttempts.Load()
if attempts < 6 {
payment.On("AllowMoreAttempts").Return(true, nil).Once()
return
}
// Otherwise, terminate the lifecycle.
payment.On("Terminated").Return(true).
On("NeedWaitAttempts").Return(true, nil)
})
if !settled.Load() {
payment.On("AllowMoreAttempts").Return(false, nil).Once()
payment.On("NeedWaitAttempts").Return(true, nil).Once()
// We add another attempt to the counter to
// unblock next time.
numAttempts.Add(1)
return
}
payment.On("AllowMoreAttempts").Return(false, nil).
On("NeedWaitAttempts").Return(false, nil)
})
// Create a route that can send 1/4 of the total amount. This value
// will be returned by calling RequestRoute.
@ -3768,6 +3792,10 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
payment.On("GetHTLCs").Return(
[]channeldb.HTLCAttempt{settledAttempt},
)
if numAttempts.Load() > 1 {
settled.Store(true)
}
})
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
@ -3885,8 +3913,8 @@ func TestSendMPPaymentFailed(t *testing.T) {
// Make a mock MPPayment.
payment := &mockMPPayment{}
payment.On("InFlightHTLCs").Return(htlcs).Once()
payment.On("GetStatus").Return(channeldb.StatusInFlight).Once()
payment.On("GetState").Return(&channeldb.MPPaymentState{})
payment.On("Terminated").Return(false)
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
// Mock the sequential FetchPayment to return the payment.
@ -3895,21 +3923,20 @@ func TestSendMPPaymentFailed(t *testing.T) {
// We want to at least send out all parts in order to
// wait for them later.
if numAttempts.Load() < numParts {
payment.On("Terminated").Return(false).Times(2).
On("NeedWaitAttempts").Return(false, nil).Once()
payment.On("AllowMoreAttempts").Return(true, nil).Once()
return
}
// Wait if the payment wasn't failed yet.
if !failed.Load() {
payment.On("Terminated").Return(false).Times(2).
payment.On("AllowMoreAttempts").Return(false, nil).Once().
On("NeedWaitAttempts").Return(true, nil).Once()
return
}
payment.On("Terminated").Return(true).
On("GetHTLCs").Return(htlcs).Once()
payment.On("AllowMoreAttempts").Return(false, nil).
On("GetHTLCs").Return(htlcs).Once().
On("NeedWaitAttempts").Return(false, nil).Once()
})
// Create a route that can send 1/4 of the total amount. This value
@ -3990,6 +4017,8 @@ func TestSendMPPaymentFailed(t *testing.T) {
mock.Anything, mock.Anything, mock.Anything,
).Return(nil)
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
// Call the actual method SendPayment on router. This is place inside a
// goroutine so we can set a timeout for the whole test, in case
// anything goes wrong and the test never finishes.