diff --git a/channeldb/graph.go b/channeldb/graph.go index 69c86d1f9..68f4fb537 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -211,8 +211,8 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, startTime := time.Now() log.Debugf("Populating in-memory channel graph, this might take a " + "while...") - err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { - return g.graphCache.AddNode(tx, &graphCacheNode{node}) + err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error { + return g.graphCache.AddNode(tx, node) }) if err != nil { return nil, err @@ -468,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro return kvdb.View(c.db, traversal, func() {}) } +// ForEachNodeCacheable iterates through all the stored vertices/nodes in the +// graph, executing the passed callback with each node encountered. If the +// callback returns an error, then the transaction is aborted and the iteration +// stops early. +func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx, + GraphCacheNode) error) error { + + traversal := func(tx kvdb.RTx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.ReadBucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + cacheableNode := newGraphCacheNode(route.Vertex{}, nil) + return nodes.ForEach(func(pubKey, nodeBytes []byte) error { + // If this is the source key, then we skip this + // iteration as the value for this key is a pubKey + // rather than raw node information. + if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { + return nil + } + + nodeReader := bytes.NewReader(nodeBytes) + err := deserializeLightningNodeCacheable( + nodeReader, cacheableNode, + ) + if err != nil { + return err + } + + // Execute the callback, the transaction will abort if + // this returns an error. + return cb(tx, cacheableNode) + }) + } + + return kvdb.View(c.db, traversal, func() {}) +} + // SourceNode returns the source node of the graph. The source node is treated // as the center node within a star-graph. This method may be used to kick off // a path finding algorithm in order to explore the reachability of another @@ -559,8 +600,10 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode, r := &batch.Request{ Update: func(tx kvdb.RwTx) error { - wNode := &graphCacheNode{node} - if err := c.graphCache.AddNode(tx, wNode); err != nil { + cNode := newGraphCacheNode( + node.PubKeyBytes, node.Features, + ) + if err := c.graphCache.AddNode(tx, cNode); err != nil { return err } @@ -2532,17 +2575,30 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( // graphCacheNode is a struct that wraps a LightningNode in a way that it can be // cached in the graph cache. type graphCacheNode struct { - lnNode *LightningNode + pubKeyBytes route.Vertex + features *lnwire.FeatureVector + + nodeScratch [8]byte +} + +// newGraphCacheNode returns a new cache optimized node. +func newGraphCacheNode(pubKey route.Vertex, + features *lnwire.FeatureVector) *graphCacheNode { + + return &graphCacheNode{ + pubKeyBytes: pubKey, + features: features, + } } // PubKey returns the node's public identity key. -func (w *graphCacheNode) PubKey() route.Vertex { - return w.lnNode.PubKeyBytes +func (n *graphCacheNode) PubKey() route.Vertex { + return n.pubKeyBytes } // Features returns the node's features. -func (w *graphCacheNode) Features() *lnwire.FeatureVector { - return w.lnNode.Features +func (n *graphCacheNode) Features() *lnwire.FeatureVector { + return n.features } // ForEachChannel iterates through all channels of this node, executing the @@ -2553,11 +2609,11 @@ func (w *graphCacheNode) Features() *lnwire.FeatureVector { // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx, +func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx, cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - return w.lnNode.ForEachChannel(tx, cb) + return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb) } var _ GraphCacheNode = (*graphCacheNode)(nil) @@ -3865,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket, return deserializeLightningNode(nodeReader) } +func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error { + // Always populate a feature vector, even if we don't have a node + // announcement and short circuit below. + node.features = lnwire.EmptyFeatureVector() + + // Skip ahead: + // - LastUpdate (8 bytes) + if _, err := r.Read(node.nodeScratch[:]); err != nil { + return err + } + + if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil { + return err + } + + // Read the node announcement flag. + if _, err := r.Read(node.nodeScratch[:2]); err != nil { + return err + } + hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2]) + + // The rest of the data is optional, and will only be there if we got a + // node announcement for this node. + if hasNodeAnn == 0 { + return nil + } + + // We did get a node announcement for this node, so we'll have the rest + // of the data available. + var rgb uint8 + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + + if _, err := wire.ReadVarString(r, 0); err != nil { + return err + } + + return node.features.Decode(r) +} + func deserializeLightningNode(r io.Reader) (LightningNode, error) { var ( node LightningNode diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index b43dbb972..27b979842 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -1145,6 +1146,58 @@ func TestGraphTraversal(t *testing.T) { require.Equal(t, numChannels, numNodeChans) } +// TestGraphTraversalCacheable tests that the memory optimized node traversal is +// working correctly. +func TestGraphTraversalCacheable(t *testing.T) { + t.Parallel() + + graph, cleanUp, err := MakeTestGraph() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. And we'll create 5 channels between the first two nodes. + const numNodes = 20 + const numChannels = 5 + chanIndex, _ := fillTestGraph(t, graph, numNodes, numChannels) + + // Create a map of all nodes with the iteration we know works (because + // it is tested in another test). + nodeMap := make(map[route.Vertex]struct{}) + err = graph.ForEachNode(func(tx kvdb.RTx, n *LightningNode) error { + nodeMap[n.PubKeyBytes] = struct{}{} + + return nil + }) + require.NoError(t, err) + require.Len(t, nodeMap, numNodes) + + // Iterate through all the known channels within the graph DB by + // iterating over each node, once again if the map is empty that + // indicates that all edges have properly been reached. + err = graph.ForEachNodeCacheable( + func(tx kvdb.RTx, node GraphCacheNode) error { + delete(nodeMap, node.PubKey()) + + return node.ForEachChannel( + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + policy *ChannelEdgePolicy, + policy2 *ChannelEdgePolicy) error { + + delete(chanIndex, info.ChannelID) + return nil + }, + ) + }, + ) + require.NoError(t, err) + require.Len(t, nodeMap, 0) + require.Len(t, chanIndex, 0) +} + func TestGraphCacheTraversal(t *testing.T) { t.Parallel() @@ -1164,6 +1217,8 @@ func TestGraphCacheTraversal(t *testing.T) { // properly been reached. numNodeChans := 0 for _, node := range nodeList { + node := node + err = graph.graphCache.ForEachChannel( node.PubKeyBytes, func(d *DirectedChannel) error { delete(chanIndex, d.ChannelID) @@ -1197,7 +1252,7 @@ func TestGraphCacheTraversal(t *testing.T) { require.Equal(t, numChannels*2*(numNodes-1), numNodeChans) } -func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes, +func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, numChannels int) (map[uint64]struct{}, []*LightningNode) { nodes := make([]*LightningNode, numNodes) @@ -1237,7 +1292,7 @@ func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes, for i := 0; i < numChannels; i++ { txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64((n << 4) + i + 1) + chanID := uint64((n << 8) + i + 1) op := wire.OutPoint{ Hash: txHash, Index: 0, @@ -3592,3 +3647,47 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { require.Nil(t, err) } } + +// BenchmarkForEachChannel is a benchmark test that measures the number of +// allocations and the total memory consumed by the full graph traversal. +func BenchmarkForEachChannel(b *testing.B) { + graph, cleanUp, err := MakeTestGraph() + require.Nil(b, err) + defer cleanUp() + + const numNodes = 100 + const numChannels = 4 + _, _ = fillTestGraph(b, graph, numNodes, numChannels) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var ( + totalCapacity btcutil.Amount + maxHTLCs lnwire.MilliSatoshi + ) + err := graph.ForEachNodeCacheable( + func(tx kvdb.RTx, n GraphCacheNode) error { + return n.ForEachChannel( + tx, func(tx kvdb.RTx, + info *ChannelEdgeInfo, + policy *ChannelEdgePolicy, + policy2 *ChannelEdgePolicy) error { + + // We need to do something with + // the data here, otherwise the + // compiler is going to optimize + // this away, and we get bogus + // results. + totalCapacity += info.Capacity + maxHTLCs += policy.MaxHTLC + maxHTLCs += policy2.MaxHTLC + + return nil + }, + ) + }, + ) + require.NoError(b, err) + } +}