mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
channeldb: optimize memory usage of initial cache fill
With this commit we use an optimized version of the node iteration that causes fewer memory allocations by only loading the part of the graph node that we actually need to know for the cache.
This commit is contained in:
parent
a5202a89e6
commit
6240851f93
@ -211,8 +211,8 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
|
||||
startTime := time.Now()
|
||||
log.Debugf("Populating in-memory channel graph, this might take a " +
|
||||
"while...")
|
||||
err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error {
|
||||
return g.graphCache.AddNode(tx, &graphCacheNode{node})
|
||||
err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
return g.graphCache.AddNode(tx, node)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -468,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
|
||||
return kvdb.View(c.db, traversal, func() {})
|
||||
}
|
||||
|
||||
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
|
||||
// graph, executing the passed callback with each node encountered. If the
|
||||
// callback returns an error, then the transaction is aborted and the iteration
|
||||
// stops early.
|
||||
func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
|
||||
GraphCacheNode) error) error {
|
||||
|
||||
traversal := func(tx kvdb.RTx) error {
|
||||
// First grab the nodes bucket which stores the mapping from
|
||||
// pubKey to node information.
|
||||
nodes := tx.ReadBucket(nodeBucket)
|
||||
if nodes == nil {
|
||||
return ErrGraphNotFound
|
||||
}
|
||||
|
||||
cacheableNode := newGraphCacheNode(route.Vertex{}, nil)
|
||||
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
|
||||
// If this is the source key, then we skip this
|
||||
// iteration as the value for this key is a pubKey
|
||||
// rather than raw node information.
|
||||
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nodeReader := bytes.NewReader(nodeBytes)
|
||||
err := deserializeLightningNodeCacheable(
|
||||
nodeReader, cacheableNode,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the callback, the transaction will abort if
|
||||
// this returns an error.
|
||||
return cb(tx, cacheableNode)
|
||||
})
|
||||
}
|
||||
|
||||
return kvdb.View(c.db, traversal, func() {})
|
||||
}
|
||||
|
||||
// SourceNode returns the source node of the graph. The source node is treated
|
||||
// as the center node within a star-graph. This method may be used to kick off
|
||||
// a path finding algorithm in order to explore the reachability of another
|
||||
@ -559,8 +600,10 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode,
|
||||
|
||||
r := &batch.Request{
|
||||
Update: func(tx kvdb.RwTx) error {
|
||||
wNode := &graphCacheNode{node}
|
||||
if err := c.graphCache.AddNode(tx, wNode); err != nil {
|
||||
cNode := newGraphCacheNode(
|
||||
node.PubKeyBytes, node.Features,
|
||||
)
|
||||
if err := c.graphCache.AddNode(tx, cNode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -2532,17 +2575,30 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
|
||||
// graphCacheNode is a struct that wraps a LightningNode in a way that it can be
|
||||
// cached in the graph cache.
|
||||
type graphCacheNode struct {
|
||||
lnNode *LightningNode
|
||||
pubKeyBytes route.Vertex
|
||||
features *lnwire.FeatureVector
|
||||
|
||||
nodeScratch [8]byte
|
||||
}
|
||||
|
||||
// newGraphCacheNode returns a new cache optimized node.
|
||||
func newGraphCacheNode(pubKey route.Vertex,
|
||||
features *lnwire.FeatureVector) *graphCacheNode {
|
||||
|
||||
return &graphCacheNode{
|
||||
pubKeyBytes: pubKey,
|
||||
features: features,
|
||||
}
|
||||
}
|
||||
|
||||
// PubKey returns the node's public identity key.
|
||||
func (w *graphCacheNode) PubKey() route.Vertex {
|
||||
return w.lnNode.PubKeyBytes
|
||||
func (n *graphCacheNode) PubKey() route.Vertex {
|
||||
return n.pubKeyBytes
|
||||
}
|
||||
|
||||
// Features returns the node's features.
|
||||
func (w *graphCacheNode) Features() *lnwire.FeatureVector {
|
||||
return w.lnNode.Features
|
||||
func (n *graphCacheNode) Features() *lnwire.FeatureVector {
|
||||
return n.features
|
||||
}
|
||||
|
||||
// ForEachChannel iterates through all channels of this node, executing the
|
||||
@ -2553,11 +2609,11 @@ func (w *graphCacheNode) Features() *lnwire.FeatureVector {
|
||||
// halted with the error propagated back up to the caller.
|
||||
//
|
||||
// Unknown policies are passed into the callback as nil values.
|
||||
func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx,
|
||||
func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx,
|
||||
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
|
||||
*ChannelEdgePolicy) error) error {
|
||||
|
||||
return w.lnNode.ForEachChannel(tx, cb)
|
||||
return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb)
|
||||
}
|
||||
|
||||
var _ GraphCacheNode = (*graphCacheNode)(nil)
|
||||
@ -3865,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket,
|
||||
return deserializeLightningNode(nodeReader)
|
||||
}
|
||||
|
||||
func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error {
|
||||
// Always populate a feature vector, even if we don't have a node
|
||||
// announcement and short circuit below.
|
||||
node.features = lnwire.EmptyFeatureVector()
|
||||
|
||||
// Skip ahead:
|
||||
// - LastUpdate (8 bytes)
|
||||
if _, err := r.Read(node.nodeScratch[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read the node announcement flag.
|
||||
if _, err := r.Read(node.nodeScratch[:2]); err != nil {
|
||||
return err
|
||||
}
|
||||
hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2])
|
||||
|
||||
// The rest of the data is optional, and will only be there if we got a
|
||||
// node announcement for this node.
|
||||
if hasNodeAnn == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We did get a node announcement for this node, so we'll have the rest
|
||||
// of the data available.
|
||||
var rgb uint8
|
||||
if err := binary.Read(r, byteOrder, &rgb); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &rgb); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &rgb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := wire.ReadVarString(r, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return node.features.Decode(r)
|
||||
}
|
||||
|
||||
func deserializeLightningNode(r io.Reader) (LightningNode, error) {
|
||||
var (
|
||||
node LightningNode
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
@ -1145,6 +1146,58 @@ func TestGraphTraversal(t *testing.T) {
|
||||
require.Equal(t, numChannels, numNodeChans)
|
||||
}
|
||||
|
||||
// TestGraphTraversalCacheable tests that the memory optimized node traversal is
|
||||
// working correctly.
|
||||
func TestGraphTraversalCacheable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
graph, cleanUp, err := MakeTestGraph()
|
||||
defer cleanUp()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test database: %v", err)
|
||||
}
|
||||
|
||||
// We'd like to test some of the graph traversal capabilities within
|
||||
// the DB, so we'll create a series of fake nodes to insert into the
|
||||
// graph. And we'll create 5 channels between the first two nodes.
|
||||
const numNodes = 20
|
||||
const numChannels = 5
|
||||
chanIndex, _ := fillTestGraph(t, graph, numNodes, numChannels)
|
||||
|
||||
// Create a map of all nodes with the iteration we know works (because
|
||||
// it is tested in another test).
|
||||
nodeMap := make(map[route.Vertex]struct{})
|
||||
err = graph.ForEachNode(func(tx kvdb.RTx, n *LightningNode) error {
|
||||
nodeMap[n.PubKeyBytes] = struct{}{}
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodeMap, numNodes)
|
||||
|
||||
// Iterate through all the known channels within the graph DB by
|
||||
// iterating over each node, once again if the map is empty that
|
||||
// indicates that all edges have properly been reached.
|
||||
err = graph.ForEachNodeCacheable(
|
||||
func(tx kvdb.RTx, node GraphCacheNode) error {
|
||||
delete(nodeMap, node.PubKey())
|
||||
|
||||
return node.ForEachChannel(
|
||||
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
|
||||
policy *ChannelEdgePolicy,
|
||||
policy2 *ChannelEdgePolicy) error {
|
||||
|
||||
delete(chanIndex, info.ChannelID)
|
||||
return nil
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodeMap, 0)
|
||||
require.Len(t, chanIndex, 0)
|
||||
}
|
||||
|
||||
func TestGraphCacheTraversal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -1164,6 +1217,8 @@ func TestGraphCacheTraversal(t *testing.T) {
|
||||
// properly been reached.
|
||||
numNodeChans := 0
|
||||
for _, node := range nodeList {
|
||||
node := node
|
||||
|
||||
err = graph.graphCache.ForEachChannel(
|
||||
node.PubKeyBytes, func(d *DirectedChannel) error {
|
||||
delete(chanIndex, d.ChannelID)
|
||||
@ -1197,7 +1252,7 @@ func TestGraphCacheTraversal(t *testing.T) {
|
||||
require.Equal(t, numChannels*2*(numNodes-1), numNodeChans)
|
||||
}
|
||||
|
||||
func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes,
|
||||
func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes,
|
||||
numChannels int) (map[uint64]struct{}, []*LightningNode) {
|
||||
|
||||
nodes := make([]*LightningNode, numNodes)
|
||||
@ -1237,7 +1292,7 @@ func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes,
|
||||
|
||||
for i := 0; i < numChannels; i++ {
|
||||
txHash := sha256.Sum256([]byte{byte(i)})
|
||||
chanID := uint64((n << 4) + i + 1)
|
||||
chanID := uint64((n << 8) + i + 1)
|
||||
op := wire.OutPoint{
|
||||
Hash: txHash,
|
||||
Index: 0,
|
||||
@ -3592,3 +3647,47 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkForEachChannel is a benchmark test that measures the number of
|
||||
// allocations and the total memory consumed by the full graph traversal.
|
||||
func BenchmarkForEachChannel(b *testing.B) {
|
||||
graph, cleanUp, err := MakeTestGraph()
|
||||
require.Nil(b, err)
|
||||
defer cleanUp()
|
||||
|
||||
const numNodes = 100
|
||||
const numChannels = 4
|
||||
_, _ = fillTestGraph(b, graph, numNodes, numChannels)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var (
|
||||
totalCapacity btcutil.Amount
|
||||
maxHTLCs lnwire.MilliSatoshi
|
||||
)
|
||||
err := graph.ForEachNodeCacheable(
|
||||
func(tx kvdb.RTx, n GraphCacheNode) error {
|
||||
return n.ForEachChannel(
|
||||
tx, func(tx kvdb.RTx,
|
||||
info *ChannelEdgeInfo,
|
||||
policy *ChannelEdgePolicy,
|
||||
policy2 *ChannelEdgePolicy) error {
|
||||
|
||||
// We need to do something with
|
||||
// the data here, otherwise the
|
||||
// compiler is going to optimize
|
||||
// this away, and we get bogus
|
||||
// results.
|
||||
totalCapacity += info.Capacity
|
||||
maxHTLCs += policy.MaxHTLC
|
||||
maxHTLCs += policy2.MaxHTLC
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user