multi: use wire records on payment and intercept flows

This commit is contained in:
George Tsagkarelis 2024-04-16 12:29:15 +02:00 committed by Oliver Gugger
parent aa86020b84
commit 878f964a33
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
10 changed files with 108 additions and 75 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -645,15 +646,16 @@ func (f *interceptedForward) Packet() InterceptedPacket {
ChanID: f.packet.incomingChanID, ChanID: f.packet.incomingChanID,
HtlcID: f.packet.incomingHTLCID, HtlcID: f.packet.incomingHTLCID,
}, },
OutgoingChanID: f.packet.outgoingChanID, OutgoingChanID: f.packet.outgoingChanID,
Hash: f.htlc.PaymentHash, Hash: f.htlc.PaymentHash,
OutgoingExpiry: f.htlc.Expiry, OutgoingExpiry: f.htlc.Expiry,
OutgoingAmount: f.htlc.Amount, OutgoingAmount: f.htlc.Amount,
IncomingAmount: f.packet.incomingAmount, IncomingAmount: f.packet.incomingAmount,
IncomingExpiry: f.packet.incomingTimeout, IncomingExpiry: f.packet.incomingTimeout,
CustomRecords: f.packet.customRecords, InOnionCustomRecords: f.packet.inOnionCustomRecords,
OnionBlob: f.htlc.OnionBlob, OnionBlob: f.htlc.OnionBlob,
AutoFailHeight: f.autoFailHeight, AutoFailHeight: f.autoFailHeight,
InWireCustomRecords: f.packet.inWireCustomRecords,
} }
} }
@ -723,6 +725,8 @@ func (f *interceptedForward) ResumeModified(
} }
} }
log.Tracef("Forwarding packet %v", lnutils.SpewLogClosure(f.packet))
// Forward to the switch. A link quit channel isn't needed, because we // Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now. // are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet) return f.htlcSwitch.ForwardPackets(nil, f.packet)

View File

@ -357,13 +357,17 @@ type InterceptedPacket struct {
// IncomingAmount is the amount of the accepted htlc. // IncomingAmount is the amount of the accepted htlc.
IncomingAmount lnwire.MilliSatoshi IncomingAmount lnwire.MilliSatoshi
// CustomRecords are user-defined records in the custom type range that // InOnionCustomRecords are user-defined records in the custom type
// were included in the payload. // range that were included in the payload.
CustomRecords record.CustomSet InOnionCustomRecords record.CustomSet
// OnionBlob is the onion packet for the next hop // OnionBlob is the onion packet for the next hop
OnionBlob [lnwire.OnionPacketSize]byte OnionBlob [lnwire.OnionPacketSize]byte
// InWireCustomRecords are user-defined p2p wire message records that
// were defined by the peer that forwarded this HTLC to us.
InWireCustomRecords lnwire.CustomRecords
// AutoFailHeight is the block height at which this intercept will be // AutoFailHeight is the block height at which this intercept will be
// failed back automatically. // failed back automatically.
AutoFailHeight int32 AutoFailHeight int32

View File

@ -3630,7 +3630,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
} }
// Otherwise, it was already processed, we can // Otherwise, it was already processed, we can
// can collect it and continue. // collect it and continue.
addMsg := &lnwire.UpdateAddHTLC{ addMsg := &lnwire.UpdateAddHTLC{
Expiry: fwdInfo.OutgoingCTLV, Expiry: fwdInfo.OutgoingCTLV,
Amount: fwdInfo.AmountToForward, Amount: fwdInfo.AmountToForward,
@ -3650,19 +3650,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
inboundFee := l.cfg.FwrdingPolicy.InboundFee inboundFee := l.cfg.FwrdingPolicy.InboundFee
//nolint:lll
updatePacket := &htlcPacket{ updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(), incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex, incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop, outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef, sourceRef: pd.SourceRef,
incomingAmount: pd.Amount, incomingAmount: pd.Amount,
amount: addMsg.Amount, amount: addMsg.Amount,
htlc: addMsg, htlc: addMsg,
obfuscator: obfuscator, obfuscator: obfuscator,
incomingTimeout: pd.Timeout, incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV, outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(), inOnionCustomRecords: pld.CustomRecords(),
inboundFee: inboundFee, inboundFee: inboundFee,
inWireCustomRecords: pd.CustomRecords.Copy(),
} }
switchPackets = append( switchPackets = append(
switchPackets, updatePacket, switchPackets, updatePacket,
@ -3718,19 +3720,21 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
if fwdPkg.State == channeldb.FwdStateLockedIn { if fwdPkg.State == channeldb.FwdStateLockedIn {
inboundFee := l.cfg.FwrdingPolicy.InboundFee inboundFee := l.cfg.FwrdingPolicy.InboundFee
//nolint:lll
updatePacket := &htlcPacket{ updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(), incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex, incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop, outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef, sourceRef: pd.SourceRef,
incomingAmount: pd.Amount, incomingAmount: pd.Amount,
amount: addMsg.Amount, amount: addMsg.Amount,
htlc: addMsg, htlc: addMsg,
obfuscator: obfuscator, obfuscator: obfuscator,
incomingTimeout: pd.Timeout, incomingTimeout: pd.Timeout,
outgoingTimeout: fwdInfo.OutgoingCTLV, outgoingTimeout: fwdInfo.OutgoingCTLV,
customRecords: pld.CustomRecords(), inOnionCustomRecords: pld.CustomRecords(),
inboundFee: inboundFee, inboundFee: inboundFee,
inWireCustomRecords: pd.CustomRecords.Copy(),
} }
fwdPkg.FwdFilter.Set(idx) fwdPkg.FwdFilter.Set(idx)

View File

@ -94,9 +94,13 @@ type htlcPacket struct {
// link. // link.
outgoingTimeout uint32 outgoingTimeout uint32
// customRecords are user-defined records in the custom type range that // inOnionCustomRecords are user-defined records in the custom type
// were included in the payload. // range that were included in the onion payload.
customRecords record.CustomSet inOnionCustomRecords record.CustomSet
// inWireCustomRecords are custom type range TLVs that are included
// in the incoming update_add_htlc wire message.
inWireCustomRecords lnwire.CustomRecords
// originalOutgoingChanID is used when sending back failure messages. // originalOutgoingChanID is used when sending back failure messages.
// It is only used for forwarded Adds on option_scid_alias channels. // It is only used for forwarded Adds on option_scid_alias channels.

View File

@ -89,7 +89,7 @@ func (r *forwardInterceptor) onIntercept(
OutgoingExpiry: htlc.OutgoingExpiry, OutgoingExpiry: htlc.OutgoingExpiry,
IncomingAmountMsat: uint64(htlc.IncomingAmount), IncomingAmountMsat: uint64(htlc.IncomingAmount),
IncomingExpiry: htlc.IncomingExpiry, IncomingExpiry: htlc.IncomingExpiry,
CustomRecords: htlc.CustomRecords, CustomRecords: htlc.InOnionCustomRecords,
OnionBlob: htlc.OnionBlob[:], OnionBlob: htlc.OnionBlob[:],
AutoFailHeight: htlc.AutoFailHeight, AutoFailHeight: htlc.AutoFailHeight,
} }

View File

@ -207,6 +207,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64,
Index: uint16(i), Index: uint16(i),
}, },
BlindingPoint: wireMsg.BlindingPoint, BlindingPoint: wireMsg.BlindingPoint,
CustomRecords: wireMsg.CustomRecords.Copy(),
} }
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
@ -1154,6 +1155,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate,
LogIndex: logUpdate.LogIndex, LogIndex: logUpdate.LogIndex,
addCommitHeightRemote: commitHeight, addCommitHeightRemote: commitHeight,
BlindingPoint: wireMsg.BlindingPoint, BlindingPoint: wireMsg.BlindingPoint,
CustomRecords: wireMsg.CustomRecords.Copy(),
} }
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:])
@ -1359,6 +1361,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd
LogIndex: logUpdate.LogIndex, LogIndex: logUpdate.LogIndex,
addCommitHeightLocal: commitHeight, addCommitHeightLocal: commitHeight,
BlindingPoint: wireMsg.BlindingPoint, BlindingPoint: wireMsg.BlindingPoint,
CustomRecords: wireMsg.CustomRecords.Copy(),
} }
pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob))
copy(pd.OnionBlob, wireMsg.OnionBlob[:]) copy(pd.OnionBlob, wireMsg.OnionBlob[:])
@ -3403,6 +3406,7 @@ func (lc *LightningChannel) createCommitDiff(newCommit *commitment,
Expiry: pd.Timeout, Expiry: pd.Timeout,
PaymentHash: pd.RHash, PaymentHash: pd.RHash,
BlindingPoint: pd.BlindingPoint, BlindingPoint: pd.BlindingPoint,
CustomRecords: pd.CustomRecords.Copy(),
} }
copy(htlc.OnionBlob[:], pd.OnionBlob) copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc logUpdate.UpdateMsg = htlc
@ -3543,6 +3547,7 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate {
Expiry: pd.Timeout, Expiry: pd.Timeout,
PaymentHash: pd.RHash, PaymentHash: pd.RHash,
BlindingPoint: pd.BlindingPoint, BlindingPoint: pd.BlindingPoint,
CustomRecords: pd.CustomRecords.Copy(),
} }
copy(htlc.OnionBlob[:], pd.OnionBlob) copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc logUpdate.UpdateMsg = htlc
@ -5620,6 +5625,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
Expiry: pd.Timeout, Expiry: pd.Timeout,
PaymentHash: pd.RHash, PaymentHash: pd.RHash,
BlindingPoint: pd.BlindingPoint, BlindingPoint: pd.BlindingPoint,
CustomRecords: pd.CustomRecords.Copy(),
} }
copy(htlc.OnionBlob[:], pd.OnionBlob) copy(htlc.OnionBlob[:], pd.OnionBlob)
logUpdate.UpdateMsg = htlc logUpdate.UpdateMsg = htlc
@ -5965,9 +5971,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC,
OnionBlob: htlc.OnionBlob[:], OnionBlob: htlc.OnionBlob[:],
OpenCircuitKey: openKey, OpenCircuitKey: openKey,
BlindingPoint: htlc.BlindingPoint, BlindingPoint: htlc.BlindingPoint,
// TODO(guggero): Add custom records from HTLC here once we have CustomRecords: htlc.CustomRecords.Copy(),
// the custom records in the HTLC struct (later commits in this
// PR).
} }
} }
@ -6028,9 +6032,7 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64,
HtlcIndex: lc.updateLogs.Remote.htlcCounter, HtlcIndex: lc.updateLogs.Remote.htlcCounter,
OnionBlob: htlc.OnionBlob[:], OnionBlob: htlc.OnionBlob[:],
BlindingPoint: htlc.BlindingPoint, BlindingPoint: htlc.BlindingPoint,
// TODO(guggero): Add custom records from HTLC here once we have CustomRecords: htlc.CustomRecords.Copy(),
// the custom records in the HTLC struct (later commits in this
// PR).
} }
localACKedIndex := lc.commitChains.Remote.tail().messageIndices.Local localACKedIndex := lc.commitChains.Remote.tail().messageIndices.Local

View File

@ -25,12 +25,13 @@ var ErrPaymentLifecycleExiting = errors.New("payment lifecycle exiting")
// paymentLifecycle holds all information about the current state of a payment // paymentLifecycle holds all information about the current state of a payment
// needed to resume if from any point. // needed to resume if from any point.
type paymentLifecycle struct { type paymentLifecycle struct {
router *ChannelRouter router *ChannelRouter
feeLimit lnwire.MilliSatoshi feeLimit lnwire.MilliSatoshi
identifier lntypes.Hash identifier lntypes.Hash
paySession PaymentSession paySession PaymentSession
shardTracker shards.ShardTracker shardTracker shards.ShardTracker
currentHeight int32 currentHeight int32
firstHopCustomRecords lnwire.CustomRecords
// quit is closed to signal the sub goroutines of the payment lifecycle // quit is closed to signal the sub goroutines of the payment lifecycle
// to stop. // to stop.
@ -52,18 +53,19 @@ type paymentLifecycle struct {
// newPaymentLifecycle initiates a new payment lifecycle and returns it. // newPaymentLifecycle initiates a new payment lifecycle and returns it.
func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
identifier lntypes.Hash, paySession PaymentSession, identifier lntypes.Hash, paySession PaymentSession,
shardTracker shards.ShardTracker, shardTracker shards.ShardTracker, currentHeight int32,
currentHeight int32) *paymentLifecycle { firstHopCustomRecords lnwire.CustomRecords) *paymentLifecycle {
p := &paymentLifecycle{ p := &paymentLifecycle{
router: r, router: r,
feeLimit: feeLimit, feeLimit: feeLimit,
identifier: identifier, identifier: identifier,
paySession: paySession, paySession: paySession,
shardTracker: shardTracker, shardTracker: shardTracker,
currentHeight: currentHeight, currentHeight: currentHeight,
quit: make(chan struct{}), quit: make(chan struct{}),
resultCollected: make(chan error, 1), resultCollected: make(chan error, 1),
firstHopCustomRecords: firstHopCustomRecords,
} }
// Mount the result collector. // Mount the result collector.
@ -677,9 +679,10 @@ func (p *paymentLifecycle) sendAttempt(
// this packet will be used to route the payment through the network, // this packet will be used to route the payment through the network,
// starting with the first-hop. // starting with the first-hop.
htlcAdd := &lnwire.UpdateAddHTLC{ htlcAdd := &lnwire.UpdateAddHTLC{
Amount: rt.TotalAmount, Amount: rt.TotalAmount,
Expiry: rt.TotalTimeLock, Expiry: rt.TotalTimeLock,
PaymentHash: *attempt.Hash, PaymentHash: *attempt.Hash,
CustomRecords: p.firstHopCustomRecords,
} }
// Generate the raw encoded sphinx packet to be included along // Generate the raw encoded sphinx packet to be included along

View File

@ -89,7 +89,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {
// Create a test payment lifecycle with no fee limit and no timeout. // Create a test payment lifecycle with no fee limit and no timeout.
p := newPaymentLifecycle( p := newPaymentLifecycle(
rt, noFeeLimit, paymentHash, mockPaymentSession, rt, noFeeLimit, paymentHash, mockPaymentSession,
mockShardTracker, 0, mockShardTracker, 0, nil,
) )
// Create a mock payment which is returned from mockControlTower. // Create a mock payment which is returned from mockControlTower.

View File

@ -865,6 +865,11 @@ type LightningPayment struct {
// fail. // fail.
DestCustomRecords record.CustomSet DestCustomRecords record.CustomSet
// FirstHopCustomRecords are the TLV records that are to be sent to the
// first hop of this payment. These records will be transmitted via the
// wire message and therefore do not affect the onion payload size.
FirstHopCustomRecords lnwire.CustomRecords
// MaxParts is the maximum number of partial payments that may be used // MaxParts is the maximum number of partial payments that may be used
// to complete the full amount. // to complete the full amount.
MaxParts uint32 MaxParts uint32
@ -948,6 +953,7 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
return r.sendPayment( return r.sendPayment(
context.Background(), payment.FeeLimit, payment.Identifier(), context.Background(), payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, paySession, shardTracker, payment.PayAttemptTimeout, paySession, shardTracker,
payment.FirstHopCustomRecords,
) )
} }
@ -968,6 +974,7 @@ func (r *ChannelRouter) SendPaymentAsync(ctx context.Context,
_, _, err := r.sendPayment( _, _, err := r.sendPayment(
ctx, payment.FeeLimit, payment.Identifier(), ctx, payment.FeeLimit, payment.Identifier(),
payment.PayAttemptTimeout, ps, st, payment.PayAttemptTimeout, ps, st,
payment.FirstHopCustomRecords,
) )
if err != nil { if err != nil {
log.Errorf("Payment %x failed: %v", log.Errorf("Payment %x failed: %v",
@ -1141,7 +1148,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// - nil payment session (since we already have a route). // - nil payment session (since we already have a route).
// - no payment timeout. // - no payment timeout.
// - no current block height. // - no current block height.
p := newPaymentLifecycle(r, 0, paymentIdentifier, nil, shardTracker, 0) p := newPaymentLifecycle(
r, 0, paymentIdentifier, nil, shardTracker, 0, nil,
)
// We found a route to try, create a new HTLC attempt to try. // We found a route to try, create a new HTLC attempt to try.
// //
@ -1237,7 +1246,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
func (r *ChannelRouter) sendPayment(ctx context.Context, func (r *ChannelRouter) sendPayment(ctx context.Context,
feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash, feeLimit lnwire.MilliSatoshi, identifier lntypes.Hash,
paymentAttemptTimeout time.Duration, paySession PaymentSession, paymentAttemptTimeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) { shardTracker shards.ShardTracker,
firstHopCustomRecords lnwire.CustomRecords) ([32]byte, *route.Route,
error) {
// If the user provides a timeout, we will additionally wrap the context // If the user provides a timeout, we will additionally wrap the context
// in a deadline. // in a deadline.
@ -1262,7 +1273,7 @@ func (r *ChannelRouter) sendPayment(ctx context.Context,
// can resume the payment from the current state. // can resume the payment from the current state.
p := newPaymentLifecycle( p := newPaymentLifecycle(
r, feeLimit, identifier, paySession, shardTracker, r, feeLimit, identifier, paySession, shardTracker,
currentHeight, currentHeight, firstHopCustomRecords,
) )
return p.resumePayment(ctx) return p.resumePayment(ctx)
@ -1465,7 +1476,7 @@ func (r *ChannelRouter) resumePayments() error {
noTimeout := time.Duration(0) noTimeout := time.Duration(0)
_, _, err := r.sendPayment( _, _, err := r.sendPayment(
context.Background(), 0, payHash, noTimeout, paySession, context.Background(), 0, payHash, noTimeout, paySession,
shardTracker, shardTracker, nil,
) )
if err != nil { if err != nil {
log.Errorf("Resuming payment %v failed: %v", payHash, log.Errorf("Resuming payment %v failed: %v", payHash,

View File

@ -101,10 +101,11 @@ func (p *preimageBeacon) SubscribeUpdates(
ChanID: chanID, ChanID: chanID,
HtlcID: htlc.HtlcIndex, HtlcID: htlc.HtlcIndex,
}, },
OutgoingChanID: payload.FwdInfo.NextHop, OutgoingChanID: payload.FwdInfo.NextHop,
OutgoingExpiry: payload.FwdInfo.OutgoingCTLV, OutgoingExpiry: payload.FwdInfo.OutgoingCTLV,
OutgoingAmount: payload.FwdInfo.AmountToForward, OutgoingAmount: payload.FwdInfo.AmountToForward,
CustomRecords: payload.CustomRecords(), InOnionCustomRecords: payload.CustomRecords(),
InWireCustomRecords: htlc.CustomRecords,
} }
copy(packet.OnionBlob[:], nextHopOnionBlob) copy(packet.OnionBlob[:], nextHopOnionBlob)