diff --git a/routing/heap.go b/routing/heap.go index f155fecd6..80336fd0c 100644 --- a/routing/heap.go +++ b/routing/heap.go @@ -21,6 +21,10 @@ type nodeWithDist struct { // amount that includes also the fees for subsequent hops. amountToReceive lnwire.MilliSatoshi + // incomingCltv is the expected cltv value for the incoming htlc of this + // node. This value does not include the final cltv. + incomingCltv uint32 + // fee is the fee that this node is charging for forwarding. fee lnwire.MilliSatoshi } diff --git a/routing/pathfind.go b/routing/pathfind.go index 45853ad68..f7afa93d8 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -393,6 +393,11 @@ type RestrictParams struct { // OutgoingChannelID is the channel that needs to be taken to the first // hop. If nil, any channel may be used. OutgoingChannelID *uint64 + + // CltvLimit is the maximum time lock of the route excluding the final + // ctlv. After path finding is complete, the caller needs to increase + // all cltv expiry heights with the required final cltv delta. + CltvLimit *uint32 } // findPath attempts to find a path from the source node within the @@ -479,6 +484,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target Vertex, node: targetNode, amountToReceive: amt, fee: 0, + incomingCltv: 0, } // We'll use this map as a series of "next" hop pointers. So to get @@ -575,6 +581,14 @@ func findPath(g *graphParams, r *RestrictParams, source, target Vertex, timeLockDelta = edge.TimeLockDelta } + incomingCltv := toNodeDist.incomingCltv + + uint32(timeLockDelta) + + // Check that we have cltv limit and that we are within it. + if r.CltvLimit != nil && incomingCltv > *r.CltvLimit { + return + } + // amountToReceive is the amount that the node that is added to // the distance map needs to receive from a (to be found) // previous node in the route. That previous node will need to @@ -622,6 +636,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target Vertex, node: fromNode, amountToReceive: amountToReceive, fee: fee, + incomingCltv: incomingCltv, } next[fromVertex] = edge diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index c7fa4564b..d0e508f66 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2021,3 +2021,111 @@ func TestRestrictOutgoingChannel(t *testing.T) { "but channel %v was selected instead", route.Hops[0].ChannelID) } } + +// TestCltvLimit asserts that a cltv limit is obeyed by the path finding +// algorithm. +func TestCltvLimit(t *testing.T) { + t.Run("no limit", func(t *testing.T) { testCltvLimit(t, 0, 1) }) + t.Run("no path", func(t *testing.T) { testCltvLimit(t, 50, 0) }) + t.Run("force high cost", func(t *testing.T) { testCltvLimit(t, 80, 3) }) +} + +func testCltvLimit(t *testing.T, limit uint32, expectedChannel uint64) { + t.Parallel() + + // Set up a test graph with three possible paths to the target. The path + // through a is the lowest cost with a high time lock (144). The path + // through b has a higher cost but a lower time lock (100). That path + // through c and d (two hops) has the same case as the path through b, + // but the total time lock is lower (60). + testChannels := []*testChannel{ + symmetricTestChannel("roasbeef", "a", 100000, &testChannelPolicy{}, 1), + symmetricTestChannel("a", "target", 100000, &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 10000, + MinHTLC: 1, + }), + symmetricTestChannel("roasbeef", "b", 100000, &testChannelPolicy{}, 2), + symmetricTestChannel("b", "target", 100000, &testChannelPolicy{ + Expiry: 100, + FeeBaseMsat: 20000, + MinHTLC: 1, + }), + symmetricTestChannel("roasbeef", "c", 100000, &testChannelPolicy{}, 3), + symmetricTestChannel("c", "d", 100000, &testChannelPolicy{ + Expiry: 30, + FeeBaseMsat: 10000, + MinHTLC: 1, + }), + symmetricTestChannel("d", "target", 100000, &testChannelPolicy{ + Expiry: 30, + FeeBaseMsat: 10000, + MinHTLC: 1, + }), + } + + testGraphInstance, err := createTestGraphFromChannels(testChannels) + if err != nil { + t.Fatalf("unable to create graph: %v", err) + } + defer testGraphInstance.cleanUp() + + sourceNode, err := testGraphInstance.graph.SourceNode() + if err != nil { + t.Fatalf("unable to fetch source node: %v", err) + } + sourceVertex := Vertex(sourceNode.PubKeyBytes) + + ignoredEdges := make(map[EdgeLocator]struct{}) + ignoredVertexes := make(map[Vertex]struct{}) + + paymentAmt := lnwire.NewMSatFromSatoshis(100) + target := testGraphInstance.aliasMap["target"] + + // Find the best path given the cltv limit. + var cltvLimit *uint32 + if limit != 0 { + cltvLimit = &limit + } + + path, err := findPath( + &graphParams{ + graph: testGraphInstance.graph, + }, + &RestrictParams{ + IgnoredNodes: ignoredVertexes, + IgnoredEdges: ignoredEdges, + FeeLimit: noFeeLimit, + CltvLimit: cltvLimit, + }, + sourceVertex, target, paymentAmt, + ) + if expectedChannel == 0 { + // Finish test if we expect no route. + if IsError(err, ErrNoPathFound) { + return + } + t.Fatal("expected no path to be found") + } + if err != nil { + t.Fatalf("unable to find path: %v", err) + } + + const ( + startingHeight = 100 + finalHopCLTV = 1 + ) + route, err := newRoute( + paymentAmt, sourceVertex, path, startingHeight, finalHopCLTV, + ) + if err != nil { + t.Fatalf("unable to create path: %v", err) + } + + // Assert that the route starts with the expected channel. + if route.Hops[0].ChannelID != expectedChannel { + t.Fatalf("expected route to pass through channel %v, "+ + "but channel %v was selected instead", expectedChannel, + route.Hops[0].ChannelID) + } +}