routing/payment_lifecycle: use ShardTracker to track shards

We'll let the payment's lifecycle register each shard it's sending with
the ShardTracker, canceling failed shards. This will be the foundation
for correct AMP derivation for each shard we'll send.
This commit is contained in:
Johan T. Halseth 2021-04-12 15:21:59 +02:00
parent 6474b253d6
commit 41ae3530a3
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
4 changed files with 137 additions and 46 deletions

View File

@ -194,6 +194,8 @@ func newRoute(sourceVertex route.Vertex,
}
// Otherwise attach the mpp record if it exists.
// TODO(halseth): move this to payment life cycle,
// where AMP options are set.
if finalHop.paymentAddr != nil {
mpp = record.NewMPP(
finalHop.totalAmt,

View File

@ -12,6 +12,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
)
// errShardHandlerExiting is returned from the shardHandler when it exits.
@ -25,6 +26,7 @@ type paymentLifecycle struct {
feeLimit lnwire.MilliSatoshi
paymentHash lntypes.Hash
paySession PaymentSession
shardTracker shards.ShardTracker
timeoutChan <-chan time.Time
currentHeight int32
}
@ -83,10 +85,11 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
// resumePayment resumes the paymentLifecycle from the current state.
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
shardHandler := &shardHandler{
router: p.router,
paymentHash: p.paymentHash,
shardErrors: make(chan error),
quit: make(chan struct{}),
router: p.router,
paymentHash: p.paymentHash,
shardTracker: p.shardTracker,
shardErrors: make(chan error),
quit: make(chan struct{}),
}
// When the payment lifecycle loop exits, we make sure to signal any
@ -246,8 +249,12 @@ lifecycle:
continue lifecycle
}
// If this route will consume the last remeining amount to send
// to the receiver, this will be our last shard (for now).
lastShard := rt.ReceiverAmt() == state.remainingAmt
// We found a route to try, launch a new shard.
attempt, outcome, err := shardHandler.launchShard(rt)
attempt, outcome, err := shardHandler.launchShard(rt, lastShard)
switch {
// We may get a terminal error if we've processed a shard with
// a terminal state (settled or permanent failure), while we
@ -294,8 +301,9 @@ lifecycle:
// shardHandler holds what is necessary to send and collect the result of
// shards.
type shardHandler struct {
paymentHash lntypes.Hash
router *ChannelRouter
paymentHash lntypes.Hash
router *ChannelRouter
shardTracker shards.ShardTracker
// shardErrors is a channel where errors collected by calling
// collectResultAsync will be delivered. These results are meant to be
@ -366,19 +374,20 @@ type launchOutcome struct {
}
// launchShard creates and sends an HTLC attempt along the given route,
// registering it with the control tower before sending it. It returns the
// HTLCAttemptInfo that was created for the shard, along with a launchOutcome.
// The launchOutcome is used to indicate whether the attempt was successfully
// sent. If the launchOutcome wraps a non-nil error, it means that the attempt
// was not sent onto the network, so no result will be available in the future
// for it.
func (p *shardHandler) launchShard(rt *route.Route) (*channeldb.HTLCAttemptInfo,
*launchOutcome, error) {
// registering it with the control tower before sending it. The lastShard
// argument should be true if this shard will consume the remainder of the
// amount to send. It returns the HTLCAttemptInfo that was created for the
// shard, along with a launchOutcome. The launchOutcome is used to indicate
// whether the attempt was successfully sent. If the launchOutcome wraps a
// non-nil error, it means that the attempt was not sent onto the network, so
// no result will be available in the future for it.
func (p *shardHandler) launchShard(rt *route.Route,
lastShard bool) (*channeldb.HTLCAttemptInfo, *launchOutcome, error) {
// Using the route received from the payment session, create a new
// shard to send.
firstHop, htlcAdd, attempt, err := p.createNewPaymentAttempt(
rt,
rt, lastShard,
)
if err != nil {
return nil, nil, err
@ -480,10 +489,17 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) (
*shardResult, error) {
// We'll retrieve the hash specific to this shard from the
// shardTracker, since it will be needed to regenerate the circuit
// below.
hash, err := p.shardTracker.GetHash(attempt.AttemptID)
if err != nil {
return nil, err
}
// Regenerate the circuit for this attempt.
_, circuit, err := generateSphinxPacket(
&attempt.Route, p.paymentHash[:],
attempt.SessionKey,
&attempt.Route, hash[:], attempt.SessionKey,
)
if err != nil {
return nil, err
@ -597,7 +613,7 @@ func (p *shardHandler) collectResult(attempt *channeldb.HTLCAttemptInfo) (
}
// createNewPaymentAttempt creates a new payment attempt from the given route.
func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
func (p *shardHandler) createNewPaymentAttempt(rt *route.Route, lastShard bool) (
lnwire.ShortChannelID, *lnwire.UpdateAddHTLC,
*channeldb.HTLCAttemptInfo, error) {
@ -607,12 +623,39 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
return lnwire.ShortChannelID{}, nil, nil, err
}
// We generate a new, unique payment ID that we will use for
// this HTLC.
attemptID, err := p.router.cfg.NextPaymentID()
if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err
}
// Requesst a new shard from the ShardTracker. If this is an AMP
// payment, and this is the last shard, the outstanding shards together
// with ths one will be enough for the receiver to derive all HTLC
// preimages. If this a non-AMP payment, the ShardTracker will return a
// simple shard with the payment's static payment hash.
shard, err := p.shardTracker.NewShard(attemptID, lastShard)
if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err
}
// It this shard carries MPP or AMP options, add them to the last hop
// on the route.
hop := rt.Hops[len(rt.Hops)-1]
if shard.MPP() != nil {
hop.MPP = shard.MPP()
}
if shard.AMP() != nil {
hop.AMP = shard.AMP()
}
// Generate the raw encoded sphinx packet to be included along
// with the htlcAdd message that we send directly to the
// switch.
onionBlob, _, err := generateSphinxPacket(
rt, p.paymentHash[:], sessionKey,
)
hash := shard.Hash()
onionBlob, _, err := generateSphinxPacket(rt, hash[:], sessionKey)
if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err
}
@ -623,7 +666,7 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
htlcAdd := &lnwire.UpdateAddHTLC{
Amount: rt.TotalAmount,
Expiry: rt.TotalTimeLock,
PaymentHash: p.paymentHash,
PaymentHash: hash,
}
copy(htlcAdd.OnionBlob[:], onionBlob)
@ -634,13 +677,6 @@ func (p *shardHandler) createNewPaymentAttempt(rt *route.Route) (
rt.Hops[0].ChannelID,
)
// We generate a new, unique payment ID that we will use for
// this HTLC.
attemptID, err := p.router.cfg.NextPaymentID()
if err != nil {
return lnwire.ShortChannelID{}, nil, nil, err
}
// We now have all the information needed to populate
// the current attempt information.
attempt := &channeldb.HTLCAttemptInfo{
@ -722,6 +758,13 @@ func (p *shardHandler) failAttempt(attempt *channeldb.HTLCAttemptInfo,
p.router.cfg.Clock.Now(),
)
// Now that we are failing this payment attempt, cancel the shard with
// the ShardTracker such that it can derive the correct hash for the
// next attempt.
if err := p.shardTracker.CancelShard(attempt.AttemptID); err != nil {
return nil, err
}
return p.router.cfg.Control.FailAttempt(
p.paymentHash, attempt.AttemptID,
failInfo,

View File

@ -29,6 +29,7 @@ import (
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/chainview"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/zpay32"
)
@ -603,19 +604,40 @@ func (r *ChannelRouter) Start() error {
go func(payment *channeldb.MPPayment) {
defer r.wg.Done()
// Get the hashes used for the outstanding HTLCs.
htlcs := make(map[uint64]lntypes.Hash)
for _, a := range payment.HTLCs {
a := a
hash := payment.Info.PaymentHash
htlcs[a.AttemptID] = hash
}
// Since we are not supporting creating more shards
// after a restart (only receiving the result of the
// shards already outstanding), we create a simple
// shard tracker that will map the attempt IDs to
// hashes used for the HTLCs. This will be enough also
// for AMP payments, since we only need the hashes for
// the individual HTLCs to regenerate the circuits, and
// we don't currently persist the root share necessary
// to re-derive them.
shardTracker := shards.NewSimpleShardTracker(
payment.Info.PaymentHash, htlcs,
)
// We create a dummy, empty payment session such that
// we won't make another payment attempt when the
// result for the in-flight attempt is received.
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
// We pass in a zero timeout value, to indicate we
// don't need it to timeout. It will stop immediately
// after the existing attempt has finished anyway. We
// also set a zero fee limit, as no more routes should
// be tried.
_, _, err := r.sendPayment(
payment.Info.Value, 0,
payment.Info.PaymentHash, 0, paySession,
payment.Info.Value, 0, payment.Info.PaymentHash,
0, paySession, shardTracker,
)
if err != nil {
log.Errorf("Resuming payment with hash %v "+
@ -1770,7 +1792,7 @@ type LightningPayment struct {
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
*route.Route, error) {
paySession, err := r.preparePayment(payment)
paySession, shardTracker, err := r.preparePayment(payment)
if err != nil {
return [32]byte{}, nil, err
}
@ -1782,14 +1804,14 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
// for the existing attempt.
return r.sendPayment(
payment.Amount, payment.FeeLimit, payment.PaymentHash,
payment.PayAttemptTimeout, paySession,
payment.PayAttemptTimeout, paySession, shardTracker,
)
}
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
paySession, err := r.preparePayment(payment)
paySession, shardTracker, err := r.preparePayment(payment)
if err != nil {
return err
}
@ -1805,7 +1827,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
_, _, err := r.sendPayment(
payment.Amount, payment.FeeLimit, payment.PaymentHash,
payment.PayAttemptTimeout, paySession,
payment.PayAttemptTimeout, paySession, shardTracker,
)
if err != nil {
log.Errorf("Payment with hash %x failed: %v",
@ -1841,14 +1863,14 @@ func spewPayment(payment *LightningPayment) logClosure {
// preparePayment creates the payment session and registers the payment with the
// control tower.
func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
PaymentSession, error) {
PaymentSession, shards.ShardTracker, error) {
// Before starting the HTLC routing attempt, we'll create a fresh
// payment session which will report our errors back to mission
// control.
paySession, err := r.cfg.SessionSource.NewPaymentSession(payment)
if err != nil {
return nil, err
return nil, nil, err
}
// Record this payment hash with the ControlTower, ensuring it is not
@ -1862,12 +1884,18 @@ func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
PaymentRequest: payment.PaymentRequest,
}
// Create a new ShardTracker that we'll use during the life cycle of
// this payment.
shardTracker := shards.NewSimpleShardTracker(
payment.PaymentHash, nil,
)
err = r.cfg.Control.InitPayment(payment.PaymentHash, info)
if err != nil {
return nil, err
return nil, nil, err
}
return paySession, nil
return paySession, shardTracker, nil
}
// SendToRoute attempts to send a payment with the given hash through the
@ -1915,14 +1943,22 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
}),
)
// Since the HTLC hashes and preimages are specified manually over the
// RPC for SendToRoute requests, we don't have to worry about creating
// a ShardTracker that can generate hashes for AMP payments. Instead we
// create a simple tracker that can just return the hash for the single
// shard we'll now launch.
shardTracker := shards.NewSimpleShardTracker(hash, nil)
// Launch a shard along the given route.
sh := &shardHandler{
router: r,
paymentHash: hash,
router: r,
paymentHash: hash,
shardTracker: shardTracker,
}
var shardError error
attempt, outcome, err := sh.launchShard(rt)
attempt, outcome, err := sh.launchShard(rt, false)
// With SendToRoute, it can happen that the route exceeds protocol
// constraints. Mark the payment as failed with an internal error.
@ -2007,8 +2043,8 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
// the ControlTower.
func (r *ChannelRouter) sendPayment(
totalAmt, feeLimit lnwire.MilliSatoshi, paymentHash lntypes.Hash,
timeout time.Duration,
paySession PaymentSession) ([32]byte, *route.Route, error) {
timeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
// We'll also fetch the current block height so we can properly
// calculate the required HTLC time locks within the route.
@ -2025,6 +2061,7 @@ func (r *ChannelRouter) sendPayment(
feeLimit: feeLimit,
paymentHash: paymentHash,
paySession: paySession,
shardTracker: shardTracker,
currentHeight: currentHeight,
}

View File

@ -2,6 +2,7 @@ package shards
import (
"fmt"
"sync"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/record"
@ -74,6 +75,7 @@ func (s *Shard) AMP() *record.AMP {
type SimpleShardTracker struct {
hash lntypes.Hash
shards map[uint64]lntypes.Hash
sync.Mutex
}
// A compile time check to ensure SimpleShardTracker implements the
@ -100,7 +102,9 @@ func NewSimpleShardTracker(paymentHash lntypes.Hash,
// if it ends up not being used by the overall payment, i.e. if the attempt
// fails.
func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
m.Lock()
m.shards[id] = m.hash
m.Unlock()
return &Shard{
hash: m.hash,
@ -109,14 +113,19 @@ func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
// CancelShard cancel's the shard corresponding to the given attempt ID.
func (m *SimpleShardTracker) CancelShard(id uint64) error {
m.Lock()
delete(m.shards, id)
m.Unlock()
return nil
}
// GetHash retrieves the hash used by the shard of the given attempt ID. This
// will return an error if the attempt ID is unknown.
func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) {
m.Lock()
hash, ok := m.shards[id]
m.Unlock()
if !ok {
return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+
"not found", id)