diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 086825568..a193c654a 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -1,10 +1,14 @@ package routing import ( + "fmt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) // bandwidthHints provides hints about the currently available balance in our @@ -19,6 +23,42 @@ type bandwidthHints interface { // returned. availableChanBandwidth(channelID uint64, amount lnwire.MilliSatoshi) (lnwire.MilliSatoshi, bool) + + // firstHopCustomBlob returns the custom blob for the first hop of the + // payment, if available. + firstHopCustomBlob() fn.Option[tlv.Blob] +} + +// TlvTrafficShaper is an interface that allows the sender to determine if a +// payment should be carried by a channel based on the TLV records that may be +// present in the `update_add_htlc` message or the channel commitment itself. +type TlvTrafficShaper interface { + AuxHtlcModifier + + // ShouldHandleTraffic is called in order to check if the channel + // identified by the provided channel ID may have external mechanisms + // that would allow it to carry out the payment. + ShouldHandleTraffic(cid lnwire.ShortChannelID, + fundingBlob fn.Option[tlv.Blob]) (bool, error) + + // PaymentBandwidth returns the available bandwidth for a custom channel + // decided by the given channel aux blob and HTLC blob. A return value + // of 0 means there is no bandwidth available. To find out if a channel + // is a custom channel that should be handled by the traffic shaper, the + // HandleTraffic method should be called first. + PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], + linkBandwidth lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) +} + +// AuxHtlcModifier is an interface that allows the sender to modify the outgoing +// HTLC of a payment by changing the amount or the wire message tlv records. +type AuxHtlcModifier interface { + // ProduceHtlcExtraData is a function that, based on the previous extra + // data blob of an HTLC, may produce a different blob or modify the + // amount of bitcoin this htlc should carry. + ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, + htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, + lnwire.CustomRecords, error) } // getLinkQuery is the function signature used to lookup a link. @@ -29,8 +69,10 @@ type getLinkQuery func(lnwire.ShortChannelID) ( // uses the link lookup provided to query the link for our latest local channel // balances. type bandwidthManager struct { - getLink getLinkQuery - localChans map[lnwire.ShortChannelID]struct{} + getLink getLinkQuery + localChans map[lnwire.ShortChannelID]struct{} + firstHopBlob fn.Option[tlv.Blob] + trafficShaper fn.Option[TlvTrafficShaper] } // newBandwidthManager creates a bandwidth manager for the source node provided @@ -40,11 +82,14 @@ type bandwidthManager struct { // allows us to reduce the number of extraneous attempts as we can skip channels // that are inactive, or just don't have enough bandwidth to carry the payment. func newBandwidthManager(graph Graph, sourceNode route.Vertex, - linkQuery getLinkQuery) (*bandwidthManager, error) { + linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], + trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { manager := &bandwidthManager{ - getLink: linkQuery, - localChans: make(map[lnwire.ShortChannelID]struct{}), + getLink: linkQuery, + localChans: make(map[lnwire.ShortChannelID]struct{}), + firstHopBlob: firstHopBlob, + trafficShaper: trafficShaper, } // First, we'll collect the set of outbound edges from the target @@ -89,17 +134,111 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, return 0 } - // If our link isn't currently in a state where it can add another - // outgoing htlc, treat the link as unusable. - if err := link.MayAddOutgoingHtlc(amount); err != nil { - log.Warnf("ShortChannelID=%v: cannot add outgoing htlc: %v", - cid, err) + // bandwidthResult is an inline type that we'll use to pass the + // bandwidth result from the external traffic shaper to the main logic + // below. + type bandwidthResult struct { + // bandwidth is the available bandwidth for the channel as + // reported by the external traffic shaper. If the external + // traffic shaper is not handling the channel, this value will + // be fn.None + bandwidth fn.Option[lnwire.MilliSatoshi] + + // htlcAmount is the amount we're going to use to check if we + // can add another HTLC to the channel. If the external traffic + // shaper is handling the channel, we'll use 0 to just sanity + // check the number of HTLCs on the channel, since we don't know + // the actual HTLC amount that will be sent. + htlcAmount fn.Option[lnwire.MilliSatoshi] + } + + var ( + // We will pass the link bandwidth to the external traffic + // shaper. This is the current best estimate for the available + // bandwidth for the link. + linkBandwidth = link.Bandwidth() + + bandwidthErr = func(err error) fn.Result[bandwidthResult] { + return fn.Err[bandwidthResult](err) + } + ) + + result, err := fn.MapOptionZ( + b.trafficShaper, + func(ts TlvTrafficShaper) fn.Result[bandwidthResult] { + fundingBlob := link.FundingCustomBlob() + shouldHandle, err := ts.ShouldHandleTraffic( + cid, fundingBlob, + ) + if err != nil { + return bandwidthErr(fmt.Errorf("traffic "+ + "shaper failed to decide whether to "+ + "handle traffic: %w", err)) + } + + log.Debugf("ShortChannelID=%v: external traffic "+ + "shaper is handling traffic: %v", cid, + shouldHandle) + + // If this channel isn't handled by the external traffic + // shaper, we'll return early. + if !shouldHandle { + return fn.Ok(bandwidthResult{}) + } + + // Ask for a specific bandwidth to be used for the + // channel. + commitmentBlob := link.CommitmentCustomBlob() + auxBandwidth, err := ts.PaymentBandwidth( + b.firstHopBlob, commitmentBlob, linkBandwidth, + ) + if err != nil { + return bandwidthErr(fmt.Errorf("failed to get "+ + "bandwidth from external traffic "+ + "shaper: %w", err)) + } + + log.Debugf("ShortChannelID=%v: external traffic "+ + "shaper reported available bandwidth: %v", cid, + auxBandwidth) + + // We don't know the actual HTLC amount that will be + // sent using the custom channel. But we'll still want + // to make sure we can add another HTLC, using the + // MayAddOutgoingHtlc method below. Passing 0 into that + // method will use the minimum HTLC value for the + // channel, which is okay to just check we don't exceed + // the max number of HTLCs on the channel. A proper + // balance check is done elsewhere. + return fn.Ok(bandwidthResult{ + bandwidth: fn.Some(auxBandwidth), + htlcAmount: fn.Some[lnwire.MilliSatoshi](0), + }) + }, + ).Unpack() + if err != nil { + log.Errorf("ShortChannelID=%v: failed to get bandwidth from "+ + "external traffic shaper: %v", cid, err) + return 0 } - // Otherwise, we'll return the current best estimate for the available - // bandwidth for the link. - return link.Bandwidth() + htlcAmount := result.htlcAmount.UnwrapOr(amount) + + // If our link isn't currently in a state where it can add another + // outgoing htlc, treat the link as unusable. + if err := link.MayAddOutgoingHtlc(htlcAmount); err != nil { + log.Warnf("ShortChannelID=%v: cannot add outgoing "+ + "htlc with amount %v: %v", cid, htlcAmount, err) + return 0 + } + + // If the external traffic shaper determined the bandwidth, we'll return + // that value, even if it is zero (which would mean no bandwidth is + // available on that channel). + reportedBandwidth := result.bandwidth.UnwrapOr(linkBandwidth) + + return reportedBandwidth } // availableChanBandwidth returns the total available bandwidth for a channel @@ -116,3 +255,9 @@ func (b *bandwidthManager) availableChanBandwidth(channelID uint64, return b.getBandwidth(shortID, amount), true } + +// firstHopCustomBlob returns the custom blob for the first hop of the payment, +// if available. +func (b *bandwidthManager) firstHopCustomBlob() fn.Option[tlv.Blob] { + return b.firstHopBlob +} diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index ef12d6973..4876f09c7 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -5,8 +5,10 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -115,6 +117,8 @@ func TestBandwidthManager(t *testing.T) { m, err := newBandwidthManager( g, sourceNode.pubkey, testCase.linkQuery, + fn.None[[]byte](), + fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), ) require.NoError(t, err) @@ -126,3 +130,35 @@ func TestBandwidthManager(t *testing.T) { }) } } + +type mockTrafficShaper struct{} + +// ShouldHandleTraffic is called in order to check if the channel identified +// by the provided channel ID may have external mechanisms that would +// allow it to carry out the payment. +func (*mockTrafficShaper) ShouldHandleTraffic(_ lnwire.ShortChannelID, + _ fn.Option[tlv.Blob]) (bool, error) { + + return true, nil +} + +// PaymentBandwidth returns the available bandwidth for a custom channel +// decided by the given channel aux blob and HTLC blob. A return value +// of 0 means there is no bandwidth available. To find out if a channel +// is a custom channel that should be handled by the traffic shaper, the +// HandleTraffic method should be called first. +func (*mockTrafficShaper) PaymentBandwidth(_, _ fn.Option[tlv.Blob], + linkBandwidth lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) { + + return linkBandwidth, nil +} + +// ProduceHtlcExtraData is a function that, based on the previous extra +// data blob of an HTLC, may produce a different blob or modify the +// amount of bitcoin this htlc should carry. +func (*mockTrafficShaper) ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, + _ lnwire.CustomRecords) (lnwire.MilliSatoshi, lnwire.CustomRecords, + error) { + + return totalAmount, nil, nil +} diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index ee6fed295..785fa1a50 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -8,9 +8,11 @@ import ( "time" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/zpay32" "github.com/stretchr/testify/require" ) @@ -35,6 +37,10 @@ func (m *mockBandwidthHints) availableChanBandwidth(channelID uint64, return balance, ok } +func (m *mockBandwidthHints) firstHopCustomBlob() fn.Option[tlv.Blob] { + return fn.None[tlv.Blob]() +} + // integratedRoutingContext defines the context in which integrated routing // tests run. type integratedRoutingContext struct { @@ -227,6 +233,9 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, // Find a route. route, err := session.RequestRoute( amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0, + lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, + }, ) if err != nil { return attempts, err diff --git a/routing/integrated_routing_test.go b/routing/integrated_routing_test.go index 0e873d664..4a2447b48 100644 --- a/routing/integrated_routing_test.go +++ b/routing/integrated_routing_test.go @@ -296,7 +296,7 @@ func testMppSend(t *testing.T, testCase *mppSendTestCase) { case err == nil && testCase.expectedFailure: t.Fatal("expected payment to fail") case err != nil && !testCase.expectedFailure: - t.Fatal("expected payment to succeed") + t.Fatalf("expected payment to succeed, got %v", err) } if len(attempts) != testCase.expectedAttempts { diff --git a/routing/mock_test.go b/routing/mock_test.go index 306c18210..99d56c68b 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -9,12 +9,14 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/shards" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/mock" ) @@ -104,7 +106,8 @@ type mockPaymentSessionSourceOld struct { var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil) func (m *mockPaymentSessionSourceOld) NewPaymentSession( - _ *LightningPayment) (PaymentSession, error) { + _ *LightningPayment, _ fn.Option[tlv.Blob], + _ fn.Option[TlvTrafficShaper]) (PaymentSession, error) { return &mockPaymentSessionOld{ routes: m.routes, @@ -166,7 +169,8 @@ type mockPaymentSessionOld struct { var _ PaymentSession = (*mockPaymentSessionOld)(nil) func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, - _, height uint32) (*route.Route, error) { + _, height uint32, _ lnwire.CustomRecords) (*route.Route, + error) { if m.release != nil { m.release <- struct{}{} @@ -630,9 +634,10 @@ type mockPaymentSessionSource struct { var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( - payment *LightningPayment) (PaymentSession, error) { + payment *LightningPayment, firstHopBlob fn.Option[tlv.Blob], + tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { - args := m.Called(payment) + args := m.Called(payment, firstHopBlob, tlvShaper) return args.Get(0).(PaymentSession), args.Error(1) } @@ -690,9 +695,12 @@ type mockPaymentSession struct { var _ PaymentSession = (*mockPaymentSession)(nil) func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) { + activeShards, height uint32, + firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { - args := m.Called(maxAmt, feeLimit, activeShards, height) + args := m.Called( + maxAmt, feeLimit, activeShards, height, firstHopCustomRecords, + ) // Type assertion on nil will fail, so we check and return here. if args.Get(0) == nil { @@ -897,6 +905,14 @@ func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error { return m.mayAddOutgoingErr } +func (m *mockLink) FundingCustomBlob() fn.Option[tlv.Blob] { + return fn.None[tlv.Blob]() +} + +func (m *mockLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { + return fn.None[tlv.Blob]() +} + type mockShardTracker struct { mock.Mock } diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 2c7fd3b21..96ff79608 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -366,6 +366,7 @@ func (p *paymentLifecycle) requestRoute( rt, err := p.paySession.RequestRoute( ps.RemainingAmt, remainingFees, uint32(ps.NumAttemptsInFlight), uint32(p.currentHeight), + p.firstHopCustomRecords, ) // Exit early if there's no error. diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 4df27523f..315c1bad5 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/wait" @@ -28,7 +29,11 @@ func createTestPaymentLifecycle() *paymentLifecycle { paymentHash := lntypes.Hash{1, 2, 3} quitChan := make(chan struct{}) rt := &ChannelRouter{ - cfg: &Config{}, + cfg: &Config{ + TrafficShaper: fn.Some[TlvTrafficShaper]( + &mockTrafficShaper{}, + ), + }, quit: quitChan, } @@ -78,6 +83,9 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { Payer: mockPayer, Clock: mockClock, MissionControl: mockMissionControl, + TrafficShaper: fn.Some[TlvTrafficShaper]( + &mockTrafficShaper{}, + ), }, quit: quitChan, } @@ -372,6 +380,7 @@ func TestRequestRouteSucceed(t *testing.T) { // Mock the paySession's `RequestRoute` method to return no error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, + mock.Anything, ).Return(dummyRoute, nil) result, err := p.requestRoute(ps) @@ -408,6 +417,7 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, + mock.Anything, ).Return(nil, errDummy) result, err := p.requestRoute(ps) @@ -442,6 +452,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { // type. m.paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, + mock.Anything, ).Return(nil, errNoTlvPayload) // The payment should be failed with reason no route. @@ -489,6 +500,7 @@ func TestRequestRouteFailPaymentError(t *testing.T) { // Mock the paySession's `RequestRoute` method to return an error. paySession.On("RequestRoute", mock.Anything, mock.Anything, mock.Anything, mock.Anything, + mock.Anything, ).Return(nil, errNoTlvPayload) result, err := p.requestRoute(ps) @@ -865,7 +877,7 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) { // 4. mock requestRoute to return an error. m.paySession.On("RequestRoute", paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), - uint32(p.currentHeight), + uint32(p.currentHeight), mock.Anything, ).Return(nil, errDummy).Once() // Send the payment and assert it failed. @@ -911,7 +923,7 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) { // 4. mock requestRoute to return an route. m.paySession.On("RequestRoute", paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), - uint32(p.currentHeight), + uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() // 5. mock shardTracker used in `createNewPaymentAttempt` to return an @@ -971,7 +983,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) { // 4. mock requestRoute to return an route. m.paySession.On("RequestRoute", paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), - uint32(p.currentHeight), + uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() // 5. mock `registerAttempt` to return an attempt. @@ -1063,7 +1075,7 @@ func TestResumePaymentSuccess(t *testing.T) { // 1.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), - uint32(p.currentHeight), + uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() // 1.5. mock `registerAttempt` to return an attempt. @@ -1164,7 +1176,7 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { // 1.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", paymentAmt, p.feeLimit, uint32(ps.NumAttemptsInFlight), - uint32(p.currentHeight), + uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() // Create two attempt IDs here. @@ -1226,7 +1238,7 @@ func TestResumePaymentSuccessWithTwoAttempts(t *testing.T) { // 2.4. mock requestRoute to return an route. m.paySession.On("RequestRoute", paymentAmt/2, p.feeLimit, uint32(ps.NumAttemptsInFlight), - uint32(p.currentHeight), + uint32(p.currentHeight), mock.Anything, ).Return(rt, nil).Once() // 2.5. mock `registerAttempt` to return an attempt. diff --git a/routing/payment_session.go b/routing/payment_session.go index 00b4ab70e..b3c131cc3 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -139,7 +139,9 @@ type PaymentSession interface { // A noRouteError is returned if a non-critical error is encountered // during path finding. RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) + activeShards, height uint32, + firstHopCustomRecords lnwire.CustomRecords) (*route.Route, + error) // UpdateAdditionalEdge takes an additional channel edge policy // (private channels) and applies the update from the message. Returns @@ -243,7 +245,8 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, // NOTE: This function is safe for concurrent access. // NOTE: Part of the PaymentSession interface. func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, - activeShards, height uint32) (*route.Route, error) { + activeShards, height uint32, + firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) { if p.empty { return nil, errEmptyPaySession @@ -284,9 +287,9 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // client-side MTU that we'll attempt to respect at all times. maxShardActive := p.payment.MaxShardAmt != nil if maxShardActive && maxAmt > *p.payment.MaxShardAmt { - p.log.Debug("Clamping payment attempt from %v to %v due to "+ - "max shard size of %v", maxAmt, - *p.payment.MaxShardAmt, maxAmt) + p.log.Debugf("Clamping payment attempt from %v to %v due to "+ + "max shard size of %v", maxAmt, *p.payment.MaxShardAmt, + maxAmt) maxAmt = *p.payment.MaxShardAmt } diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 46e7a42aa..c89d6a8e5 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -4,8 +4,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/zpay32" ) @@ -49,12 +51,14 @@ type SessionSource struct { // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the // payment's destination. -func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( - PaymentSession, error) { +func (m *SessionSource) NewPaymentSession(p *LightningPayment, + firstHopBlob fn.Option[tlv.Blob], + trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { getBandwidthHints := func(graph Graph) (bandwidthHints, error) { return newBandwidthManager( graph, m.SourceNode.PubKeyBytes, m.GetLink, + firstHopBlob, trafficShaper, ) } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index f6873aa75..34f835682 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -235,6 +235,9 @@ func TestRequestRoute(t *testing.T) { route, err := session.RequestRoute( payment.Amount, payment.FeeLimit, 0, height, + lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 123: []byte{1, 2, 3}, + }, ) if err != nil { t.Fatal(err) diff --git a/routing/router.go b/routing/router.go index de5f45fd8..e5769151b 100644 --- a/routing/router.go +++ b/routing/router.go @@ -29,6 +29,7 @@ import ( "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/shards" + "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/zpay32" ) @@ -154,7 +155,10 @@ type PaymentSessionSource interface { // routes to the given target. An optional set of routing hints can be // provided in order to populate additional edges to explore when // finding a path to the payment's destination. - NewPaymentSession(p *LightningPayment) (PaymentSession, error) + NewPaymentSession(p *LightningPayment, + firstHopBlob fn.Option[tlv.Blob], + trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, + error) // NewPaymentSessionEmpty creates a new paymentSession instance that is // empty, and will be exhausted immediately. Used for failure reporting @@ -290,6 +294,10 @@ type Config struct { // // TODO(yy): remove it once the root cause of stuck payments is found. ClosedSCIDs map[lnwire.ShortChannelID]struct{} + + // TrafficShaper is an optional traffic shaper that can be used to + // control the outgoing channel of a payment. + TrafficShaper fn.Option[TlvTrafficShaper] } // EdgeLocator is a struct used to identify a specific edge. @@ -517,6 +525,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // eliminate certain routes early on in the path finding process. bandwidthHints, err := newBandwidthManager( r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, + fn.None[tlv.Blob](), r.cfg.TrafficShaper, ) if err != nil { return nil, 0, err @@ -1009,10 +1018,29 @@ func spewPayment(payment *LightningPayment) lnutils.LogClosure { func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( PaymentSession, shards.ShardTracker, error) { + // Assemble any custom data we want to send to the first hop only. + var firstHopData fn.Option[tlv.Blob] + if len(payment.FirstHopCustomRecords) > 0 { + if err := payment.FirstHopCustomRecords.Validate(); err != nil { + return nil, nil, fmt.Errorf("invalid first hop custom "+ + "records: %w", err) + } + + firstHopBlob, err := payment.FirstHopCustomRecords.Serialize() + if err != nil { + return nil, nil, fmt.Errorf("unable to serialize "+ + "first hop custom records: %w", err) + } + + firstHopData = fn.Some(firstHopBlob) + } + // Before starting the HTLC routing attempt, we'll create a fresh // payment session which will report our errors back to mission // control. - paySession, err := r.cfg.SessionSource.NewPaymentSession(payment) + paySession, err := r.cfg.SessionSource.NewPaymentSession( + payment, firstHopData, r.cfg.TrafficShaper, + ) if err != nil { return nil, nil, err } @@ -1277,6 +1305,11 @@ func (r *ChannelRouter) sendPayment(ctx context.Context, return [32]byte{}, nil, err } + // Validate the custom records before we attempt to send the payment. + if err := firstHopCustomRecords.Validate(); err != nil { + return [32]byte{}, nil, err + } + // Now set up a paymentLifecycle struct with these params, such that we // can resume the payment from the current state. p := newPaymentLifecycle( @@ -1327,7 +1360,7 @@ func (e ErrNoChannel) Error() string { // outgoing channel, use the outgoingChan parameter. func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], hops []route.Vertex, outgoingChan *uint64, finalCltvDelta int32, - payAddr *[32]byte, _ fn.Option[[]byte]) (*route.Route, + payAddr *[32]byte, firstHopBlob fn.Option[[]byte]) (*route.Route, error) { log.Tracef("BuildRoute called: hopsCount=%v, amt=%v", len(hops), amt) @@ -1342,7 +1375,8 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := newBandwidthManager( - r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, + r.cfg.RoutingGraph, r.cfg.SelfNode, r.cfg.GetLink, firstHopBlob, + r.cfg.TrafficShaper, ) if err != nil { return nil, err diff --git a/routing/router_test.go b/routing/router_test.go index 676fde4cc..7726091fc 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -164,6 +164,9 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, Clock: clock.NewTestClock(time.Unix(1, 0)), ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate, ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[TlvTrafficShaper]( + &mockTrafficShaper{}, + ), }) require.NoError(t, router.Start(), "unable to start router") @@ -2189,7 +2192,8 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), }} // Register mockers with the expected method calls. @@ -2273,7 +2277,8 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), }} // Expect an error to be returned. @@ -2328,7 +2333,8 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), }} // Create the error to be returned. @@ -2411,7 +2417,8 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), }} // Create the error to be returned. @@ -2498,7 +2505,8 @@ func TestSendToRouteTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), }} // Create the error to be returned.