From 7d29ab905c6226d750726bf943a1830c2a9c3266 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Fri, 30 Sep 2022 09:03:27 +0200 Subject: [PATCH] routing: return *unifiedPolicyEdge in getPolicy We encapsulate the capacity inside a unifiedPolicyEdge for later usage. The meaning of "policy" has changed now, which will be refactored in the next commmit. --- routing/pathfind.go | 14 +++++++------- routing/router.go | 9 +++++---- routing/unified_policies.go | 15 ++++++++------- routing/unified_policies_test.go | 4 +++- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/routing/pathfind.go b/routing/pathfind.go index 0d0a1e920..fb351bd88 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -628,7 +628,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // satisfy our specific requirements. processEdge := func(fromVertex route.Vertex, fromFeatures *lnwire.FeatureVector, - edge *channeldb.CachedEdgePolicy, toNodeDist *nodeWithDist) { + edge *unifiedPolicyEdge, toNodeDist *nodeWithDist) { edgesExpanded++ @@ -666,8 +666,8 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, var fee lnwire.MilliSatoshi var timeLockDelta uint16 if fromVertex != source { - fee = edge.ComputeFee(amountToSend) - timeLockDelta = edge.TimeLockDelta + fee = edge.policy.ComputeFee(amountToSend) + timeLockDelta = edge.policy.TimeLockDelta } incomingCltv := toNodeDist.incomingCltv + int32(timeLockDelta) @@ -744,9 +744,9 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Every edge should have a positive time lock delta. If we // encounter a zero delta, log a warning line. - if edge.TimeLockDelta == 0 { + if edge.policy.TimeLockDelta == 0 { log.Warnf("Channel %v has zero cltv delta", - edge.ChannelID) + edge.policy.ChannelID) } // Calculate the total routing info size if this hop were to be @@ -767,7 +767,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, LegacyPayload: !supportsTlv, } - payloadSize = hop.PayloadSize(edge.ChannelID) + payloadSize = hop.PayloadSize(edge.policy.ChannelID) } routingInfoSize := toNodeDist.routingInfoSize + payloadSize @@ -788,7 +788,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, amountToReceive: amountToReceive, incomingCltv: incomingCltv, probability: probability, - nextHop: edge, + nextHop: edge.policy, routingInfoSize: routingInfoSize, } distance[fromVertex] = withDist diff --git a/routing/router.go b/routing/router.go index 0588d5122..696d9cb9c 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2837,10 +2837,11 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // Add fee for this hop. if !localChan { - runningAmt += policy.ComputeFee(runningAmt) + runningAmt += policy.policy.ComputeFee(runningAmt) } - log.Tracef("Select channel %v at position %v", policy.ChannelID, i) + log.Tracef("Select channel %v at position %v", + policy.policy.ChannelID, i) edges[i] = unifiedPolicy } @@ -2862,12 +2863,12 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, if i > 0 { // Decrease the amount to send while going forward. - receiverAmt -= policy.ComputeFeeFromIncoming( + receiverAmt -= policy.policy.ComputeFeeFromIncoming( receiverAmt, ) } - pathEdges = append(pathEdges, policy) + pathEdges = append(pathEdges, policy.policy) } // Build and return the final route. diff --git a/routing/unified_policies.go b/routing/unified_policies.go index ec7e594d5..c2cc1b373 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -133,7 +133,7 @@ type unifiedPolicy struct { // specific amount to send. It differentiates between local and network // channels. func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, - bandwidthHints bandwidthHints) *channeldb.CachedEdgePolicy { + bandwidthHints bandwidthHints) *unifiedPolicyEdge { if u.localChan { return u.getPolicyLocal(amt, bandwidthHints) @@ -145,10 +145,10 @@ func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, // getPolicyLocal returns the optimal policy to use for this local connection // given a specific amount to send. func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, - bandwidthHints bandwidthHints) *channeldb.CachedEdgePolicy { + bandwidthHints bandwidthHints) *unifiedPolicyEdge { var ( - bestPolicy *channeldb.CachedEdgePolicy + bestPolicy *unifiedPolicyEdge maxBandwidth lnwire.MilliSatoshi ) @@ -192,7 +192,7 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, maxBandwidth = bandwidth // Update best policy. - bestPolicy = edge.policy + bestPolicy = &unifiedPolicyEdge{policy: edge.policy} } return bestPolicy @@ -202,7 +202,7 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, // a specific amount to send. The goal is to return a policy that maximizes the // probability of a successful forward in a non-strict forwarding context. func (u *unifiedPolicy) getPolicyNetwork( - amt lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { + amt lnwire.MilliSatoshi) *unifiedPolicyEdge { var ( bestPolicy *channeldb.CachedEdgePolicy @@ -255,8 +255,9 @@ func (u *unifiedPolicy) getPolicyNetwork( // get forwarded. Because we penalize pair-wise, there won't be a second // chance for this node pair. But this is all only needed for nodes that // have distinct policies for channels to the same peer. - modifiedPolicy := *bestPolicy - modifiedPolicy.TimeLockDelta = maxTimelock + policyCopy := *bestPolicy + modifiedPolicy := unifiedPolicyEdge{policy: &policyCopy} + modifiedPolicy.policy.TimeLockDelta = maxTimelock return &modifiedPolicy } diff --git a/routing/unified_policies_test.go b/routing/unified_policies_test.go index abdc56b62..8c8d9bbd8 100644 --- a/routing/unified_policies_test.go +++ b/routing/unified_policies_test.go @@ -39,12 +39,14 @@ func TestUnifiedPolicies(t *testing.T) { u.addPolicy(fromNode, &p1, 7) u.addPolicy(fromNode, &p2, 7) - checkPolicy := func(policy *channeldb.CachedEdgePolicy, + checkPolicy := func(unifiedPolicy *unifiedPolicyEdge, feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, timeLockDelta uint16) { t.Helper() + policy := unifiedPolicy.policy + if policy.FeeBaseMSat != feeBase { t.Fatalf("expected fee base %v, got %v", feeBase, policy.FeeBaseMSat)