routing+routerrpc: cancelable context in SendPaymentV2

In this commit we set up the payment loop context
according to user-provided parameters. The
`cancelable` parameter indicates whether the user
is able to interrupt the payment loop by cancelling
the server stream context. We'll additionally wrap
the context in a deadline if the user provided a
payment timeout.
We remove the timeout channel of the payment_lifecycle.go
and in favor of the deadline context.
This commit is contained in:
Slyghtning 2024-05-16 16:38:51 +02:00
parent e729084149
commit bba01cf634
No known key found for this signature in database
GPG key ID: F82D456EA023C9BF
4 changed files with 163 additions and 69 deletions

View file

@ -360,13 +360,25 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
return err
}
// The payment context is influenced by two user-provided parameters,
// the cancelable flag and the payment attempt timeout.
// If the payment is cancelable, we will use the stream context as the
// payment context. That way, if the user ends the stream, the payment
// loop will be canceled.
// The second context parameter is the timeout. If the user provides a
// timeout, we will additionally wrap the context in a deadline. If the
// user provided 'cancelable' and ends the stream before the timeout is
// reached the payment will be canceled.
ctx := context.Background()
if req.Cancelable {
ctx = stream.Context()
}
// Send the payment asynchronously.
s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
s.cfg.Router.SendPaymentAsync(ctx, payment, paySession, shardTracker)
// Track the payment and return.
return s.trackPayment(
sub, payHash, stream, req.NoInflightUpdates,
)
return s.trackPayment(sub, payHash, stream, req.NoInflightUpdates)
}
// EstimateRouteFee allows callers to obtain an expected value w.r.t how much it

View file

@ -1,6 +1,7 @@
package routing
import (
"context"
"errors"
"fmt"
"time"
@ -29,7 +30,6 @@ type paymentLifecycle struct {
identifier lntypes.Hash
paySession PaymentSession
shardTracker shards.ShardTracker
timeoutChan <-chan time.Time
currentHeight int32
// quit is closed to signal the sub goroutines of the payment lifecycle
@ -52,7 +52,7 @@ type paymentLifecycle struct {
// newPaymentLifecycle initiates a new payment lifecycle and returns it.
func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, paySession PaymentSession,
shardTracker shards.ShardTracker, timeout time.Duration,
shardTracker shards.ShardTracker,
currentHeight int32) *paymentLifecycle {
p := &paymentLifecycle{
@ -69,13 +69,6 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
// Mount the result collector.
p.resultCollector = p.collectResultAsync
// If a timeout is specified, create a timeout channel. If no timeout is
// specified, the channel is left nil and will never abort the payment
// loop.
if timeout != 0 {
p.timeoutChan = time.After(timeout)
}
return p
}
@ -167,7 +160,9 @@ func (p *paymentLifecycle) decideNextStep(
}
// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
*route.Route, error) {
// When the payment lifecycle loop exits, we make sure to signal any
// sub goroutine of the HTLC attempt to exit, then wait for them to
// return.
@ -221,18 +216,17 @@ lifecycle:
// We now proceed our lifecycle with the following tasks in
// order,
// 1. check timeout.
// 1. check context.
// 2. request route.
// 3. create HTLC attempt.
// 4. send HTLC attempt.
// 5. collect HTLC attempt result.
//
// Before we attempt any new shard, we'll check to see if
// either we've gone past the payment attempt timeout, or the
// router is exiting. In either case, we'll stop this payment
// attempt short. If a timeout is not applicable, timeoutChan
// will be nil.
if err := p.checkTimeout(); err != nil {
// Before we attempt any new shard, we'll check to see if we've
// gone past the payment attempt timeout, or if the context was
// cancelled, or the router is exiting. In any of these cases,
// we'll stop this payment attempt short.
if err := p.checkContext(ctx); err != nil {
return exitWithErr(err)
}
@ -318,19 +312,30 @@ lifecycle:
return [32]byte{}, nil, *failure
}
// checkTimeout checks whether the payment has reached its timeout.
func (p *paymentLifecycle) checkTimeout() error {
// checkContext checks whether the payment context has been canceled.
// Cancellation occurs manually or if the context times out.
func (p *paymentLifecycle) checkContext(ctx context.Context) error {
select {
case <-p.timeoutChan:
log.Warnf("payment attempt not completed before timeout")
case <-ctx.Done():
// If the context was canceled, we'll mark the payment as
// failed. There are two cases to distinguish here: Either a
// user-provided timeout was reached, or the context was
// canceled, either to a manual cancellation or due to an
// unknown error.
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
log.Warnf("Payment attempt not completed before "+
"timeout, id=%s", p.identifier.String())
} else {
log.Warnf("Payment attempt context canceled, id=%s",
p.identifier.String())
}
// By marking the payment failed, depending on whether it has
// inflight HTLCs or not, its status will now either be
// `StatusInflight` or `StatusFailed`. In either case, no more
// HTLCs will be attempted.
err := p.router.cfg.Control.FailPayment(
p.identifier, channeldb.FailureReasonTimeout,
)
reason := channeldb.FailureReasonTimeout
err := p.router.cfg.Control.FailPayment(p.identifier, reason)
if err != nil {
return fmt.Errorf("FailPayment got %w", err)
}

View file

@ -1,6 +1,7 @@
package routing
import (
"context"
"sync/atomic"
"testing"
"time"
@ -88,7 +89,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {
// Create a test payment lifecycle with no fee limit and no timeout.
p := newPaymentLifecycle(
rt, noFeeLimit, paymentHash, mockPaymentSession,
mockShardTracker, 0, 0,
mockShardTracker, 0,
)
// Create a mock payment which is returned from mockControlTower.
@ -151,9 +152,9 @@ type resumePaymentResult struct {
err error
}
// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
// is returned.
func sendPaymentAndAssertFailed(t *testing.T,
// sendPaymentAndAssertError calls `resumePayment` and asserts that an error is
// returned.
func sendPaymentAndAssertError(t *testing.T, ctx context.Context,
p *paymentLifecycle, errExpected error) {
resultChan := make(chan *resumePaymentResult, 1)
@ -161,7 +162,7 @@ func sendPaymentAndAssertFailed(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// error.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(ctx)
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
@ -189,7 +190,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T,
// We now make a call to `resumePayment` and expect it to return the
// preimage.
go func() {
preimage, _, err := p.resumePayment()
preimage, _, err := p.resumePayment(context.Background())
resultChan <- &resumePaymentResult{
preimage: preimage,
err: err,
@ -278,6 +279,10 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo {
func TestCheckTimeoutTimedOut(t *testing.T) {
t.Parallel()
deadline := time.Now().Add(time.Nanosecond)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
p := createTestPaymentLifecycle()
// Mock the control tower's `FailPayment` method.
@ -288,14 +293,11 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
// Mount the mocked control tower.
p.router.cfg.Control = ct
// Make the timeout happens instantly.
p.timeoutChan = time.After(1 * time.Nanosecond)
// Sleep one millisecond to make sure it timed out.
time.Sleep(1 * time.Millisecond)
// Call the function and expect no error.
err := p.checkTimeout()
err := p.checkContext(ctx)
require.NoError(t, err)
// Assert that `FailPayment` is called as expected.
@ -313,13 +315,15 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
p.router.cfg.Control = ct
// Make the timeout happens instantly.
p.timeoutChan = time.After(1 * time.Nanosecond)
deadline = time.Now().Add(time.Nanosecond)
ctx, cancel = context.WithDeadline(context.Background(), deadline)
defer cancel()
// Sleep one millisecond to make sure it timed out.
time.Sleep(1 * time.Millisecond)
// Call the function and expect an error.
err = p.checkTimeout()
err = p.checkContext(ctx)
require.ErrorIs(t, err, errDummy)
// Assert that `FailPayment` is called as expected.
@ -331,10 +335,13 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
p := createTestPaymentLifecycle()
close(p.router.quit)
err := p.checkTimeout()
err := p.checkContext(ctx)
require.ErrorIs(t, err, ErrRouterShuttingDown)
}
@ -627,7 +634,7 @@ func TestResumePaymentFailOnFetchPayment(t *testing.T) {
m.control.On("FetchPayment", p.identifier).Return(nil, errDummy)
// Send the payment and assert it failed.
sendPaymentAndAssertFailed(t, p, errDummy)
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
@ -656,14 +663,15 @@ func TestResumePaymentFailOnTimeout(t *testing.T) {
}
m.payment.On("GetState").Return(ps).Once()
// NOTE: GetStatus is only used to populate the logs which is
// not critical so we loosen the checks on how many times it's
// been called.
// NOTE: GetStatus is only used to populate the logs which is not
// critical, so we loosen the checks on how many times it's been called.
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
// 3. make the timeout happens instantly and sleep one millisecond to
// make sure it timed out.
p.timeoutChan = time.After(1 * time.Nanosecond)
deadline := time.Now().Add(time.Nanosecond)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
time.Sleep(1 * time.Millisecond)
// 4. the payment should be failed with reason timeout.
@ -683,7 +691,7 @@ func TestResumePaymentFailOnTimeout(t *testing.T) {
m.payment.On("TerminalInfo").Return(nil, &reason)
// Send the payment and assert it failed with the timeout reason.
sendPaymentAndAssertFailed(t, p, reason)
sendPaymentAndAssertError(t, ctx, p, reason)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
@ -721,7 +729,65 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) {
close(p.router.quit)
// Send the payment and assert it failed when router is shutting down.
sendPaymentAndAssertFailed(t, p, ErrRouterShuttingDown)
sendPaymentAndAssertError(
t, context.Background(), p, ErrRouterShuttingDown,
)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
}
// TestResumePaymentFailContextCancel checks that the lifecycle fails when the
// context is canceled.
//
// NOTE: No parallel test because it overwrites global variables.
//
//nolint:paralleltest
func TestResumePaymentFailContextCancel(t *testing.T) {
// Create a test paymentLifecycle with the initial two calls mocked.
p, m := setupTestPaymentLifecycle(t)
// Create the cancelable payment context.
ctx, cancel := context.WithCancel(context.Background())
paymentAmt := lnwire.MilliSatoshi(10000)
// We now enter the payment lifecycle loop.
//
// 1. calls `FetchPayment` and return the payment.
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
// 2. calls `GetState` and return the state.
ps := &channeldb.MPPaymentState{
RemainingAmt: paymentAmt,
}
m.payment.On("GetState").Return(ps).Once()
// NOTE: GetStatus is only used to populate the logs which is not
// critical, so we loosen the checks on how many times it's been called.
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
// 3. Cancel the context and skip the FailPayment error to trigger the
// context cancellation of the payment.
cancel()
m.control.On(
"FailPayment", p.identifier, channeldb.FailureReasonTimeout,
).Return(nil).Once()
// 5. decideNextStep now returns stepExit.
m.payment.On("AllowMoreAttempts").Return(false, nil).Once().
On("NeedWaitAttempts").Return(false, nil).Once()
// 6. Control tower deletes failed attempts.
m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once()
// 7. We will observe FailureReasonError if the context was cancelled.
reason := channeldb.FailureReasonError
m.payment.On("TerminalInfo").Return(nil, &reason)
// Send the payment and assert it failed with the timeout reason.
sendPaymentAndAssertError(t, ctx, p, reason)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
@ -759,7 +825,7 @@ func TestResumePaymentFailOnStepErr(t *testing.T) {
m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once()
// Send the payment and assert it failed.
sendPaymentAndAssertFailed(t, p, errDummy)
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
@ -803,7 +869,7 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) {
).Return(nil, errDummy).Once()
// Send the payment and assert it failed.
sendPaymentAndAssertFailed(t, p, errDummy)
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
@ -863,7 +929,7 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) {
).Return(nil, errDummy).Once()
// Send the payment and assert it failed.
sendPaymentAndAssertFailed(t, p, errDummy)
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)
@ -955,7 +1021,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) {
).Return(nil, errDummy).Once()
// Send the payment and assert it failed.
sendPaymentAndAssertFailed(t, p, errDummy)
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
// Expected collectResultAsync to not be called.
require.Zero(t, m.collectResultsCount)

View file

@ -2,6 +2,7 @@ package routing
import (
"bytes"
"context"
"fmt"
"math"
"runtime"
@ -715,13 +716,15 @@ func (r *ChannelRouter) Start() error {
// result for the in-flight attempt is received.
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
// We pass in a zero timeout value, to indicate we
// We pass in a non-timeout context, to indicate we
// don't need it to timeout. It will stop immediately
// after the existing attempt has finished anyway. We
// also set a zero fee limit, as no more routes should
// be tried.
noTimeout := time.Duration(0)
_, _, err := r.sendPayment(
0, payment.Info.PaymentIdentifier, 0,
context.Background(), 0,
payment.Info.PaymentIdentifier, noTimeout,
paySession, shardTracker,
)
if err != nil {
@ -2406,18 +2409,16 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
log.Tracef("Dispatching SendPayment for lightning payment: %v",
spewPayment(payment))
// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
return r.sendPayment(
payment.FeeLimit, payment.Identifier(),
context.Background(), payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker,
)
}
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
ps PaymentSession, st shards.ShardTracker) {
func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
payment *LightningPayment, ps PaymentSession, st shards.ShardTracker) {
// Since this is the first time this payment is being made, we pass nil
// for the existing attempt.
@ -2429,7 +2430,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
spewPayment(payment))
_, _, err := r.sendPayment(
payment.FeeLimit, payment.Identifier(),
ctx, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, ps, st,
)
if err != nil {
@ -2604,9 +2605,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// - nil payment session (since we already have a route).
// - no payment timeout.
// - no current block height.
p := newPaymentLifecycle(
r, 0, paymentIdentifier, nil, shardTracker, 0, 0,
)
p := newPaymentLifecycle(r, 0, paymentIdentifier, nil, shardTracker, 0)
// We found a route to try, create a new HTLC attempt to try.
//
@ -2699,11 +2698,23 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// carry out its execution. After restarts, it is safe, and assumed, that the
// router will call this method for every payment still in-flight according to
// the ControlTower.
func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, timeout time.Duration,
paySession PaymentSession,
func (r *ChannelRouter) sendPayment(ctx context.Context,
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
paymentAttemptTimeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
// If the user provides a timeout, we will additionally wrap the context
// in a deadline.
cancel := func() {}
if paymentAttemptTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, paymentAttemptTimeout)
}
// Since resumePayment is a blocking call, we'll cancel this
// context if the payment completes before the optional
// deadline.
defer cancel()
// We'll also fetch the current block height, so we can properly
// calculate the required HTLC time locks within the route.
_, currentHeight, err := r.cfg.Chain.GetBestBlock()
@ -2714,11 +2725,11 @@ func (r *ChannelRouter) sendPayment(feeLimit lnwire.MilliSatoshi,
// Now set up a paymentLifecycle struct with these params, such that we
// can resume the payment from the current state.
p := newPaymentLifecycle(
r, feeLimit, identifier, paySession,
shardTracker, timeout, currentHeight,
r, feeLimit, identifier, paySession, shardTracker,
currentHeight,
)
return p.resumePayment()
return p.resumePayment(ctx)
}
// extractChannelUpdate examines the error and extracts the channel update.