diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 420259de2..d8f1aa04a 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -15,7 +15,6 @@ import ( "github.com/btcsuite/btcd/wire" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnrpc" @@ -280,7 +279,7 @@ func (r *RouterBackend) parseQueryRoutesRequest(in *lnrpc.QueryRoutesRequest) ( // inside of the path rather than the request's fields. var ( targetPubKey *route.Vertex - routeHintEdges map[route.Vertex][]*models.CachedEdgePolicy + routeHintEdges map[route.Vertex][]routing.AdditionalEdge blindedPmt *routing.BlindedPayment // finalCLTVDelta varies depending on whether we're sending to diff --git a/routing/blinding.go b/routing/blinding.go index 50d3b6f79..61d303a0c 100644 --- a/routing/blinding.go +++ b/routing/blinding.go @@ -99,7 +99,7 @@ func (b *BlindedPayment) toRouteHints() RouteHints { hintCount := len(b.BlindedPath.BlindedHops) - 1 hints := make( - map[route.Vertex][]*models.CachedEdgePolicy, hintCount, + RouteHints, hintCount, ) // Start at the unblinded introduction node, because our pathfinding @@ -116,25 +116,31 @@ func (b *BlindedPayment) toRouteHints() RouteHints { // will ensure that pathfinding provides sufficient fees/delay for the // blinded portion to the introduction node. firstBlindedHop := b.BlindedPath.BlindedHops[1].BlindedNodePub - hints[fromNode] = []*models.CachedEdgePolicy{ - { - TimeLockDelta: b.CltvExpiryDelta, - MinHTLC: lnwire.MilliSatoshi(b.HtlcMinimum), - MaxHTLC: lnwire.MilliSatoshi(b.HtlcMaximum), - FeeBaseMSat: lnwire.MilliSatoshi(b.BaseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi( - b.ProportionalFee, - ), - ToNodePubKey: func() route.Vertex { - return route.NewVertex( - // The first node in this slice is - // the introduction node, so we start - // at index 1 to get the first blinded - // relaying node. - firstBlindedHop, - ) - }, - ToNodeFeatures: features, + edgePolicy := &models.CachedEdgePolicy{ + TimeLockDelta: b.CltvExpiryDelta, + MinHTLC: lnwire.MilliSatoshi(b.HtlcMinimum), + MaxHTLC: lnwire.MilliSatoshi(b.HtlcMaximum), + FeeBaseMSat: lnwire.MilliSatoshi(b.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi( + b.ProportionalFee, + ), + ToNodePubKey: func() route.Vertex { + return route.NewVertex( + // The first node in this slice is + // the introduction node, so we start + // at index 1 to get the first blinded + // relaying node. + firstBlindedHop, + ) + }, + ToNodeFeatures: features, + } + + hints[fromNode] = []AdditionalEdge{ + &BlindedEdge{ + policy: edgePolicy, + cipherText: b.BlindedPath.BlindedHops[0].CipherText, + blindingPoint: b.BlindedPath.BlindingPoint, }, } @@ -156,15 +162,19 @@ func (b *BlindedPayment) toRouteHints() RouteHints { b.BlindedPath.BlindedHops[nextHopIdx].BlindedNodePub, ) - hint := &models.CachedEdgePolicy{ + edgePolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return nextNode }, ToNodeFeatures: features, } - hints[fromNode] = []*models.CachedEdgePolicy{ - hint, + hints[fromNode] = []AdditionalEdge{ + &BlindedEdge{ + policy: edgePolicy, + cipherText: b.BlindedPath.BlindedHops[i]. + CipherText, + }, } } diff --git a/routing/blinding_test.go b/routing/blinding_test.go index 5dc71354f..561ace6fc 100644 --- a/routing/blinding_test.go +++ b/routing/blinding_test.go @@ -1,10 +1,12 @@ package routing import ( + "bytes" "testing" "github.com/btcsuite/btcd/btcec/v2" sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -94,6 +96,12 @@ func TestBlindedPaymentToHints(t *testing.T) { htlcMin uint64 = 100 htlcMax uint64 = 100_000_000 + sizeEncryptedData = 100 + cipherText = bytes.Repeat( + []byte{1}, sizeEncryptedData, + ) + _, blindedPoint = btcec.PrivKeyFromBytes([]byte{5}) + rawFeatures = lnwire.NewRawFeatureVector( lnwire.AMPOptional, ) @@ -108,6 +116,7 @@ func TestBlindedPaymentToHints(t *testing.T) { blindedPayment := &BlindedPayment{ BlindedPath: &sphinx.BlindedPath{ IntroductionPoint: pk1, + BlindingPoint: blindedPoint, BlindedHops: []*sphinx.BlindedHopInfo{ {}, }, @@ -125,40 +134,52 @@ func TestBlindedPaymentToHints(t *testing.T) { blindedPayment.BlindedPath.BlindedHops = []*sphinx.BlindedHopInfo{ { BlindedNodePub: pkb1, + CipherText: cipherText, }, { BlindedNodePub: pkb2, + CipherText: cipherText, }, { BlindedNodePub: pkb3, + CipherText: cipherText, }, } expected := RouteHints{ v1: { - { - TimeLockDelta: cltvDelta, - MinHTLC: lnwire.MilliSatoshi(htlcMin), - MaxHTLC: lnwire.MilliSatoshi(htlcMax), - FeeBaseMSat: lnwire.MilliSatoshi(baseFee), - FeeProportionalMillionths: lnwire.MilliSatoshi( - ppmFee, - ), - ToNodePubKey: func() route.Vertex { - return vb2 + //nolint:lll + &BlindedEdge{ + policy: &models.CachedEdgePolicy{ + TimeLockDelta: cltvDelta, + MinHTLC: lnwire.MilliSatoshi(htlcMin), + MaxHTLC: lnwire.MilliSatoshi(htlcMax), + FeeBaseMSat: lnwire.MilliSatoshi(baseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi( + ppmFee, + ), + ToNodePubKey: func() route.Vertex { + return vb2 + }, + ToNodeFeatures: features, }, - ToNodeFeatures: features, + blindingPoint: blindedPoint, + cipherText: cipherText, }, }, vb2: { - { - ToNodePubKey: func() route.Vertex { - return vb3 + &BlindedEdge{ + policy: &models.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return vb3 + }, + ToNodeFeatures: features, }, - ToNodeFeatures: features, + cipherText: cipherText, }, }, } + actual := blindedPayment.toRouteHints() require.Equal(t, len(expected), len(actual)) @@ -170,13 +191,24 @@ func TestBlindedPaymentToHints(t *testing.T) { require.Len(t, actualHint, 1) // We can't assert that our functions are equal, so we check - // their output and then mark as nil so that we can use + // their output and then mark them as nil so that we can use // require.Equal for all our other fields. - require.Equal(t, expectedHint[0].ToNodePubKey(), - actualHint[0].ToNodePubKey()) + require.Equal(t, expectedHint[0].EdgePolicy().ToNodePubKey(), + actualHint[0].EdgePolicy().ToNodePubKey()) - actualHint[0].ToNodePubKey = nil - expectedHint[0].ToNodePubKey = nil + actualHint[0].EdgePolicy().ToNodePubKey = nil + expectedHint[0].EdgePolicy().ToNodePubKey = nil + + // The arguments we use for the payload do not matter as long as + // both functions return the same payload. + expectedPayloadSize := expectedHint[0].IntermediatePayloadSize( + 0, 0, false, 0, + ) + actualPayloadSize := actualHint[0].IntermediatePayloadSize( + 0, 0, false, 0, + ) + + require.Equal(t, expectedPayloadSize, actualPayloadSize) require.Equal(t, expectedHint[0], actualHint[0]) } diff --git a/routing/pathfind.go b/routing/pathfind.go index 642dab90b..81bdefc34 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -88,7 +88,7 @@ var ( // of the edge. type edgePolicyWithSource struct { sourceNode route.Vertex - edge *models.CachedEdgePolicy + edge AdditionalEdge } // finalHopParams encapsulates various parameters for route construction that @@ -355,8 +355,9 @@ type graphParams struct { // additionalEdges is an optional set of edges that should be // considered during path finding, that is not already found in the - // channel graph. - additionalEdges map[route.Vertex][]*models.CachedEdgePolicy + // channel graph. These can either be private edges for bolt 11 invoices + // or blinded edges when a payment to a blinded path is made. + additionalEdges map[route.Vertex][]AdditionalEdge // bandwidthHints is an interface that provides bandwidth hints that // can provide a better estimate of the current channel bandwidth than @@ -609,7 +610,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount) additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource) - for vertex, outgoingEdgePolicies := range g.additionalEdges { + for vertex, additionalEdges := range g.additionalEdges { // Edges connected to self are always included in the graph, // therefore can be skipped. This prevents us from trying // routes to malformed hop hints. @@ -619,12 +620,13 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Build reverse lookup to find incoming edges. Needed because // search is taken place from target to source. - for _, outgoingEdgePolicy := range outgoingEdgePolicies { + for _, additionalEdge := range additionalEdges { + outgoingEdgePolicy := additionalEdge.EdgePolicy() toVertex := outgoingEdgePolicy.ToNodePubKey() incomingEdgePolicy := &edgePolicyWithSource{ sourceNode: vertex, - edge: outgoingEdgePolicy, + edge: additionalEdge, } additionalEdgesWithSrc[toVertex] = @@ -821,23 +823,30 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // blob. var payloadSize uint64 if fromVertex != source { + // In case the unifiedEdge does not have a payload size + // function supplied we request a graceful shutdown + // because this should never happen. + if edge.hopPayloadSizeFn == nil { + log.Criticalf("No payload size function "+ + "available for edge=%v unable to "+ + "determine payload size: %v", edge, + ErrNoPayLoadSizeFunc) + + return + } + supportsTlv := fromFeatures.HasFeature( lnwire.TLVOnionPayloadOptional, ) - hop := route.Hop{ - AmtToForward: amountToSend, - OutgoingTimeLock: uint32( - toNodeDist.incomingCltv, - ), - LegacyPayload: !supportsTlv, - } - - payloadSize = hop.PayloadSize(edge.policy.ChannelID) + payloadSize = edge.hopPayloadSizeFn( + amountToSend, + uint32(toNodeDist.incomingCltv), + !supportsTlv, edge.policy.ChannelID, + ) } routingInfoSize := toNodeDist.routingInfoSize + payloadSize - // Skip paths that would exceed the maximum routing info size. if routingInfoSize > sphinx.MaxPayloadSize { return @@ -930,9 +939,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // calculations. We set a high capacity to act as if // there is enough liquidity, otherwise the hint would // not have been added by a wallet. + // We also pass the payload size function to the + // graph data so that we calculate the exact payload + // size when evaluating this hop for a route. u.addPolicy( - reverseEdge.sourceNode, reverseEdge.edge, + reverseEdge.sourceNode, + reverseEdge.edge.EdgePolicy(), fakeHopHintCapacity, + reverseEdge.edge.IntermediatePayloadSize, ) } diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 7c4f635df..9b54b69b8 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -746,6 +746,9 @@ func TestPathFinding(t *testing.T) { }, { name: "path finding with additional edges", fn: runPathFindingWithAdditionalEdges, + }, { + name: "path finding max payload restriction", + fn: runPathFindingMaxPayloadRestriction, }, { name: "path finding with redundant additional edges", fn: runPathFindingWithRedundantAdditionalEdges, @@ -1204,7 +1207,7 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { // Create the channel edge going from songoku to doge and include it in // our map of additional edges. - songokuToDoge := &models.CachedEdgePolicy{ + songokuToDogePolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return doge.PubKeyBytes }, @@ -1215,8 +1218,10 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { TimeLockDelta: 9, } - additionalEdges := map[route.Vertex][]*models.CachedEdgePolicy{ - graph.aliasMap["songoku"]: {songokuToDoge}, + additionalEdges := map[route.Vertex][]AdditionalEdge{ + graph.aliasMap["songoku"]: {&PrivateEdge{ + policy: songokuToDogePolicy, + }}, } find := func(r *RestrictParams) ( @@ -1266,6 +1271,122 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { assertExpectedPath(t, graph.aliasMap, path, "songoku", "doge") } +// runPathFindingMaxPayloadRestriction tests the maximum size of a sphinx +// package when creating a route. So we make sure the pathfinder does not return +// a route which is greater than the maximum sphinx package size of 1300 bytes +// defined in BOLT04. +func runPathFindingMaxPayloadRestriction(t *testing.T, useCache bool) { + graph, err := parseTestGraph(t, useCache, basicGraphFilePath) + require.NoError(t, err, "unable to create graph") + + sourceNode, err := graph.graph.SourceNode() + require.NoError(t, err, "unable to fetch source node") + + paymentAmt := lnwire.NewMSatFromSatoshis(100) + + // Create a node doge which is not visible in the graph. + dogePubKeyHex := "03dd46ff29a6941b4a2607525b043ec9b020b3f318a1bf281" + + "536fd7011ec59c882" + dogePubKeyBytes, err := hex.DecodeString(dogePubKeyHex) + require.NoError(t, err, "unable to decode public key") + dogePubKey, err := btcec.ParsePubKey(dogePubKeyBytes) + require.NoError(t, err, "unable to parse public key from bytes") + + doge := &channeldb.LightningNode{} + doge.AddPubKey(dogePubKey) + doge.Alias = "doge" + copy(doge.PubKeyBytes[:], dogePubKeyBytes) + graph.aliasMap["doge"] = doge.PubKeyBytes + + const ( + chanID uint64 = 1337 + finalHtlcExpiry int32 = 0 + ) + + // Create the channel edge going from songoku to doge and later add it + // with the mocked size function to the graph data. + songokuToDogePolicy := &models.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return doge.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), + ChannelID: chanID, + FeeBaseMSat: 1, + FeeProportionalMillionths: 1000, + TimeLockDelta: 9, + } + + // The route has 2 hops. The exit hop (doge) and the hop + // (songoku -> doge). The desired path looks like this: + // source -> songoku -> doge + tests := []struct { + name string + mockedPayloadSize uint64 + err error + }{ + { + // The final hop payload size needs to be considered + // as well and because its treated differently than the + // intermediate hops this tests choose to use the legacy + // payload format to have a constant final hop payload + // size. + name: "route max payload size (1300)", + mockedPayloadSize: 1300 - sphinx.LegacyHopDataSize, + }, + { + // We increase the enrypted data size by one byte. + name: "route 1 bytes bigger than max " + + "payload", + mockedPayloadSize: 1300 - sphinx.LegacyHopDataSize + 1, + err: errNoPathFound, + }, + } + + for _, testCase := range tests { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + restrictions := *noRestrictions + // No tlv payload, this makes sure the final hop uses + // the legacy payload. + restrictions.DestFeatures = lnwire.EmptyFeatureVector() + + // Create the mocked AdditionalEdge and mock the + // corresponding calls. + mockedEdge := &mockAdditionalEdge{} + + mockedEdge.On("EdgePolicy").Return(songokuToDogePolicy) + + mockedEdge.On("IntermediatePayloadSize", + paymentAmt, uint32(finalHtlcExpiry), true, + chanID).Once(). + Return(testCase.mockedPayloadSize) + + additionalEdges := map[route.Vertex][]AdditionalEdge{ + graph.aliasMap["songoku"]: {mockedEdge}, + } + + path, err := dbFindPath( + graph.graph, additionalEdges, + &mockBandwidthHints{}, &restrictions, + testPathFindingConfig, sourceNode.PubKeyBytes, + doge.PubKeyBytes, paymentAmt, 0, + finalHtlcExpiry, + ) + require.ErrorIs(t, err, testCase.err) + + if err == nil { + assertExpectedPath(t, graph.aliasMap, path, + "songoku", "doge") + } + + mockedEdge.AssertExpectations(t) + }) + } +} + // runPathFindingWithRedundantAdditionalEdges asserts that we are able to find // paths to nodes ignoring additional edges that are already known by self node. func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { @@ -1290,7 +1411,7 @@ func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { // Create the channel edge going from alice to bob and include it in // our map of additional edges. - aliceToBob := &models.CachedEdgePolicy{ + aliceToBobPolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return target }, @@ -1301,8 +1422,10 @@ func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { TimeLockDelta: 9, } - additionalEdges := map[route.Vertex][]*models.CachedEdgePolicy{ - ctx.source: {aliceToBob}, + additionalEdges := map[route.Vertex][]AdditionalEdge{ + ctx.source: {&PrivateEdge{ + policy: aliceToBobPolicy, + }}, } path, err := dbFindPath( @@ -2402,7 +2525,8 @@ func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, path []*models.CachedEdgePolicy, nodeAliases ...string) { if len(path) != len(nodeAliases) { - t.Fatal("number of hops and number of aliases do not match") + t.Fatalf("number of hops=(%v) and number of aliases=(%v) do "+ + "not match", len(path), len(nodeAliases)) } for i, hop := range path { @@ -3072,7 +3196,7 @@ func (c *pathFindingTestContext) assertPath(path []*models.CachedEdgePolicy, // dbFindPath calls findPath after getting a db transaction from the database // graph. func dbFindPath(graph *channeldb.ChannelGraph, - additionalEdges map[route.Vertex][]*models.CachedEdgePolicy, + additionalEdges map[route.Vertex][]AdditionalEdge, bandwidthHints bandwidthHints, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, @@ -3230,8 +3354,8 @@ func TestBlindedRouteConstruction(t *testing.T) { edges := []*models.CachedEdgePolicy{ aliceBobEdge, bobCarolEdge, - carolDaveEdge, - daveEveEdge, + carolDaveEdge.EdgePolicy(), + daveEveEdge.EdgePolicy(), } // Total timelock for the route should include: diff --git a/routing/payment_session.go b/routing/payment_session.go index a04a4de55..61496e915 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -163,7 +163,7 @@ type PaymentSession interface { // loop if payment attempts take long enough. An additional set of edges can // also be provided to assist in reaching the payment's destination. type paymentSession struct { - additionalEdges map[route.Vertex][]*models.CachedEdgePolicy + additionalEdges map[route.Vertex][]AdditionalEdge getBandwidthHints func(routingGraph) (bandwidthHints, error) @@ -441,11 +441,12 @@ func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, } for _, edge := range edges { - if edge.ChannelID != channelID { + policy := edge.EdgePolicy() + if policy.ChannelID != channelID { continue } - return edge + return policy } return nil diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 229c932f3..b96a2294b 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -95,9 +95,9 @@ func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession { // RouteHintsToEdges converts a list of invoice route hints to an edge map that // can be passed into pathfinding. func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( - map[route.Vertex][]*models.CachedEdgePolicy, error) { + map[route.Vertex][]AdditionalEdge, error) { - edges := make(map[route.Vertex][]*models.CachedEdgePolicy) + edges := make(map[route.Vertex][]AdditionalEdge) // Traverse through all of the available hop hints and include them in // our edges map, indexed by the public key of the channel's starting @@ -127,7 +127,7 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( // Finally, create the channel edge from the hop hint // and add it to list of edges corresponding to the node // at the start of the channel. - edge := &models.CachedEdgePolicy{ + edgePolicy := &models.CachedEdgePolicy{ ToNodePubKey: func() route.Vertex { return endNode.PubKeyBytes }, @@ -142,6 +142,10 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( TimeLockDelta: hopHint.CLTVExpiryDelta, } + edge := &PrivateEdge{ + policy: edgePolicy, + } + v := route.NewVertex(hopHint.NodeID) edges[v] = append(edges[v], edge) } diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 1c199ff40..67a285159 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -138,7 +138,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { require.Equal(t, 1, len(policies), "should have 1 edge policy") // Check that the policy has been created as expected. - policy := policies[0] + policy := policies[0].EdgePolicy() require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch") require.Equal(t, oldExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch", diff --git a/routing/router.go b/routing/router.go index c602573eb..00c0ded57 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1954,7 +1954,7 @@ type RouteRequest struct { // RouteHints is an alias type for a set of route hints, with the source node // as the map's key and the details of the hint(s) in the edge policy. -type RouteHints map[route.Vertex][]*models.CachedEdgePolicy +type RouteHints map[route.Vertex][]AdditionalEdge // NewRouteRequest produces a new route request for a regular payment or one // to a blinded route, validating that the target, routeHints and finalExpiry diff --git a/routing/router_test.go b/routing/router_test.go index 227899c83..12c8ff729 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3954,7 +3954,7 @@ func TestNewRouteRequest(t *testing.T) { name: "hints and blinded", blindedPayment: blindedMultiHop, routeHints: make( - map[route.Vertex][]*models.CachedEdgePolicy, + map[route.Vertex][]AdditionalEdge, ), err: ErrHintsAndBlinded, }, diff --git a/routing/unified_edges.go b/routing/unified_edges.go index aee168348..c828e9a6e 100644 --- a/routing/unified_edges.go +++ b/routing/unified_edges.go @@ -40,9 +40,12 @@ func newNodeEdgeUnifier(sourceNode, toNode route.Vertex, } // addPolicy adds a single channel policy. Capacity may be zero if unknown -// (light clients). +// (light clients). We expect a non-nil payload size function and will request a +// graceful shutdown if it is not provided as this indicates that edges are +// incorrectly specified. func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, - edge *models.CachedEdgePolicy, capacity btcutil.Amount) { + edge *models.CachedEdgePolicy, capacity btcutil.Amount, + hopPayloadSizeFn PayloadSizeFunc) { localChan := fromNode == u.sourceNode @@ -62,9 +65,20 @@ func (u *nodeEdgeUnifier) addPolicy(fromNode route.Vertex, u.edgeUnifiers[fromNode] = unifier } + // In case no payload size function was provided a graceful shutdown + // is requested, because this function is not used as intended. + if hopPayloadSizeFn == nil { + log.Criticalf("No payloadsize function was provided for the "+ + "edge (chanid=%v) when adding it to the edge unifier "+ + "of node: %v", edge.ChannelID, fromNode) + + return + } + unifier.edges = append(unifier.edges, &unifiedEdge{ - policy: edge, - capacity: capacity, + policy: edge, + capacity: capacity, + hopPayloadSizeFn: hopPayloadSizeFn, }) } @@ -79,9 +93,13 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { return nil } - // Add this policy to the corresponding edgeUnifier. + // Add this policy to the corresponding edgeUnifier. We default + // to the clear hop payload size function because + // `addGraphPolicies` is only used for cleartext intermediate + // hops in a route. u.addPolicy( channel.OtherNode, channel.InPolicy, channel.Capacity, + defaultHopPayloadSize, ) return nil @@ -96,6 +114,12 @@ func (u *nodeEdgeUnifier) addGraphPolicies(g routingGraph) error { type unifiedEdge struct { policy *models.CachedEdgePolicy capacity btcutil.Amount + + // hopPayloadSize supplies an edge with the ability to calculate the + // exact payload size if this edge would be included in a route. This + // is needed because hops of a blinded path differ in their payload + // structure compared to cleartext hops. + hopPayloadSizeFn PayloadSizeFunc } // amtInRange checks whether an amount falls within the valid range for a @@ -202,6 +226,7 @@ func (u *edgeUnifier) getEdgeLocal(amt lnwire.MilliSatoshi, log.Debugf("Skipped edge %v: not enough bandwidth, "+ "bandwidth=%v, amt=%v", edge.policy.ChannelID, bandwidth, amt) + continue } @@ -214,14 +239,16 @@ func (u *edgeUnifier) getEdgeLocal(amt lnwire.MilliSatoshi, log.Debugf("Skipped edge %v: not max bandwidth, "+ "bandwidth=%v, maxBandwidth=%v", bandwidth, maxBandwidth) + continue } maxBandwidth = bandwidth // Update best edge. bestEdge = &unifiedEdge{ - policy: edge.policy, - capacity: edge.capacity, + policy: edge.policy, + capacity: edge.capacity, + hopPayloadSizeFn: edge.hopPayloadSizeFn, } } @@ -234,10 +261,11 @@ func (u *edgeUnifier) getEdgeLocal(amt lnwire.MilliSatoshi, // forwarding context. func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { var ( - bestPolicy *models.CachedEdgePolicy - maxFee lnwire.MilliSatoshi - maxTimelock uint16 - maxCapMsat lnwire.MilliSatoshi + bestPolicy *models.CachedEdgePolicy + maxFee lnwire.MilliSatoshi + maxTimelock uint16 + maxCapMsat lnwire.MilliSatoshi + hopPayloadSizeFn PayloadSizeFunc ) for _, edge := range u.edges { @@ -274,7 +302,6 @@ func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { maxTimelock = lntypes.Max( maxTimelock, edge.policy.TimeLockDelta, ) - // Use the policy that results in the highest fee for this // specific amount. fee := edge.policy.ComputeFee(amt) @@ -282,11 +309,17 @@ func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { log.Debugf("Skipped edge %v due to it produces less "+ "fee: fee=%v, maxFee=%v", edge.policy.ChannelID, fee, maxFee) + continue } maxFee = fee bestPolicy = edge.policy + // The payload size function for edges to a connected peer is + // always the same hence there is not need to find the maximum. + // This also counts for blinded edges where we only have one + // edge to a blinded peer. + hopPayloadSizeFn = edge.hopPayloadSizeFn } // Return early if no channel matches. @@ -308,6 +341,7 @@ func (u *edgeUnifier) getEdgeNetwork(amt lnwire.MilliSatoshi) *unifiedEdge { modifiedEdge := unifiedEdge{policy: &policyCopy} modifiedEdge.policy.TimeLockDelta = maxTimelock modifiedEdge.capacity = maxCapMsat.ToSatoshis() + modifiedEdge.hopPayloadSizeFn = hopPayloadSizeFn return &modifiedEdge } diff --git a/routing/unified_edges_test.go b/routing/unified_edges_test.go index 9b603d78c..043447f52 100644 --- a/routing/unified_edges_test.go +++ b/routing/unified_edges_test.go @@ -41,15 +41,17 @@ func TestNodeEdgeUnifier(t *testing.T) { c2 := btcutil.Amount(8) unifierFilled := newNodeEdgeUnifier(source, toNode, nil) - unifierFilled.addPolicy(fromNode, &p1, c1) - unifierFilled.addPolicy(fromNode, &p2, c2) + unifierFilled.addPolicy(fromNode, &p1, c1, defaultHopPayloadSize) + unifierFilled.addPolicy(fromNode, &p2, c2, defaultHopPayloadSize) unifierNoCapacity := newNodeEdgeUnifier(source, toNode, nil) - unifierNoCapacity.addPolicy(fromNode, &p1, 0) - unifierNoCapacity.addPolicy(fromNode, &p2, 0) + unifierNoCapacity.addPolicy(fromNode, &p1, 0, defaultHopPayloadSize) + unifierNoCapacity.addPolicy(fromNode, &p2, 0, defaultHopPayloadSize) unifierNoInfo := newNodeEdgeUnifier(source, toNode, nil) - unifierNoInfo.addPolicy(fromNode, &models.CachedEdgePolicy{}, 0) + unifierNoInfo.addPolicy( + fromNode, &models.CachedEdgePolicy{}, 0, defaultHopPayloadSize, + ) tests := []struct { name string