From ff30ff40bfb0f60ee8097df64fbee90cb5b60b01 Mon Sep 17 00:00:00 2001 From: ziggie Date: Sat, 3 Feb 2024 11:59:29 +0000 Subject: [PATCH] multi: Fix final hop payload size for blinded rt. The final hop size is calculated differently therefore we extract the logic in its own function and also account for the case where the final hop might be a blinded hop. --- lnrpc/routerrpc/router_backend.go | 1 + routing/pathfind.go | 70 +++++++++---- routing/pathfind_test.go | 162 +++++++++++++++++++++++++++++- 3 files changed, 211 insertions(+), 22 deletions(-) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index d8f1aa04a..0422f2a9f 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -390,6 +390,7 @@ func (r *RouterBackend) parseQueryRoutesRequest(in *lnrpc.QueryRoutesRequest) ( DestCustomRecords: record.CustomSet(in.DestCustomRecords), CltvLimit: cltvLimit, DestFeatures: destinationFeatures, + BlindedPayment: blindedPmt, } // Pass along an outgoing channel restriction if specified. diff --git a/routing/pathfind.go b/routing/pathfind.go index 81bdefc34..9380c7c7d 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -411,6 +411,10 @@ type RestrictParams struct { // Metadata is additional data that is sent along with the payment to // the payee. Metadata []byte + + // BlindedPayment is necessary to determine the hop size of the + // last/exit hop. + BlindedPayment *BlindedPayment } // PathFindingConfig defines global parameters that control the trade-off in @@ -635,23 +639,10 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } } - // Build a preliminary destination hop structure to obtain the payload - // size. - var mpp *record.MPP - if r.PaymentAddr != nil { - mpp = record.NewMPP(amt, *r.PaymentAddr) - } - - finalHop := route.Hop{ - AmtToForward: amt, - OutgoingTimeLock: uint32(finalHtlcExpiry), - CustomRecords: r.DestCustomRecords, - LegacyPayload: !features.HasFeature( - lnwire.TLVOnionPayloadOptional, - ), - MPP: mpp, - Metadata: r.Metadata, - } + // The payload size of the final hop differ from intermediate hops + // and depends on whether the destination is blinded or not. + lastHopPayloadSize := lastHopPayloadSize(r, finalHtlcExpiry, amt, + !features.HasFeature(lnwire.TLVOnionPayloadOptional)) // We can't always assume that the end destination is publicly // advertised to the network so we'll manually include the target node. @@ -669,7 +660,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, amountToReceive: amt, incomingCltv: finalHtlcExpiry, probability: 1, - routingInfoSize: finalHop.PayloadSize(0), + routingInfoSize: lastHopPayloadSize, } // Calculate the absolute cltv limit. Use uint64 to prevent an overflow @@ -1112,3 +1103,46 @@ func getProbabilityBasedDist(weight int64, probability float64, return int64(dist) } + +// lastHopPayloadSize calculates the payload size of the final hop in a route. +// It depends on the tlv types which are present and also whether the hop is +// part of a blinded route or not. +func lastHopPayloadSize(r *RestrictParams, finalHtlcExpiry int32, + amount lnwire.MilliSatoshi, legacy bool) uint64 { + + if r.BlindedPayment != nil { + blindedPath := r.BlindedPayment.BlindedPath.BlindedHops + blindedPoint := r.BlindedPayment.BlindedPath.BlindingPoint + + encryptedData := blindedPath[len(blindedPath)-1].CipherText + finalHop := route.Hop{ + AmtToForward: amount, + OutgoingTimeLock: uint32(finalHtlcExpiry), + LegacyPayload: false, + EncryptedData: encryptedData, + } + if len(blindedPath) == 1 { + finalHop.BlindingPoint = blindedPoint + } + + // The final hop does not have a short chanID set. + return finalHop.PayloadSize(0) + } + + var mpp *record.MPP + if r.PaymentAddr != nil { + mpp = record.NewMPP(amount, *r.PaymentAddr) + } + + finalHop := route.Hop{ + AmtToForward: amount, + OutgoingTimeLock: uint32(finalHtlcExpiry), + CustomRecords: r.DestCustomRecords, + LegacyPayload: legacy, + MPP: mpp, + Metadata: r.Metadata, + } + + // The final hop does not have a short chanID set. + return finalHop.PayloadSize(0) +} diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 9b54b69b8..63ec579e3 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -1326,10 +1326,10 @@ func runPathFindingMaxPayloadRestriction(t *testing.T, useCache bool) { }{ { // The final hop payload size needs to be considered - // as well and because its treated differently than the - // intermediate hops this tests choose to use the legacy - // payload format to have a constant final hop payload - // size. + // as well and because it's treated differently than the + // intermediate hops the following tests choose to use + // the legacy payload format to have a constant final + // hop payload size. name: "route max payload size (1300)", mockedPayloadSize: 1300 - sphinx.LegacyHopDataSize, }, @@ -3451,3 +3451,157 @@ func TestBlindedRouteConstruction(t *testing.T) { require.NoError(t, err) require.Equal(t, expectedRoute, route) } + +// TestLastHopPayloadSize tests the final hop payload size. The final hop +// payload structure differes from the intermediate hop payload for both the +// non-blinded and blinded case. +func TestLastHopPayloadSize(t *testing.T) { + t.Parallel() + + var ( + metadata = []byte{21, 22} + customRecords = map[uint64][]byte{ + record.CustomTypeStart: {1, 2, 3}, + } + sizeEncryptedData = 100 + encrypedData = bytes.Repeat( + []byte{1}, sizeEncryptedData, + ) + _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) + paymentAddr = &[32]byte{1} + amtToForward = lnwire.MilliSatoshi(10000) + finalHopExpiry int32 = 144 + + oneHopBlindedPayment = &BlindedPayment{ + BlindedPath: &sphinx.BlindedPath{ + BlindedHops: []*sphinx.BlindedHopInfo{ + { + CipherText: encrypedData, + }, + }, + BlindingPoint: blindedPoint, + }, + } + twoHopBlindedPayment = &BlindedPayment{ + BlindedPath: &sphinx.BlindedPath{ + BlindedHops: []*sphinx.BlindedHopInfo{ + { + CipherText: encrypedData, + }, + { + CipherText: encrypedData, + }, + }, + BlindingPoint: blindedPoint, + }, + } + ) + + testCases := []struct { + name string + restrictions *RestrictParams + finalHopExpiry int32 + amount lnwire.MilliSatoshi + legacy bool + }{ + { + name: "Non blinded final hop", + restrictions: &RestrictParams{ + PaymentAddr: paymentAddr, + DestCustomRecords: customRecords, + Metadata: metadata, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + legacy: false, + }, + { + name: "Non blinded final hop legacy", + restrictions: &RestrictParams{ + // The legacy encoding has no ability to include + // those extra data we expect that this data is + // ignored. + PaymentAddr: paymentAddr, + DestCustomRecords: customRecords, + Metadata: metadata, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + legacy: true, + }, + { + name: "Blinded final hop introduction point", + restrictions: &RestrictParams{ + BlindedPayment: oneHopBlindedPayment, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + }, + { + name: "Blinded final hop of a two hop payment", + restrictions: &RestrictParams{ + BlindedPayment: twoHopBlindedPayment, + }, + amount: amtToForward, + finalHopExpiry: finalHopExpiry, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var mpp *record.MPP + if tc.restrictions.PaymentAddr != nil { + mpp = record.NewMPP( + tc.amount, *tc.restrictions.PaymentAddr, + ) + } + + var finalHop route.Hop + if tc.restrictions.BlindedPayment != nil { + blindedPath := tc.restrictions.BlindedPayment. + BlindedPath.BlindedHops + + blindedPoint := tc.restrictions.BlindedPayment. + BlindedPath.BlindingPoint + + //nolint:lll + finalHop = route.Hop{ + AmtToForward: tc.amount, + OutgoingTimeLock: uint32(tc.finalHopExpiry), + LegacyPayload: false, + EncryptedData: blindedPath[len(blindedPath)-1].CipherText, + } + if len(blindedPath) == 1 { + finalHop.BlindingPoint = blindedPoint + } + } else { + //nolint:lll + finalHop = route.Hop{ + LegacyPayload: tc.legacy, + AmtToForward: tc.amount, + OutgoingTimeLock: uint32(tc.finalHopExpiry), + Metadata: tc.restrictions.Metadata, + MPP: mpp, + CustomRecords: tc.restrictions.DestCustomRecords, + } + } + + payLoad, err := createHopPayload(finalHop, 0, true) + require.NoErrorf(t, err, "failed to create hop payload") + + expectedPayloadSize := lastHopPayloadSize( + tc.restrictions, tc.finalHopExpiry, + tc.amount, tc.legacy, + ) + + require.Equal( + t, expectedPayloadSize, + uint64(payLoad.NumBytes()), + ) + }) + } +}