channeldb: add new ForEachNode method using the channel graph cache

This commit, adds a new ForEachNode method to the channel graph cache
that assumes the contents won't be modified. This is generally useful,
and will be used in a later commit to optimize some heavy RPC calls.
This commit is contained in:
Olaoluwa Osuntokun 2021-10-19 16:04:23 -07:00
parent a6f22c6185
commit fae470293f
No known key found for this signature in database
GPG Key ID: 3BBD59E99B280306
4 changed files with 136 additions and 0 deletions

View File

@ -458,6 +458,72 @@ func (c *ChannelGraph) FetchNodeFeatures(
}
}
// ForEachNodeCached is similar to ForEachNode, but it utilizes the channel
// graph cache instead. Note that this doesn't return all the information the
// regular ForEachNode method does.
//
// NOTE: The callback contents MUST not be modified.
func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex,
chans map[uint64]*DirectedChannel) error) error {
if c.graphCache != nil {
return c.graphCache.ForEachNode(cb)
}
// Otherwise call back to a version that uses the database directly.
// We'll iterate over each node, then the set of channels for each
// node, and construct a similar callback functiopn signature as the
// main funcotin expects.
return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error {
channels := make(map[uint64]*DirectedChannel)
err := node.ForEachChannel(tx, func(tx kvdb.RTx,
e *ChannelEdgeInfo, p1 *ChannelEdgePolicy,
p2 *ChannelEdgePolicy) error {
toNodeCallback := func() route.Vertex {
return node.PubKeyBytes
}
toNodeFeatures, err := c.FetchNodeFeatures(
node.PubKeyBytes,
)
if err != nil {
return err
}
var cachedInPolicy *CachedEdgePolicy
if p2 != nil {
cachedInPolicy := NewCachedPolicy(p2)
cachedInPolicy.ToNodePubKey = toNodeCallback
cachedInPolicy.ToNodeFeatures = toNodeFeatures
}
directedChannel := &DirectedChannel{
ChannelID: e.ChannelID,
IsNode1: node.PubKeyBytes == e.NodeKey1Bytes,
OtherNode: e.NodeKey2Bytes,
Capacity: e.Capacity,
OutPolicySet: p1 != nil,
InPolicy: cachedInPolicy,
}
if node.PubKeyBytes == e.NodeKey2Bytes {
directedChannel.OtherNode = e.NodeKey1Bytes
}
channels[e.ChannelID] = directedChannel
return nil
})
if err != nil {
return err
}
return cb(node.PubKeyBytes, channels)
})
}
// DisabledChannelIDs returns the channel ids of disabled channels.
// A channel is disabled when two of the associated ChanelEdgePolicies
// have their disabled bit on.

View File

@ -447,6 +447,30 @@ func (c *GraphCache) ForEachChannel(node route.Vertex,
return nil
}
// ForEachNode iterates over the adjacency list of the graph, executing the
// call back for each node and the set of channels that emanate from the given
// node.
//
// NOTE: This method should be considered _read only_, the channels or nodes
// passed in MUST NOT be modified.
func (c *GraphCache) ForEachNode(cb func(node route.Vertex,
channels map[uint64]*DirectedChannel) error) error {
c.mtx.RLock()
defer c.mtx.RUnlock()
for node, channels := range c.nodeChannels {
// We don't make a copy here since this is a read-only RPC
// call. We also don't need the node features either for this
// call.
if err := cb(node, channels); err != nil {
return err
}
}
return nil
}
// GetFeatures returns the features of the node with the given ID. If no
// features are known for the node, an empty feature vector is returned.
func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector {

View File

@ -120,6 +120,24 @@ func TestGraphCacheAddNode(t *testing.T) {
require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet)
assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy)
// Now that we've inserted two nodes into the graph, check that
// we'll recover the same set of channels during ForEachNode.
nodes := make(map[route.Vertex]struct{})
chans := make(map[uint64]struct{})
_ = cache.ForEachNode(func(node route.Vertex,
edges map[uint64]*DirectedChannel) error {
nodes[node] = struct{}{}
for chanID := range edges {
chans[chanID] = struct{}{}
}
return nil
})
require.Len(t, nodes, 2)
require.Len(t, chans, 1)
}
runTest(pubKey1, pubKey2)

View File

@ -1103,6 +1103,34 @@ func TestGraphTraversal(t *testing.T) {
const numChannels = 5
chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels)
// Make an index of the node list for easy look up below.
nodeIndex := make(map[route.Vertex]struct{})
for _, node := range nodeList {
nodeIndex[node.PubKeyBytes] = struct{}{}
}
// If we turn the channel graph cache _off_, then iterate through the
// set of channels (to force the fall back), we should find all the
// channel as well as the nodes included.
graph.graphCache = nil
err = graph.ForEachNodeCached(func(node route.Vertex,
chans map[uint64]*DirectedChannel) error {
if _, ok := nodeIndex[node]; !ok {
return fmt.Errorf("node %x not found in graph", node)
}
for chanID := range chans {
if _, ok := chanIndex[chanID]; !ok {
return fmt.Errorf("chan %v not found in "+
"graph", chanID)
}
}
return nil
})
require.NoError(t, err)
// Iterate through all the known channels within the graph DB, once
// again if the map is empty that indicates that all edges have
// properly been reached.