channeldb: optimize memory usage of initial cache fill

With this commit we use an optimized version of the node iteration that
causes fewer memory allocations by only loading the part of the graph
node that we actually need to know for the cache.
This commit is contained in:
Oliver Gugger 2021-09-28 13:23:02 +02:00
parent a5202a89e6
commit 6240851f93
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
2 changed files with 215 additions and 13 deletions

View File

@ -211,8 +211,8 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
startTime := time.Now()
log.Debugf("Populating in-memory channel graph, this might take a " +
"while...")
err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error {
return g.graphCache.AddNode(tx, &graphCacheNode{node})
err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error {
return g.graphCache.AddNode(tx, node)
})
if err != nil {
return nil, err
@ -468,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
return kvdb.View(c.db, traversal, func() {})
}
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
// graph, executing the passed callback with each node encountered. If the
// callback returns an error, then the transaction is aborted and the iteration
// stops early.
func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
GraphCacheNode) error) error {
traversal := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
cacheableNode := newGraphCacheNode(route.Vertex{}, nil)
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
// rather than raw node information.
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
return nil
}
nodeReader := bytes.NewReader(nodeBytes)
err := deserializeLightningNodeCacheable(
nodeReader, cacheableNode,
)
if err != nil {
return err
}
// Execute the callback, the transaction will abort if
// this returns an error.
return cb(tx, cacheableNode)
})
}
return kvdb.View(c.db, traversal, func() {})
}
// SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another
@ -559,8 +600,10 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode,
r := &batch.Request{
Update: func(tx kvdb.RwTx) error {
wNode := &graphCacheNode{node}
if err := c.graphCache.AddNode(tx, wNode); err != nil {
cNode := newGraphCacheNode(
node.PubKeyBytes, node.Features,
)
if err := c.graphCache.AddNode(tx, cNode); err != nil {
return err
}
@ -2532,17 +2575,30 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
// graphCacheNode is a struct that wraps a LightningNode in a way that it can be
// cached in the graph cache.
type graphCacheNode struct {
lnNode *LightningNode
pubKeyBytes route.Vertex
features *lnwire.FeatureVector
nodeScratch [8]byte
}
// newGraphCacheNode returns a new cache optimized node.
func newGraphCacheNode(pubKey route.Vertex,
features *lnwire.FeatureVector) *graphCacheNode {
return &graphCacheNode{
pubKeyBytes: pubKey,
features: features,
}
}
// PubKey returns the node's public identity key.
func (w *graphCacheNode) PubKey() route.Vertex {
return w.lnNode.PubKeyBytes
func (n *graphCacheNode) PubKey() route.Vertex {
return n.pubKeyBytes
}
// Features returns the node's features.
func (w *graphCacheNode) Features() *lnwire.FeatureVector {
return w.lnNode.Features
func (n *graphCacheNode) Features() *lnwire.FeatureVector {
return n.features
}
// ForEachChannel iterates through all channels of this node, executing the
@ -2553,11 +2609,11 @@ func (w *graphCacheNode) Features() *lnwire.FeatureVector {
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx,
func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
return w.lnNode.ForEachChannel(tx, cb)
return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb)
}
var _ GraphCacheNode = (*graphCacheNode)(nil)
@ -3865,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket,
return deserializeLightningNode(nodeReader)
}
func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error {
// Always populate a feature vector, even if we don't have a node
// announcement and short circuit below.
node.features = lnwire.EmptyFeatureVector()
// Skip ahead:
// - LastUpdate (8 bytes)
if _, err := r.Read(node.nodeScratch[:]); err != nil {
return err
}
if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil {
return err
}
// Read the node announcement flag.
if _, err := r.Read(node.nodeScratch[:2]); err != nil {
return err
}
hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2])
// The rest of the data is optional, and will only be there if we got a
// node announcement for this node.
if hasNodeAnn == 0 {
return nil
}
// We did get a node announcement for this node, so we'll have the rest
// of the data available.
var rgb uint8
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if _, err := wire.ReadVarString(r, 0); err != nil {
return err
}
return node.features.Decode(r)
}
func deserializeLightningNode(r io.Reader) (LightningNode, error) {
var (
node LightningNode

View File

@ -21,6 +21,7 @@ import (
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
@ -1145,6 +1146,58 @@ func TestGraphTraversal(t *testing.T) {
require.Equal(t, numChannels, numNodeChans)
}
// TestGraphTraversalCacheable tests that the memory optimized node traversal is
// working correctly.
func TestGraphTraversalCacheable(t *testing.T) {
t.Parallel()
graph, cleanUp, err := MakeTestGraph()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
// We'd like to test some of the graph traversal capabilities within
// the DB, so we'll create a series of fake nodes to insert into the
// graph. And we'll create 5 channels between the first two nodes.
const numNodes = 20
const numChannels = 5
chanIndex, _ := fillTestGraph(t, graph, numNodes, numChannels)
// Create a map of all nodes with the iteration we know works (because
// it is tested in another test).
nodeMap := make(map[route.Vertex]struct{})
err = graph.ForEachNode(func(tx kvdb.RTx, n *LightningNode) error {
nodeMap[n.PubKeyBytes] = struct{}{}
return nil
})
require.NoError(t, err)
require.Len(t, nodeMap, numNodes)
// Iterate through all the known channels within the graph DB by
// iterating over each node, once again if the map is empty that
// indicates that all edges have properly been reached.
err = graph.ForEachNodeCacheable(
func(tx kvdb.RTx, node GraphCacheNode) error {
delete(nodeMap, node.PubKey())
return node.ForEachChannel(
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
policy *ChannelEdgePolicy,
policy2 *ChannelEdgePolicy) error {
delete(chanIndex, info.ChannelID)
return nil
},
)
},
)
require.NoError(t, err)
require.Len(t, nodeMap, 0)
require.Len(t, chanIndex, 0)
}
func TestGraphCacheTraversal(t *testing.T) {
t.Parallel()
@ -1164,6 +1217,8 @@ func TestGraphCacheTraversal(t *testing.T) {
// properly been reached.
numNodeChans := 0
for _, node := range nodeList {
node := node
err = graph.graphCache.ForEachChannel(
node.PubKeyBytes, func(d *DirectedChannel) error {
delete(chanIndex, d.ChannelID)
@ -1197,7 +1252,7 @@ func TestGraphCacheTraversal(t *testing.T) {
require.Equal(t, numChannels*2*(numNodes-1), numNodeChans)
}
func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes,
func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes,
numChannels int) (map[uint64]struct{}, []*LightningNode) {
nodes := make([]*LightningNode, numNodes)
@ -1237,7 +1292,7 @@ func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes,
for i := 0; i < numChannels; i++ {
txHash := sha256.Sum256([]byte{byte(i)})
chanID := uint64((n << 4) + i + 1)
chanID := uint64((n << 8) + i + 1)
op := wire.OutPoint{
Hash: txHash,
Index: 0,
@ -3592,3 +3647,47 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) {
require.Nil(t, err)
}
}
// BenchmarkForEachChannel is a benchmark test that measures the number of
// allocations and the total memory consumed by the full graph traversal.
func BenchmarkForEachChannel(b *testing.B) {
graph, cleanUp, err := MakeTestGraph()
require.Nil(b, err)
defer cleanUp()
const numNodes = 100
const numChannels = 4
_, _ = fillTestGraph(b, graph, numNodes, numChannels)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
var (
totalCapacity btcutil.Amount
maxHTLCs lnwire.MilliSatoshi
)
err := graph.ForEachNodeCacheable(
func(tx kvdb.RTx, n GraphCacheNode) error {
return n.ForEachChannel(
tx, func(tx kvdb.RTx,
info *ChannelEdgeInfo,
policy *ChannelEdgePolicy,
policy2 *ChannelEdgePolicy) error {
// We need to do something with
// the data here, otherwise the
// compiler is going to optimize
// this away, and we get bogus
// results.
totalCapacity += info.Capacity
maxHTLCs += policy.MaxHTLC
maxHTLCs += policy2.MaxHTLC
return nil
},
)
},
)
require.NoError(b, err)
}
}