mirror of
https://github.com/lightningnetwork/lnd.git
synced 2024-11-19 09:53:54 +01:00
routing: allow route to self
This commit is contained in:
parent
81b7798c03
commit
f8e9efbf99
@ -381,11 +381,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
|
|||||||
|
|
||||||
// We can't always assume that the end destination is publicly
|
// We can't always assume that the end destination is publicly
|
||||||
// advertised to the network so we'll manually include the target node.
|
// advertised to the network so we'll manually include the target node.
|
||||||
// The target node charges no fee. Distance is set to 0, because this
|
// The target node charges no fee. Distance is set to 0, because this is
|
||||||
// is the starting point of the graph traversal. We are searching
|
// the starting point of the graph traversal. We are searching backwards
|
||||||
// backwards to get the fees first time right and correctly match
|
// to get the fees first time right and correctly match channel
|
||||||
// channel bandwidth.
|
// bandwidth.
|
||||||
distance[target] = &nodeWithDist{
|
//
|
||||||
|
// Don't record the initial partial path in the distance map and reserve
|
||||||
|
// that key for the source key in the case we route to ourselves.
|
||||||
|
partialPath := &nodeWithDist{
|
||||||
dist: 0,
|
dist: 0,
|
||||||
weight: 0,
|
weight: 0,
|
||||||
node: target,
|
node: target,
|
||||||
@ -530,9 +533,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
|
|||||||
// TODO(roasbeef): also add path caching
|
// TODO(roasbeef): also add path caching
|
||||||
// * similar to route caching, but doesn't factor in the amount
|
// * similar to route caching, but doesn't factor in the amount
|
||||||
|
|
||||||
// The partial path that we start out with is a path that consists of
|
routeToSelf := source == target
|
||||||
// just the target node.
|
|
||||||
partialPath := distance[target]
|
|
||||||
for {
|
for {
|
||||||
nodesVisited++
|
nodesVisited++
|
||||||
|
|
||||||
@ -555,6 +556,15 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
|
|||||||
// Expand all connections using the optimal policy for each
|
// Expand all connections using the optimal policy for each
|
||||||
// connection.
|
// connection.
|
||||||
for fromNode, unifiedPolicy := range u.policies {
|
for fromNode, unifiedPolicy := range u.policies {
|
||||||
|
// The target node is not recorded in the distance map.
|
||||||
|
// Therefore we need to have this check to prevent
|
||||||
|
// creating a cycle. Only when we intend to route to
|
||||||
|
// self, we allow this cycle to form. In that case we'll
|
||||||
|
// also break out of the search loop below.
|
||||||
|
if !routeToSelf && fromNode == target {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Apply last hop restriction if set.
|
// Apply last hop restriction if set.
|
||||||
if r.LastHop != nil &&
|
if r.LastHop != nil &&
|
||||||
pivot == target && fromNode != *r.LastHop {
|
pivot == target && fromNode != *r.LastHop {
|
||||||
@ -610,6 +620,9 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
|
|||||||
// Advance current node.
|
// Advance current node.
|
||||||
currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes
|
currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes
|
||||||
|
|
||||||
|
// Check stop condition at the end of this loop. This prevents
|
||||||
|
// breaking out too soon for self-payments that have target set
|
||||||
|
// to source.
|
||||||
if currentNode == target {
|
if currentNode == target {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -2224,6 +2224,53 @@ func TestNoCycle(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRouteToSelf tests that it is possible to find a route to the self node.
|
||||||
|
func TestRouteToSelf(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testChannels := []*testChannel{
|
||||||
|
symmetricTestChannel("source", "a", 100000, &testChannelPolicy{
|
||||||
|
Expiry: 144,
|
||||||
|
FeeBaseMsat: 500,
|
||||||
|
}, 1),
|
||||||
|
symmetricTestChannel("source", "b", 100000, &testChannelPolicy{
|
||||||
|
Expiry: 144,
|
||||||
|
FeeBaseMsat: 1000,
|
||||||
|
}, 2),
|
||||||
|
symmetricTestChannel("a", "b", 100000, &testChannelPolicy{
|
||||||
|
Expiry: 144,
|
||||||
|
FeeBaseMsat: 1000,
|
||||||
|
}, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := newPathFindingTestContext(t, testChannels, "source")
|
||||||
|
defer ctx.cleanup()
|
||||||
|
|
||||||
|
paymentAmt := lnwire.NewMSatFromSatoshis(100)
|
||||||
|
target := ctx.source
|
||||||
|
|
||||||
|
// Find the best path to self. We expect this to be source->a->source,
|
||||||
|
// because a charges the lowest forwarding fee.
|
||||||
|
path, err := ctx.findPath(target, paymentAmt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to find path: %v", err)
|
||||||
|
}
|
||||||
|
ctx.assertPath(path, []uint64{1, 1})
|
||||||
|
|
||||||
|
outgoingChanID := uint64(1)
|
||||||
|
lastHop := ctx.keyFromAlias("b")
|
||||||
|
ctx.restrictParams.OutgoingChannelID = &outgoingChanID
|
||||||
|
ctx.restrictParams.LastHop = &lastHop
|
||||||
|
|
||||||
|
// Find the best path to self given that we want to go out via channel 1
|
||||||
|
// and return through node b.
|
||||||
|
path, err = ctx.findPath(target, paymentAmt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to find path: %v", err)
|
||||||
|
}
|
||||||
|
ctx.assertPath(path, []uint64{1, 3, 2})
|
||||||
|
}
|
||||||
|
|
||||||
type pathFindingTestContext struct {
|
type pathFindingTestContext struct {
|
||||||
t *testing.T
|
t *testing.T
|
||||||
graphParams graphParams
|
graphParams graphParams
|
||||||
@ -2291,3 +2338,17 @@ func (c *pathFindingTestContext) findPath(target route.Vertex,
|
|||||||
c.source, target, amt,
|
c.source, target, amt,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) {
|
||||||
|
if len(path) != len(expected) {
|
||||||
|
c.t.Fatalf("expected path of length %v, but got %v",
|
||||||
|
len(expected), len(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, edge := range path {
|
||||||
|
if edge.ChannelID != expected[i] {
|
||||||
|
c.t.Fatalf("expected hop %v to be channel %v, "+
|
||||||
|
"but got %v", i, expected[i], edge.ChannelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user