diff --git a/routing/heap.go b/routing/heap.go index 261ced8f9..995eff21f 100644 --- a/routing/heap.go +++ b/routing/heap.go @@ -3,7 +3,6 @@ package routing import ( "container/heap" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -39,7 +38,7 @@ type nodeWithDist struct { weight int64 // nextHop is the edge this route comes from. - nextHop *models.CachedEdgePolicy + nextHop *unifiedEdge // routingInfoSize is the total size requirement for the payloads field // in the onion packet from this hop towards the final destination. diff --git a/routing/pathfind.go b/routing/pathfind.go index 53eb0dfb4..99ea20422 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -11,7 +11,6 @@ import ( "github.com/btcsuite/btcd/btcutil" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -50,7 +49,7 @@ const ( type pathFinder = func(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( - []*models.CachedEdgePolicy, float64, error) + []*unifiedEdge, float64, error) var ( // DefaultEstimator is the default estimator used for computing @@ -126,7 +125,7 @@ type finalHopParams struct { // NOTE: If a non-nil blinded path is provided it is assumed to have been // validated by the caller. func newRoute(sourceVertex route.Vertex, - pathEdges []*models.CachedEdgePolicy, currentHeight uint32, + pathEdges []*unifiedEdge, currentHeight uint32, finalHop finalHopParams, blindedPath *sphinx.BlindedPath) ( *route.Route, error) { @@ -149,7 +148,7 @@ func newRoute(sourceVertex route.Vertex, for i := pathLength - 1; i >= 0; i-- { // Now we'll start to calculate the items within the per-hop // payload for the hop this edge is leading to. - edge := pathEdges[i] + edge := pathEdges[i].policy // We'll calculate the amounts, timelocks, and fees for each hop // in the route. The base case is the final hop which includes @@ -245,13 +244,15 @@ func newRoute(sourceVertex route.Vertex, // and its policy for the outgoing channel. This policy // is stored as part of the incoming channel of // the next hop. - fee = pathEdges[i+1].ComputeFee(amtToForward) + fee = pathEdges[i+1].policy.ComputeFee(amtToForward) // We'll take the total timelock of the preceding hop as // the outgoing timelock or this hop. Then we'll // increment the total timelock incurred by this hop. outgoingTimeLock = totalTimeLock - totalTimeLock += uint32(pathEdges[i+1].TimeLockDelta) + totalTimeLock += uint32( + pathEdges[i+1].policy.TimeLockDelta, + ) } // Since we're traversing the path backwards atm, we prepend @@ -504,7 +505,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, - finalHtlcExpiry int32) ([]*models.CachedEdgePolicy, float64, error) { + finalHtlcExpiry int32) ([]*unifiedEdge, float64, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -859,7 +860,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, amountToReceive: amountToReceive, incomingCltv: incomingCltv, probability: probability, - nextHop: edge.policy, + nextHop: edge, routingInfoSize: routingInfoSize, } distance[fromVertex] = withDist @@ -1009,7 +1010,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Use the distance map to unravel the forward path from source to // target. - var pathEdges []*models.CachedEdgePolicy + var pathEdges []*unifiedEdge currentNode := source for { // Determine the next hop forward using the next map. @@ -1024,7 +1025,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, pathEdges = append(pathEdges, currentNodeWithDist.nextHop) // Advance current node. - currentNode = currentNodeWithDist.nextHop.ToNodePubKey() + currentNode = currentNodeWithDist.nextHop.policy.ToNodePubKey() // Check stop condition at the end of this loop. This prevents // breaking out too soon for self-payments that have target set @@ -1045,7 +1046,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // route construction does not care where the features are actually // taken from. In the future we may wish to do route construction within // findPath, and avoid using ChannelEdgePolicy altogether. - pathEdges[len(pathEdges)-1].ToNodeFeatures = features + pathEdges[len(pathEdges)-1].policy.ToNodeFeatures = features log.Debugf("Found route: probability=%v, hops=%v, fee=%v", distance[source].probability, len(pathEdges), diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 6bf754c82..350f0e5a6 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -1225,7 +1225,7 @@ func runPathFindingWithAdditionalEdges(t *testing.T, useCache bool) { } find := func(r *RestrictParams) ( - []*models.CachedEdgePolicy, error) { + []*unifiedEdge, error) { return dbFindPath( graph.graph, additionalEdges, &mockBandwidthHints{}, @@ -1437,7 +1437,7 @@ func runPathFindingWithRedundantAdditionalEdges(t *testing.T, useCache bool) { require.NoError(t, err, "unable to find path to bob") require.Len(t, path, 1) - require.Equal(t, realChannelID, path[0].ChannelID, + require.Equal(t, realChannelID, path[0].policy.ChannelID, "additional edge for known edge wasn't ignored") } @@ -1723,8 +1723,17 @@ func TestNewRoute(t *testing.T) { } t.Run(testCase.name, func(t *testing.T) { + var unifiedHops []*unifiedEdge + for _, hop := range testCase.hops { + unifiedHops = append(unifiedHops, + &unifiedEdge{ + policy: hop, + }, + ) + } + route, err := newRoute( - sourceVertex, testCase.hops, startingHeight, + sourceVertex, unifiedHops, startingHeight, finalHopParams{ amt: testCase.paymentAmount, totalAmt: testCase.paymentAmount, @@ -1864,7 +1873,7 @@ func runDestTLVGraphFallback(t *testing.T, useCache bool) { require.NoError(t, err, "unable to fetch source node") find := func(r *RestrictParams, - target route.Vertex) ([]*models.CachedEdgePolicy, error) { + target route.Vertex) ([]*unifiedEdge, error) { return dbFindPath( ctx.graph, nil, &mockBandwidthHints{}, @@ -2522,7 +2531,7 @@ func TestPathFindSpecExample(t *testing.T) { } func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, - path []*models.CachedEdgePolicy, nodeAliases ...string) { + path []*unifiedEdge, nodeAliases ...string) { if len(path) != len(nodeAliases) { t.Fatalf("number of hops=(%v) and number of aliases=(%v) do "+ @@ -2530,9 +2539,10 @@ func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, } for i, hop := range path { - if hop.ToNodePubKey() != aliasMap[nodeAliases[i]] { + if hop.policy.ToNodePubKey() != aliasMap[nodeAliases[i]] { t.Fatalf("expected %v to be pos #%v in hop, instead "+ - "%v was", nodeAliases[i], i, hop.ToNodePubKey()) + "%v was", nodeAliases[i], i, + hop.policy.ToNodePubKey()) } } } @@ -2606,10 +2616,10 @@ func runRestrictOutgoingChannel(t *testing.T, useCache bool) { // Assert that the route starts with channel chanSourceB1, in line with // the specified restriction. - if path[0].ChannelID != chanSourceB1 { + if path[0].policy.ChannelID != chanSourceB1 { t.Fatalf("expected route to pass through channel %v, "+ "but channel %v was selected instead", chanSourceB1, - path[0].ChannelID) + path[0].policy.ChannelID) } // If a direct channel to target is allowed as well, that channel is @@ -2619,7 +2629,7 @@ func runRestrictOutgoingChannel(t *testing.T, useCache bool) { } path, err = ctx.findPath(target, paymentAmt) require.NoError(t, err, "unable to find path") - if path[0].ChannelID != chanSourceTarget { + if path[0].policy.ChannelID != chanSourceTarget { t.Fatalf("expected route to pass through channel %v", chanSourceTarget) } @@ -2658,10 +2668,10 @@ func runRestrictLastHop(t *testing.T, useCache bool) { ctx.restrictParams.LastHop = &lastHop path, err := ctx.findPath(target, paymentAmt) require.NoError(t, err, "unable to find path") - if path[0].ChannelID != 3 { + if path[0].policy.ChannelID != 3 { t.Fatalf("expected route to pass through channel 3, "+ "but channel %v was selected instead", - path[0].ChannelID) + path[0].policy.ChannelID) } } @@ -2941,10 +2951,10 @@ func testProbabilityRouting(t *testing.T, useCache bool, } // Assert that the route passes through the expected channel. - if path[1].ChannelID != expectedChan { + if path[1].policy.ChannelID != expectedChan { t.Fatalf("expected route to pass through channel %v, "+ "but channel %v was selected instead", expectedChan, - path[1].ChannelID) + path[1].policy.ChannelID) } } @@ -3005,10 +3015,10 @@ func runEqualCostRouteSelection(t *testing.T, useCache bool) { t.Fatal(err) } - if path[1].ChannelID != 2 { + if path[1].policy.ChannelID != 2 { t.Fatalf("expected route to pass through channel %v, "+ "but channel %v was selected instead", 2, - path[1].ChannelID) + path[1].policy.ChannelID) } } @@ -3168,7 +3178,7 @@ func (c *pathFindingTestContext) aliasFromKey(pubKey route.Vertex) string { } func (c *pathFindingTestContext) findPath(target route.Vertex, - amt lnwire.MilliSatoshi) ([]*models.CachedEdgePolicy, + amt lnwire.MilliSatoshi) ([]*unifiedEdge, error) { return dbFindPath( @@ -3177,7 +3187,7 @@ func (c *pathFindingTestContext) findPath(target route.Vertex, ) } -func (c *pathFindingTestContext) assertPath(path []*models.CachedEdgePolicy, +func (c *pathFindingTestContext) assertPath(path []*unifiedEdge, expected []uint64) { if len(path) != len(expected) { @@ -3186,9 +3196,10 @@ func (c *pathFindingTestContext) assertPath(path []*models.CachedEdgePolicy, } for i, edge := range path { - if edge.ChannelID != expected[i] { + if edge.policy.ChannelID != expected[i] { c.t.Fatalf("expected hop %v to be channel %v, "+ - "but got %v", i, expected[i], edge.ChannelID) + "but got %v", i, expected[i], + edge.policy.ChannelID) } } } @@ -3200,7 +3211,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, bandwidthHints bandwidthHints, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, - finalHtlcExpiry int32) ([]*models.CachedEdgePolicy, error) { + finalHtlcExpiry int32) ([]*unifiedEdge, error) { sourceNode, err := graph.SourceNode() if err != nil { @@ -3351,11 +3362,11 @@ func TestBlindedRouteConstruction(t *testing.T) { carolDaveEdge := blindedEdges[carolVertex][0] daveEveEdge := blindedEdges[daveBlindedVertex][0] - edges := []*models.CachedEdgePolicy{ - aliceBobEdge, - bobCarolEdge, - carolDaveEdge.EdgePolicy(), - daveEveEdge.EdgePolicy(), + edges := []*unifiedEdge{ + {policy: aliceBobEdge}, + {policy: bobCarolEdge}, + {policy: carolDaveEdge.EdgePolicy()}, + {policy: daveEveEdge.EdgePolicy()}, } // Total timelock for the route should include: diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 67a285159..75b84a51a 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -212,7 +212,7 @@ func TestRequestRoute(t *testing.T) { // Override pathfinder with a mock. session.pathFinder = func(_ *graphParams, r *RestrictParams, _ *PathFindingConfig, _, _ route.Vertex, _ lnwire.MilliSatoshi, - _ float64, _ int32) ([]*models.CachedEdgePolicy, float64, + _ float64, _ int32) ([]*unifiedEdge, float64, error) { // We expect find path to receive a cltv limit excluding the @@ -221,14 +221,16 @@ func TestRequestRoute(t *testing.T) { t.Fatal("wrong cltv limit") } - path := []*models.CachedEdgePolicy{ + path := []*unifiedEdge{ { - ToNodePubKey: func() route.Vertex { - return route.Vertex{} + policy: &models.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return route.Vertex{} + }, + ToNodeFeatures: lnwire.NewFeatureVector( + nil, nil, + ), }, - ToNodeFeatures: lnwire.NewFeatureVector( - nil, nil, - ), }, } diff --git a/routing/router.go b/routing/router.go index 699ee9a33..58e2d28f4 100644 --- a/routing/router.go +++ b/routing/router.go @@ -3218,14 +3218,14 @@ func getRouteUnifiers(source route.Vertex, hops []route.Vertex, // including fees, to send the payment. func getPathEdges(source route.Vertex, receiverAmt lnwire.MilliSatoshi, unifiers []*edgeUnifier, bandwidthHints *bandwidthManager, - hops []route.Vertex) ([]*models.CachedEdgePolicy, + hops []route.Vertex) ([]*unifiedEdge, lnwire.MilliSatoshi, error) { // Now that we arrived at the start of the route and found out the route // total amount, we make a forward pass. Because the amount may have // been increased in the backward pass, fees need to be recalculated and // amount ranges re-checked. - var pathEdges []*models.CachedEdgePolicy + var pathEdges []*unifiedEdge for i, unifier := range unifiers { edge := unifier.getEdge(receiverAmt, bandwidthHints) if edge == nil { @@ -3247,7 +3247,7 @@ func getPathEdges(source route.Vertex, receiverAmt lnwire.MilliSatoshi, ) } - pathEdges = append(pathEdges, edge.policy) + pathEdges = append(pathEdges, edge) } return pathEdges, receiverAmt, nil diff --git a/routing/router_test.go b/routing/router_test.go index 12c8ff729..47bdf7a17 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2450,8 +2450,8 @@ func TestFindPathFeeWeighting(t *testing.T) { if len(path) != 1 { t.Fatalf("expected path length of 1, instead was: %v", len(path)) } - if path[0].ToNodePubKey() != ctx.aliases["luoji"] { - t.Fatalf("wrong node: %v", path[0].ToNodePubKey()) + if path[0].policy.ToNodePubKey() != ctx.aliases["luoji"] { + t.Fatalf("wrong node: %v", path[0].policy.ToNodePubKey()) } }