mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
routing: introduce stateStep
to manage payment lifecycle
This commit adds a new struct, `stateStep`, to decide the workflow inside `resumePayment`. It also refactors `collectResultAsync` introducing a new channel `resultCollected`. This channel is used to signal the payment lifecycle that an HTLC attempt result is ready to be processed.
This commit is contained in:
parent
e8c0226e1c
commit
3c5c37b693
@ -3,7 +3,6 @@ package routing
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
@ -18,9 +17,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/routing/shards"
|
||||
)
|
||||
|
||||
// errShardHandlerExiting is returned from the shardHandler when it exits.
|
||||
var errShardHandlerExiting = errors.New("shard handler exiting")
|
||||
|
||||
// paymentLifecycle holds all information about the current state of a payment
|
||||
// needed to resume if from any point.
|
||||
type paymentLifecycle struct {
|
||||
@ -32,18 +28,15 @@ type paymentLifecycle struct {
|
||||
timeoutChan <-chan time.Time
|
||||
currentHeight int32
|
||||
|
||||
// shardErrors is a channel where errors collected by calling
|
||||
// collectResultAsync will be delivered. These results are meant to be
|
||||
// inspected by calling waitForShard or checkShards, and the channel
|
||||
// doesn't need to be initiated if the caller is using the sync
|
||||
// collectResult directly.
|
||||
// TODO(yy): delete.
|
||||
shardErrors chan error
|
||||
|
||||
// quit is closed to signal the sub goroutines of the payment lifecycle
|
||||
// to stop.
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// resultCollected is used to signal that the result of an attempt has
|
||||
// been collected. A nil error means the attempt is either successful
|
||||
// or failed with temporary error. Otherwise, we should exit the
|
||||
// lifecycle loop as a terminal error has occurred.
|
||||
resultCollected chan error
|
||||
}
|
||||
|
||||
// newPaymentLifecycle initiates a new payment lifecycle and returns it.
|
||||
@ -53,14 +46,14 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
|
||||
currentHeight int32) *paymentLifecycle {
|
||||
|
||||
p := &paymentLifecycle{
|
||||
router: r,
|
||||
feeLimit: feeLimit,
|
||||
identifier: identifier,
|
||||
paySession: paySession,
|
||||
shardTracker: shardTracker,
|
||||
currentHeight: currentHeight,
|
||||
shardErrors: make(chan error),
|
||||
quit: make(chan struct{}),
|
||||
router: r,
|
||||
feeLimit: feeLimit,
|
||||
identifier: identifier,
|
||||
paySession: paySession,
|
||||
shardTracker: shardTracker,
|
||||
currentHeight: currentHeight,
|
||||
quit: make(chan struct{}),
|
||||
resultCollected: make(chan error, 1),
|
||||
}
|
||||
|
||||
// If a timeout is specified, create a timeout channel. If no timeout is
|
||||
@ -92,6 +85,74 @@ func (p *paymentLifecycle) calcFeeBudget(
|
||||
return budget
|
||||
}
|
||||
|
||||
// stateStep defines an action to be taken in our payment lifecycle. We either
|
||||
// quit, continue, or exit the lifecycle, see details below.
|
||||
type stateStep uint8
|
||||
|
||||
const (
|
||||
// stepSkip is used when we need to skip the current lifecycle and jump
|
||||
// to the next one.
|
||||
stepSkip stateStep = iota
|
||||
|
||||
// stepProceed is used when we can proceed the current lifecycle.
|
||||
stepProceed
|
||||
|
||||
// stepExit is used when we need to quit the current lifecycle.
|
||||
stepExit
|
||||
)
|
||||
|
||||
// decideNextStep is used to determine the next step in the payment lifecycle.
|
||||
func (p *paymentLifecycle) decideNextStep(
|
||||
payment dbMPPayment) (stateStep, error) {
|
||||
|
||||
// Check whether we could make new HTLC attempts.
|
||||
allow, err := payment.AllowMoreAttempts()
|
||||
if err != nil {
|
||||
return stepExit, err
|
||||
}
|
||||
|
||||
if !allow {
|
||||
// Check whether we need to wait for results.
|
||||
wait, err := payment.NeedWaitAttempts()
|
||||
if err != nil {
|
||||
return stepExit, err
|
||||
}
|
||||
|
||||
// If we are not allowed to make new HTLC attempts and there's
|
||||
// no need to wait, the lifecycle is done and we can exit.
|
||||
if !wait {
|
||||
return stepExit, nil
|
||||
}
|
||||
|
||||
log.Tracef("Waiting for attempt results for payment %v",
|
||||
p.identifier)
|
||||
|
||||
// Otherwise we wait for one HTLC attempt then continue
|
||||
// the lifecycle.
|
||||
//
|
||||
// NOTE: we don't check `p.quit` since `decideNextStep` is
|
||||
// running in the same goroutine as `resumePayment`.
|
||||
select {
|
||||
case err := <-p.resultCollected:
|
||||
// If an error is returned, exit with it.
|
||||
if err != nil {
|
||||
return stepExit, err
|
||||
}
|
||||
|
||||
log.Tracef("Received attempt result for payment %v",
|
||||
p.identifier)
|
||||
|
||||
case <-p.router.quit:
|
||||
return stepExit, ErrRouterShuttingDown
|
||||
}
|
||||
|
||||
return stepSkip, nil
|
||||
}
|
||||
|
||||
// Otherwise we need to make more attempts.
|
||||
return stepProceed, nil
|
||||
}
|
||||
|
||||
// resumePayment resumes the paymentLifecycle from the current state.
|
||||
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
|
||||
// When the payment lifecycle loop exits, we make sure to signal any
|
||||
@ -127,20 +188,12 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
|
||||
// critical error during path finding.
|
||||
lifecycle:
|
||||
for {
|
||||
// Start by quickly checking if there are any outcomes already
|
||||
// available to handle before we reevaluate our state.
|
||||
if err := p.checkShards(); err != nil {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
// We update the payment state on every iteration. Since the
|
||||
// payment state is affected by multiple goroutines (ie,
|
||||
// collectResultAsync), it is NOT guaranteed that we always
|
||||
// have the latest state here. This is fine as long as the
|
||||
// state is consistent as a whole.
|
||||
|
||||
// Fetch the latest payment from db.
|
||||
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
|
||||
payment, err = p.router.cfg.Control.FetchPayment(p.identifier)
|
||||
if err != nil {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
@ -153,53 +206,14 @@ lifecycle:
|
||||
p.identifier, payment.Terminated(),
|
||||
ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees)
|
||||
|
||||
// TODO(yy): sanity check all the states to make sure
|
||||
// everything is expected.
|
||||
// We have a terminal condition and no active shards, we are
|
||||
// ready to exit.
|
||||
if payment.Terminated() {
|
||||
// Find the first successful shard and return
|
||||
// the preimage and route.
|
||||
for _, a := range payment.GetHTLCs() {
|
||||
if a.Settle == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err := p.router.cfg.Control.DeleteFailedAttempts(
|
||||
p.identifier,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Error deleting failed "+
|
||||
"payment attempts for "+
|
||||
"payment %v: %v", p.identifier,
|
||||
err)
|
||||
}
|
||||
|
||||
return a.Settle.Preimage, &a.Route, nil
|
||||
}
|
||||
|
||||
// Payment failed.
|
||||
return exitWithErr(*payment.GetFailureReason())
|
||||
}
|
||||
|
||||
// If we either reached a terminal error condition (but had
|
||||
// active shards still) or there is no remaining value to send,
|
||||
// we'll wait for a shard outcome.
|
||||
wait, err := payment.NeedWaitAttempts()
|
||||
if err != nil {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
if wait {
|
||||
// We still have outstanding shards, so wait for a new
|
||||
// outcome to be available before re-evaluating our
|
||||
// state.
|
||||
if err := p.waitForShard(); err != nil {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
continue lifecycle
|
||||
}
|
||||
|
||||
// We now proceed our lifecycle with the following tasks in
|
||||
// order,
|
||||
// 1. check timeout.
|
||||
// 2. request route.
|
||||
// 3. create HTLC attempt.
|
||||
// 4. send HTLC attempt.
|
||||
// 5. collect HTLC attempt result.
|
||||
//
|
||||
// Before we attempt any new shard, we'll check to see if
|
||||
// either we've gone past the payment attempt timeout, or the
|
||||
// router is exiting. In either case, we'll stop this payment
|
||||
@ -209,6 +223,30 @@ lifecycle:
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
// Now decide the next step of the current lifecycle.
|
||||
step, err := p.decideNextStep(payment)
|
||||
if err != nil {
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
switch step {
|
||||
// Exit the for loop and return below.
|
||||
case stepExit:
|
||||
break lifecycle
|
||||
|
||||
// Continue the for loop and skip the rest.
|
||||
case stepSkip:
|
||||
continue lifecycle
|
||||
|
||||
// Continue the for loop and proceed the rest.
|
||||
case stepProceed:
|
||||
|
||||
// Unknown step received, exit with an error.
|
||||
default:
|
||||
err = fmt.Errorf("unknown step: %v", step)
|
||||
return exitWithErr(err)
|
||||
}
|
||||
|
||||
// Now request a route to be used to create our HTLC attempt.
|
||||
rt, err := p.requestRoute(ps)
|
||||
if err != nil {
|
||||
@ -241,6 +279,27 @@ lifecycle:
|
||||
p.collectResultAsync(attempt)
|
||||
}
|
||||
}
|
||||
|
||||
// Once we are out the lifecycle loop, it means we've reached a
|
||||
// terminal condition. We either return the settled preimage or the
|
||||
// payment's failure reason.
|
||||
//
|
||||
// Optionally delete the failed attempts from the database.
|
||||
err = p.router.cfg.Control.DeleteFailedAttempts(p.identifier)
|
||||
if err != nil {
|
||||
log.Errorf("Error deleting failed htlc attempts for payment "+
|
||||
"%v: %v", p.identifier, err)
|
||||
}
|
||||
|
||||
// Find the first successful shard and return the preimage and route.
|
||||
for _, a := range payment.GetHTLCs() {
|
||||
if a.Settle != nil {
|
||||
return a.Settle.Preimage, &a.Route, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise return the payment failure reason.
|
||||
return [32]byte{}, nil, *payment.GetFailureReason()
|
||||
}
|
||||
|
||||
// checkTimeout checks whether the payment has reached its timeout.
|
||||
@ -332,46 +391,9 @@ func (p *paymentLifecycle) requestRoute(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// stop signals any active shard goroutine to exit and waits for them to exit.
|
||||
// stop signals any active shard goroutine to exit.
|
||||
func (p *paymentLifecycle) stop() {
|
||||
close(p.quit)
|
||||
p.wg.Wait()
|
||||
}
|
||||
|
||||
// waitForShard blocks until any of the outstanding shards return.
|
||||
func (p *paymentLifecycle) waitForShard() error {
|
||||
select {
|
||||
case err := <-p.shardErrors:
|
||||
return err
|
||||
|
||||
case <-p.quit:
|
||||
return errShardHandlerExiting
|
||||
|
||||
case <-p.router.quit:
|
||||
return ErrRouterShuttingDown
|
||||
}
|
||||
}
|
||||
|
||||
// checkShards is a non-blocking method that check if any shards has finished
|
||||
// their execution.
|
||||
func (p *paymentLifecycle) checkShards() error {
|
||||
for {
|
||||
select {
|
||||
case err := <-p.shardErrors:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case <-p.quit:
|
||||
return errShardHandlerExiting
|
||||
|
||||
case <-p.router.quit:
|
||||
return ErrRouterShuttingDown
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// attemptResult holds the HTLC attempt and a possible error returned from
|
||||
@ -388,38 +410,33 @@ type attemptResult struct {
|
||||
}
|
||||
|
||||
// collectResultAsync launches a goroutine that will wait for the result of the
|
||||
// given HTLC attempt to be available then handle its result. It will fail the
|
||||
// payment with the control tower if a terminal error is encountered.
|
||||
// given HTLC attempt to be available then handle its result. Once received, it
|
||||
// will send a nil error to channel `resultCollected` to indicate there's an
|
||||
// result.
|
||||
func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) {
|
||||
// errToSend is the error to be sent to sh.shardErrors.
|
||||
var errToSend error
|
||||
|
||||
// handleResultErr is a function closure must be called using defer. It
|
||||
// finishes collecting result by updating the payment state and send
|
||||
// the error (or nil) to sh.shardErrors.
|
||||
handleResultErr := func() {
|
||||
// Send the error or quit.
|
||||
select {
|
||||
case p.shardErrors <- errToSend:
|
||||
case <-p.router.quit:
|
||||
case <-p.quit:
|
||||
}
|
||||
|
||||
p.wg.Done()
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go func() {
|
||||
defer handleResultErr()
|
||||
|
||||
// Block until the result is available.
|
||||
_, err := p.collectResult(attempt)
|
||||
if err != nil {
|
||||
log.Errorf("Error collecting result for attempt %v "+
|
||||
"in payment %v: %v", attempt.AttemptID,
|
||||
p.identifier, err)
|
||||
}
|
||||
|
||||
errToSend = err
|
||||
log.Debugf("Result collected for attempt %v in payment %v",
|
||||
attempt.AttemptID, p.identifier)
|
||||
|
||||
// Once the result is collected, we signal it by writing the
|
||||
// error to `resultCollected`.
|
||||
select {
|
||||
// Send the signal or quit.
|
||||
case p.resultCollected <- err:
|
||||
|
||||
case <-p.quit:
|
||||
log.Debugf("Lifecycle exiting while collecting "+
|
||||
"result for payment %v", p.identifier)
|
||||
|
||||
case <-p.router.quit:
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
@ -21,6 +21,10 @@ import (
|
||||
|
||||
const stepTimeout = 5 * time.Second
|
||||
|
||||
var (
|
||||
dummyErr = errors.New("dummy")
|
||||
)
|
||||
|
||||
// createTestRoute builds a route a->b->c paying the given amt to c.
|
||||
func createTestRoute(amt lnwire.MilliSatoshi,
|
||||
aliasMap map[string]route.Vertex) (*route.Route, error) {
|
||||
@ -1112,3 +1116,119 @@ func TestRequestRouteFailPaymentError(t *testing.T) {
|
||||
// Assert that `FailPayment` is called as expected.
|
||||
ct.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestDecideNextStep checks the method `decideNextStep` behaves as expected.
|
||||
func TestDecideNextStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// mockReturn is used to hold the return values from AllowMoreAttempts
|
||||
// or NeedWaitAttempts.
|
||||
type mockReturn struct {
|
||||
allowOrWait bool
|
||||
err error
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowMoreAttempts *mockReturn
|
||||
needWaitAttempts *mockReturn
|
||||
|
||||
// When the attemptResultChan has returned.
|
||||
closeResultChan bool
|
||||
|
||||
// Whether the router has quit.
|
||||
routerQuit bool
|
||||
|
||||
expectedStep stateStep
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "allow more attempts",
|
||||
allowMoreAttempts: &mockReturn{true, nil},
|
||||
expectedStep: stepProceed,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "error on allow more attempts",
|
||||
allowMoreAttempts: &mockReturn{false, dummyErr},
|
||||
expectedStep: stepExit,
|
||||
expectedErr: dummyErr,
|
||||
},
|
||||
{
|
||||
name: "no wait and exit",
|
||||
allowMoreAttempts: &mockReturn{false, nil},
|
||||
needWaitAttempts: &mockReturn{false, nil},
|
||||
expectedStep: stepExit,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "wait returns an error",
|
||||
allowMoreAttempts: &mockReturn{false, nil},
|
||||
needWaitAttempts: &mockReturn{false, dummyErr},
|
||||
expectedStep: stepExit,
|
||||
expectedErr: dummyErr,
|
||||
},
|
||||
|
||||
{
|
||||
name: "wait and exit on result chan",
|
||||
allowMoreAttempts: &mockReturn{false, nil},
|
||||
needWaitAttempts: &mockReturn{true, nil},
|
||||
closeResultChan: true,
|
||||
expectedStep: stepSkip,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "wait and exit on router quit",
|
||||
allowMoreAttempts: &mockReturn{false, nil},
|
||||
needWaitAttempts: &mockReturn{true, nil},
|
||||
routerQuit: true,
|
||||
expectedStep: stepExit,
|
||||
expectedErr: ErrRouterShuttingDown,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
// Create a test paymentLifecycle.
|
||||
p := createTestPaymentLifecycle()
|
||||
|
||||
// Make a mock payment.
|
||||
payment := &mockMPPayment{}
|
||||
|
||||
// Mock the method AllowMoreAttempts.
|
||||
payment.On("AllowMoreAttempts").Return(
|
||||
tc.allowMoreAttempts.allowOrWait,
|
||||
tc.allowMoreAttempts.err,
|
||||
).Once()
|
||||
|
||||
// Mock the method NeedWaitAttempts.
|
||||
if tc.needWaitAttempts != nil {
|
||||
payment.On("NeedWaitAttempts").Return(
|
||||
tc.needWaitAttempts.allowOrWait,
|
||||
tc.needWaitAttempts.err,
|
||||
).Once()
|
||||
}
|
||||
|
||||
// Send a nil error to the attemptResultChan if requested.
|
||||
if tc.closeResultChan {
|
||||
p.resultCollected = make(chan error, 1)
|
||||
p.resultCollected <- nil
|
||||
}
|
||||
|
||||
// Quit the router if requested.
|
||||
if tc.routerQuit {
|
||||
close(p.router.quit)
|
||||
}
|
||||
|
||||
// Once the setup is finished, run the test cases.
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
step, err := p.decideNextStep(payment)
|
||||
require.Equal(t, tc.expectedStep, step)
|
||||
require.ErrorIs(t, tc.expectedErr, err)
|
||||
})
|
||||
|
||||
// Check the payment's methods are called as expected.
|
||||
payment.AssertExpectations(t)
|
||||
}
|
||||
}
|
||||
|
@ -3470,34 +3470,44 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
||||
session := &mockPaymentSession{}
|
||||
sessionSource.On("NewPaymentSession", req).Return(session, nil)
|
||||
controlTower.On("InitPayment", identifier, mock.Anything).Return(nil)
|
||||
|
||||
// Mock the InFlightHTLCs.
|
||||
var (
|
||||
htlcs []channeldb.HTLCAttempt
|
||||
numAttempts atomic.Uint32
|
||||
settled atomic.Bool
|
||||
numParts = uint32(4)
|
||||
)
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("Terminated").Return(false)
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment",
|
||||
identifier,
|
||||
).Return(payment, nil).Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 4, we will mock
|
||||
// the payment's methods to allow the lifecycle to continue.
|
||||
if numAttempts.Load() < 4 {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
return
|
||||
}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 4, we will
|
||||
// mock the payment's methods to allow the lifecycle to
|
||||
// continue.
|
||||
if numAttempts.Load() < numParts {
|
||||
payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, terminate the lifecycle.
|
||||
payment.On("Terminated").Return(true).
|
||||
On("NeedWaitAttempts").Return(true, nil)
|
||||
})
|
||||
if !settled.Load() {
|
||||
fmt.Println("wait")
|
||||
payment.On("AllowMoreAttempts").Return(false, nil).Once()
|
||||
payment.On("NeedWaitAttempts").Return(true, nil).Once()
|
||||
// We add another attempt to the counter to
|
||||
// unblock next time.
|
||||
return
|
||||
}
|
||||
|
||||
payment.On("AllowMoreAttempts").Return(false, nil).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
})
|
||||
|
||||
// Mock SettleAttempt.
|
||||
preimage := lntypes.Preimage{1, 2, 3}
|
||||
@ -3511,6 +3521,10 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
||||
payment.On("GetHTLCs").Return(
|
||||
[]channeldb.HTLCAttempt{settledAttempt},
|
||||
)
|
||||
// We want to at least wait for one settlement.
|
||||
if numAttempts.Load() > 1 {
|
||||
settled.Store(true)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
@ -3527,7 +3541,6 @@ func TestSendMPPaymentSucceed(t *testing.T) {
|
||||
controlTower.On("RegisterAttempt",
|
||||
identifier, mock.Anything,
|
||||
).Return(nil).Run(func(args mock.Arguments) {
|
||||
// Increase the counter whenever an attempt is made.
|
||||
numAttempts.Add(1)
|
||||
})
|
||||
|
||||
@ -3663,29 +3676,40 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
htlcs []channeldb.HTLCAttempt
|
||||
numAttempts atomic.Uint32
|
||||
failAttemptCount atomic.Uint32
|
||||
settled atomic.Bool
|
||||
)
|
||||
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0})
|
||||
On("GetState").Return(&channeldb.MPPaymentState{FeesPaid: 0}).
|
||||
On("Terminated").Return(false)
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
|
||||
|
||||
// Mock FetchPayment to return the payment.
|
||||
controlTower.On("FetchPayment",
|
||||
identifier,
|
||||
).Return(payment, nil).Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 6, we will mock
|
||||
// the payment's methods to allow the lifecycle to continue.
|
||||
if numAttempts.Load() < 6 {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
return
|
||||
}
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil).
|
||||
Run(func(args mock.Arguments) {
|
||||
// When number of attempts made is less than 4, we will
|
||||
// mock the payment's methods to allow the lifecycle to
|
||||
// continue.
|
||||
attempts := numAttempts.Load()
|
||||
if attempts < 6 {
|
||||
payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, terminate the lifecycle.
|
||||
payment.On("Terminated").Return(true).
|
||||
On("NeedWaitAttempts").Return(true, nil)
|
||||
})
|
||||
if !settled.Load() {
|
||||
payment.On("AllowMoreAttempts").Return(false, nil).Once()
|
||||
payment.On("NeedWaitAttempts").Return(true, nil).Once()
|
||||
// We add another attempt to the counter to
|
||||
// unblock next time.
|
||||
numAttempts.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
payment.On("AllowMoreAttempts").Return(false, nil).
|
||||
On("NeedWaitAttempts").Return(false, nil)
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
// will be returned by calling RequestRoute.
|
||||
@ -3768,6 +3792,10 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) {
|
||||
payment.On("GetHTLCs").Return(
|
||||
[]channeldb.HTLCAttempt{settledAttempt},
|
||||
)
|
||||
|
||||
if numAttempts.Load() > 1 {
|
||||
settled.Store(true)
|
||||
}
|
||||
})
|
||||
|
||||
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
|
||||
@ -3885,8 +3913,8 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
||||
// Make a mock MPPayment.
|
||||
payment := &mockMPPayment{}
|
||||
payment.On("InFlightHTLCs").Return(htlcs).Once()
|
||||
payment.On("GetStatus").Return(channeldb.StatusInFlight).Once()
|
||||
payment.On("GetState").Return(&channeldb.MPPaymentState{})
|
||||
payment.On("Terminated").Return(false)
|
||||
controlTower.On("FetchPayment", identifier).Return(payment, nil).Once()
|
||||
|
||||
// Mock the sequential FetchPayment to return the payment.
|
||||
@ -3895,21 +3923,20 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
||||
// We want to at least send out all parts in order to
|
||||
// wait for them later.
|
||||
if numAttempts.Load() < numParts {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
payment.On("AllowMoreAttempts").Return(true, nil).Once()
|
||||
return
|
||||
}
|
||||
|
||||
// Wait if the payment wasn't failed yet.
|
||||
if !failed.Load() {
|
||||
payment.On("Terminated").Return(false).Times(2).
|
||||
payment.On("AllowMoreAttempts").Return(false, nil).Once().
|
||||
On("NeedWaitAttempts").Return(true, nil).Once()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
payment.On("Terminated").Return(true).
|
||||
On("GetHTLCs").Return(htlcs).Once()
|
||||
payment.On("AllowMoreAttempts").Return(false, nil).
|
||||
On("GetHTLCs").Return(htlcs).Once().
|
||||
On("NeedWaitAttempts").Return(false, nil).Once()
|
||||
})
|
||||
|
||||
// Create a route that can send 1/4 of the total amount. This value
|
||||
@ -3990,6 +4017,8 @@ func TestSendMPPaymentFailed(t *testing.T) {
|
||||
mock.Anything, mock.Anything, mock.Anything,
|
||||
).Return(nil)
|
||||
|
||||
controlTower.On("DeleteFailedAttempts", identifier).Return(nil)
|
||||
|
||||
// Call the actual method SendPayment on router. This is place inside a
|
||||
// goroutine so we can set a timeout for the whole test, in case
|
||||
// anything goes wrong and the test never finishes.
|
||||
|
Loading…
Reference in New Issue
Block a user