diff --git a/channeldb/graph.go b/channeldb/graph.go index 8cf149850..df233f1d9 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -236,6 +236,28 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli }) } +// ForEachNodeChannel iterates through all channels of a given node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal. +func (c *ChannelGraph) ForEachNodeChannel(tx *bbolt.Tx, nodePub []byte, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + db := c.db + + return nodeTraversal(tx, nodePub, db, cb) +} + // ForEachNode iterates through all the stored vertices/nodes in the graph, // executing the passed callback with each node encountered. If the callback // returns an error, then the transaction is aborted and the iteration stops @@ -2183,24 +2205,11 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro return updateTime, exists, nil } -// ForEachChannel iterates through all channels of this node, executing the -// passed callback with an edge info structure and the policies of each end -// of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the -// connecting node. If the callback returns an error, then the iteration is -// halted with the error propagated back up to the caller. -// -// Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, +// nodeTraversal is used to traverse all channels of a node given by its +// public key and passes channel information into the specified callback. +func nodeTraversal(tx *bbolt.Tx, nodePub []byte, db *DB, cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - nodePub := l.PubKeyBytes[:] - traversal := func(tx *bbolt.Tx) error { nodes := tx.Bucket(nodeBucket) if nodes == nil { @@ -2241,7 +2250,7 @@ func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, if err != nil { return err } - edgeInfo.db = l.db + edgeInfo.db = db outgoingPolicy, err := fetchChanEdgePolicy( edges, chanID, nodePub, nodes, @@ -2256,7 +2265,7 @@ func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, } incomingPolicy, err := fetchChanEdgePolicy( - edges, chanID, otherNode, nodes, + edges, chanID, otherNode[:], nodes, ) if err != nil { return err @@ -2275,7 +2284,7 @@ func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, // If no transaction was provided, then we'll create a new transaction // to execute the transaction within. if tx == nil { - return l.db.View(traversal) + return db.View(traversal) } // Otherwise, we re-use the existing transaction to execute the graph @@ -2283,6 +2292,28 @@ func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, return traversal(tx) } +// ForEachChannel iterates through all channels of this node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal. +func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + + nodePub := l.PubKeyBytes[:] + db := l.db + + return nodeTraversal(tx, nodePub, db, cb) +} + // ChannelEdgeInfo represents a fully authenticated channel along with all its // unique attributes. Once an authenticated channel announcement has been // processed on the network, then an instance of ChannelEdgeInfo encapsulating @@ -2450,15 +2481,15 @@ func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { // OtherNodeKeyBytes returns the node key bytes of the other end of // the channel. func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( - []byte, error) { + [33]byte, error) { switch { case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): - return c.NodeKey2Bytes[:], nil + return c.NodeKey2Bytes, nil case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): - return c.NodeKey1Bytes[:], nil + return c.NodeKey1Bytes, nil default: - return nil, fmt.Errorf("node not participating in this channel") + return [33]byte{}, fmt.Errorf("node not participating in this channel") } } diff --git a/routing/heap.go b/routing/heap.go index be7acaf0a..294e3306a 100644 --- a/routing/heap.go +++ b/routing/heap.go @@ -1,8 +1,11 @@ package routing import ( + "container/heap" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" ) // nodeWithDist is a helper struct that couples the distance from the current @@ -12,9 +15,9 @@ type nodeWithDist struct { // current context. dist int64 - // node is the vertex itself. This pointer can be used to explore all - // the outgoing edges (channels) emanating from a node. - node *channeldb.LightningNode + // node is the vertex itself. This can be used to explore all the + // outgoing edges (channels) emanating from a node. + node route.Vertex // amountToReceive is the amount that should be received by this node. // Either as final payment to the final node or as an intermediate @@ -39,6 +42,21 @@ type nodeWithDist struct { // algorithm to keep track of the "closest" node to our source node. type distanceHeap struct { nodes []nodeWithDist + + // pubkeyIndices maps public keys of nodes to their respective index in + // the heap. This is used as a way to avoid db lookups by using heap.Fix + // instead of having duplicate entries on the heap. + pubkeyIndices map[route.Vertex]int +} + +// newDistanceHeap initializes a new distance heap. This is required because +// we must initialize the pubkeyIndices map for path-finding optimizations. +func newDistanceHeap() distanceHeap { + distHeap := distanceHeap{ + pubkeyIndices: make(map[route.Vertex]int), + } + + return distHeap } // Len returns the number of nodes in the priority queue. @@ -59,13 +77,17 @@ func (d *distanceHeap) Less(i, j int) bool { // NOTE: This is part of the heap.Interface implementation. func (d *distanceHeap) Swap(i, j int) { d.nodes[i], d.nodes[j] = d.nodes[j], d.nodes[i] + d.pubkeyIndices[d.nodes[i].node] = i + d.pubkeyIndices[d.nodes[j].node] = j } // Push pushes the passed item onto the priority queue. // // NOTE: This is part of the heap.Interface implementation. func (d *distanceHeap) Push(x interface{}) { - d.nodes = append(d.nodes, x.(nodeWithDist)) + n := x.(nodeWithDist) + d.nodes = append(d.nodes, n) + d.pubkeyIndices[n.node] = len(d.nodes) - 1 } // Pop removes the highest priority item (according to Less) from the priority @@ -76,9 +98,29 @@ func (d *distanceHeap) Pop() interface{} { n := len(d.nodes) x := d.nodes[n-1] d.nodes = d.nodes[0 : n-1] + delete(d.pubkeyIndices, x.node) return x } +// PushOrFix attempts to adjust the position of a given node in the heap. +// If the vertex already exists in the heap, then we must call heap.Fix to +// modify its position and reorder the heap. If the vertex does not already +// exist in the heap, then it is pushed onto the heap. Otherwise, we will end +// up performing more db lookups on the same node in the pathfinding algorithm. +func (d *distanceHeap) PushOrFix(dist nodeWithDist) { + index, ok := d.pubkeyIndices[dist.node] + if !ok { + heap.Push(d, dist) + return + } + + // Change the value at the specified index. + d.nodes[index] = dist + + // Call heap.Fix to reorder the heap. + heap.Fix(d, index) +} + // path represents an ordered set of edges which forms an available path from a // given source node to our destination. During the process of computing the // KSP's from a source to destination, several path swill be considered in the diff --git a/routing/heap_test.go b/routing/heap_test.go index 2ada4232d..4214e965b 100644 --- a/routing/heap_test.go +++ b/routing/heap_test.go @@ -6,7 +6,8 @@ import ( "reflect" "sort" "testing" - "time" + + "github.com/lightningnetwork/lnd/routing/route" ) // TestHeapOrdering ensures that the items inserted into the heap are properly @@ -16,20 +17,33 @@ func TestHeapOrdering(t *testing.T) { // First, create a blank heap, we'll use this to push on randomly // generated items. - var nodeHeap distanceHeap + nodeHeap := newDistanceHeap() - prand.Seed(time.Now().Unix()) + prand.Seed(1) // Create 100 random entries adding them to the heap created above, but // also a list that we'll sort with the entries. const numEntries = 100 sortedEntries := make([]nodeWithDist, 0, numEntries) for i := 0; i < numEntries; i++ { + var pubKey [33]byte + prand.Read(pubKey[:]) + entry := nodeWithDist{ + node: route.Vertex(pubKey), dist: prand.Int63(), } - heap.Push(&nodeHeap, entry) + // Use the PushOrFix method for the initial push to test the scenario + // where entry doesn't exist on the heap. + nodeHeap.PushOrFix(entry) + + // Re-generate this entry's dist field + entry.dist = prand.Int63() + + // Reorder the heap with a PushOrFix call. + nodeHeap.PushOrFix(entry) + sortedEntries = append(sortedEntries, entry) } @@ -47,6 +61,13 @@ func TestHeapOrdering(t *testing.T) { poppedEntries = append(poppedEntries, e) } + // Assert that the pubkeyIndices map is empty after popping all of the + // items off of it. + if len(nodeHeap.pubkeyIndices) != 0 { + t.Fatalf("there are still %d pubkeys in the pubkeyIndices map", + len(nodeHeap.pubkeyIndices)) + } + // Finally, ensure that the items popped from the heap and the items we // sorted are identical at this rate. if !reflect.DeepEqual(poppedEntries, sortedEntries) { diff --git a/routing/pathfind.go b/routing/pathfind.go index 1e3676bc3..7e9515326 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -60,7 +60,7 @@ var ( // of a channel edge. ChannelEdgePolicy only contains to destination node // of the edge. type edgePolicyWithSource struct { - sourceNode *channeldb.LightningNode + sourceNode route.Vertex edge *channeldb.ChannelEdgePolicy } @@ -300,7 +300,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // First we'll initialize an empty heap which'll help us to quickly // locate the next edge we should visit next during our graph // traversal. - var nodeHeap distanceHeap + nodeHeap := newDistanceHeap() // For each node in the graph, we create an entry in the distance map // for the node set with a distance of "infinity". graph.ForEachNode @@ -313,7 +313,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // with a visited map distance[route.Vertex(node.PubKeyBytes)] = nodeWithDist{ dist: infinity, - node: node, + node: route.Vertex(node.PubKeyBytes), } return nil }); err != nil { @@ -324,10 +324,9 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, for vertex, outgoingEdgePolicies := range g.additionalEdges { // We'll also include all the nodes found within the additional // edges that are not known to us yet in the distance map. - node := &channeldb.LightningNode{PubKeyBytes: vertex} distance[vertex] = nodeWithDist{ dist: infinity, - node: node, + node: vertex, } // Build reverse lookup to find incoming edges. Needed because @@ -335,7 +334,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, for _, outgoingEdgePolicy := range outgoingEdgePolicies { toVertex := outgoingEdgePolicy.Node.PubKeyBytes incomingEdgePolicy := &edgePolicyWithSource{ - sourceNode: node, + sourceNode: vertex, edge: outgoingEdgePolicy, } @@ -351,11 +350,10 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // charges no fee. Distance is set to 0, because this is the starting // point of the graph traversal. We are searching backwards to get the // fees first time right and correctly match channel bandwidth. - targetNode := &channeldb.LightningNode{PubKeyBytes: target} distance[target] = nodeWithDist{ dist: 0, weight: 0, - node: targetNode, + node: target, amountToReceive: amt, incomingCltv: 0, probability: 1, @@ -368,11 +366,8 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // processEdge is a helper closure that will be used to make sure edges // satisfy our specific requirements. - processEdge := func(fromNode *channeldb.LightningNode, - edge *channeldb.ChannelEdgePolicy, - bandwidth lnwire.MilliSatoshi, toNode route.Vertex) { - - fromVertex := route.Vertex(fromNode.PubKeyBytes) + processEdge := func(fromVertex route.Vertex, bandwidth lnwire.MilliSatoshi, + edge *channeldb.ChannelEdgePolicy, toNode route.Vertex) { // If this is not a local channel and it is disabled, we will // skip it. @@ -434,7 +429,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, return } - // Compute fee that fromNode is charging. It is based on the + // Compute fee that fromVertex is charging. It is based on the // amount that needs to be sent to the next node in the route. // // Source node has no predecessor to pay a fee. Therefore set @@ -442,7 +437,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // limit check and edge weight. // // Also determine the time lock delta that will be added to the - // route if fromNode is selected. If fromNode is the source + // route if fromVertex is selected. If fromVertex is the source // node, no additional timelock is required. var fee lnwire.MilliSatoshi var timeLockDelta uint16 @@ -485,10 +480,10 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, return } - // By adding fromNode in the route, there will be an extra + // By adding fromVertex in the route, there will be an extra // weight composed of the fee that this node will charge and // the amount that will be locked for timeLockDelta blocks in - // the HTLC that is handed out to fromNode. + // the HTLC that is handed out to fromVertex. weight := edgeWeight(amountToReceive, fee, timeLockDelta) // Compute the tentative weight to this new channel/edge @@ -524,7 +519,7 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, distance[fromVertex] = nodeWithDist{ dist: tempDist, weight: tempWeight, - node: fromNode, + node: fromVertex, amountToReceive: amountToReceive, incomingCltv: incomingCltv, probability: probability, @@ -532,9 +527,10 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, next[fromVertex] = edge - // Add this new node to our heap as we'd like to further - // explore backwards through this edge. - heap.Push(&nodeHeap, distance[fromVertex]) + // Either push distance[fromVertex] onto the heap if the node + // represented by fromVertex is not already on the heap OR adjust + // its position within the heap via heap.Fix. + nodeHeap.PushOrFix(distance[fromVertex]) } // TODO(roasbeef): also add path caching @@ -548,22 +544,17 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // Fetch the node within the smallest distance from our source // from the heap. partialPath := heap.Pop(&nodeHeap).(nodeWithDist) - bestNode := partialPath.node + pivot := partialPath.node // If we've reached our source (or we don't have any incoming // edges), then we're done here and can exit the graph // traversal early. - if bestNode.PubKeyBytes == source { + if pivot == source { break } - // Now that we've found the next potential step to take we'll - // examine all the incoming edges (channels) from this node to - // further our graph traversal. - pivot := route.Vertex(bestNode.PubKeyBytes) - err := bestNode.ForEachChannel(tx, func(tx *bbolt.Tx, - edgeInfo *channeldb.ChannelEdgeInfo, - _, inEdge *channeldb.ChannelEdgePolicy) error { + cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, _, + inEdge *channeldb.ChannelEdgePolicy) error { // If there is no edge policy for this candidate // node, skip. Note that we are searching backwards @@ -595,18 +586,21 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // the node on the _other_ end of this channel as we // may later need to iterate over the incoming edges of // this node if we explore it further. - channelSource, err := edgeInfo.FetchOtherNode( - tx, pivot[:], - ) + chanSource, err := edgeInfo.OtherNodeKeyBytes(pivot[:]) if err != nil { return err } // Check if this candidate node is better than what we // already have. - processEdge(channelSource, inEdge, edgeBandwidth, pivot) + processEdge(route.Vertex(chanSource), edgeBandwidth, inEdge, pivot) return nil - }) + } + + // Now that we've found the next potential step to take we'll + // examine all the incoming edges (channels) from this node to + // further our graph traversal. + err := g.graph.ForEachNodeChannel(tx, pivot[:], cb) if err != nil { return nil, err } @@ -617,9 +611,9 @@ func findPath(g *graphParams, r *RestrictParams, source, target route.Vertex, // routing hint due to having enough capacity for the payment // and use the payment amount as its capacity. bandWidth := partialPath.amountToReceive - for _, reverseEdge := range additionalEdgesWithSrc[bestNode.PubKeyBytes] { - processEdge(reverseEdge.sourceNode, reverseEdge.edge, - bandWidth, pivot) + for _, reverseEdge := range additionalEdgesWithSrc[pivot] { + processEdge(reverseEdge.sourceNode, bandWidth, + reverseEdge.edge, pivot) } }