mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 21:35:24 +01:00
channeldb: add new ForEachNode method using the channel graph cache
This commit, adds a new ForEachNode method to the channel graph cache that assumes the contents won't be modified. This is generally useful, and will be used in a later commit to optimize some heavy RPC calls.
This commit is contained in:
parent
a6f22c6185
commit
fae470293f
@ -458,6 +458,72 @@ func (c *ChannelGraph) FetchNodeFeatures(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForEachNodeCached is similar to ForEachNode, but it utilizes the channel
|
||||||
|
// graph cache instead. Note that this doesn't return all the information the
|
||||||
|
// regular ForEachNode method does.
|
||||||
|
//
|
||||||
|
// NOTE: The callback contents MUST not be modified.
|
||||||
|
func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex,
|
||||||
|
chans map[uint64]*DirectedChannel) error) error {
|
||||||
|
|
||||||
|
if c.graphCache != nil {
|
||||||
|
return c.graphCache.ForEachNode(cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise call back to a version that uses the database directly.
|
||||||
|
// We'll iterate over each node, then the set of channels for each
|
||||||
|
// node, and construct a similar callback functiopn signature as the
|
||||||
|
// main funcotin expects.
|
||||||
|
return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error {
|
||||||
|
channels := make(map[uint64]*DirectedChannel)
|
||||||
|
|
||||||
|
err := node.ForEachChannel(tx, func(tx kvdb.RTx,
|
||||||
|
e *ChannelEdgeInfo, p1 *ChannelEdgePolicy,
|
||||||
|
p2 *ChannelEdgePolicy) error {
|
||||||
|
|
||||||
|
toNodeCallback := func() route.Vertex {
|
||||||
|
return node.PubKeyBytes
|
||||||
|
}
|
||||||
|
toNodeFeatures, err := c.FetchNodeFeatures(
|
||||||
|
node.PubKeyBytes,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var cachedInPolicy *CachedEdgePolicy
|
||||||
|
if p2 != nil {
|
||||||
|
cachedInPolicy := NewCachedPolicy(p2)
|
||||||
|
cachedInPolicy.ToNodePubKey = toNodeCallback
|
||||||
|
cachedInPolicy.ToNodeFeatures = toNodeFeatures
|
||||||
|
}
|
||||||
|
|
||||||
|
directedChannel := &DirectedChannel{
|
||||||
|
ChannelID: e.ChannelID,
|
||||||
|
IsNode1: node.PubKeyBytes == e.NodeKey1Bytes,
|
||||||
|
OtherNode: e.NodeKey2Bytes,
|
||||||
|
Capacity: e.Capacity,
|
||||||
|
OutPolicySet: p1 != nil,
|
||||||
|
InPolicy: cachedInPolicy,
|
||||||
|
}
|
||||||
|
|
||||||
|
if node.PubKeyBytes == e.NodeKey2Bytes {
|
||||||
|
directedChannel.OtherNode = e.NodeKey1Bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
channels[e.ChannelID] = directedChannel
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cb(node.PubKeyBytes, channels)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// DisabledChannelIDs returns the channel ids of disabled channels.
|
// DisabledChannelIDs returns the channel ids of disabled channels.
|
||||||
// A channel is disabled when two of the associated ChanelEdgePolicies
|
// A channel is disabled when two of the associated ChanelEdgePolicies
|
||||||
// have their disabled bit on.
|
// have their disabled bit on.
|
||||||
|
@ -447,6 +447,30 @@ func (c *GraphCache) ForEachChannel(node route.Vertex,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForEachNode iterates over the adjacency list of the graph, executing the
|
||||||
|
// call back for each node and the set of channels that emanate from the given
|
||||||
|
// node.
|
||||||
|
//
|
||||||
|
// NOTE: This method should be considered _read only_, the channels or nodes
|
||||||
|
// passed in MUST NOT be modified.
|
||||||
|
func (c *GraphCache) ForEachNode(cb func(node route.Vertex,
|
||||||
|
channels map[uint64]*DirectedChannel) error) error {
|
||||||
|
|
||||||
|
c.mtx.RLock()
|
||||||
|
defer c.mtx.RUnlock()
|
||||||
|
|
||||||
|
for node, channels := range c.nodeChannels {
|
||||||
|
// We don't make a copy here since this is a read-only RPC
|
||||||
|
// call. We also don't need the node features either for this
|
||||||
|
// call.
|
||||||
|
if err := cb(node, channels); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetFeatures returns the features of the node with the given ID. If no
|
// GetFeatures returns the features of the node with the given ID. If no
|
||||||
// features are known for the node, an empty feature vector is returned.
|
// features are known for the node, an empty feature vector is returned.
|
||||||
func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector {
|
func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector {
|
||||||
|
@ -120,6 +120,24 @@ func TestGraphCacheAddNode(t *testing.T) {
|
|||||||
|
|
||||||
require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet)
|
require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet)
|
||||||
assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy)
|
assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy)
|
||||||
|
|
||||||
|
// Now that we've inserted two nodes into the graph, check that
|
||||||
|
// we'll recover the same set of channels during ForEachNode.
|
||||||
|
nodes := make(map[route.Vertex]struct{})
|
||||||
|
chans := make(map[uint64]struct{})
|
||||||
|
_ = cache.ForEachNode(func(node route.Vertex,
|
||||||
|
edges map[uint64]*DirectedChannel) error {
|
||||||
|
|
||||||
|
nodes[node] = struct{}{}
|
||||||
|
for chanID := range edges {
|
||||||
|
chans[chanID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Len(t, nodes, 2)
|
||||||
|
require.Len(t, chans, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
runTest(pubKey1, pubKey2)
|
runTest(pubKey1, pubKey2)
|
||||||
|
@ -1103,6 +1103,34 @@ func TestGraphTraversal(t *testing.T) {
|
|||||||
const numChannels = 5
|
const numChannels = 5
|
||||||
chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels)
|
chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels)
|
||||||
|
|
||||||
|
// Make an index of the node list for easy look up below.
|
||||||
|
nodeIndex := make(map[route.Vertex]struct{})
|
||||||
|
for _, node := range nodeList {
|
||||||
|
nodeIndex[node.PubKeyBytes] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we turn the channel graph cache _off_, then iterate through the
|
||||||
|
// set of channels (to force the fall back), we should find all the
|
||||||
|
// channel as well as the nodes included.
|
||||||
|
graph.graphCache = nil
|
||||||
|
err = graph.ForEachNodeCached(func(node route.Vertex,
|
||||||
|
chans map[uint64]*DirectedChannel) error {
|
||||||
|
|
||||||
|
if _, ok := nodeIndex[node]; !ok {
|
||||||
|
return fmt.Errorf("node %x not found in graph", node)
|
||||||
|
}
|
||||||
|
|
||||||
|
for chanID := range chans {
|
||||||
|
if _, ok := chanIndex[chanID]; !ok {
|
||||||
|
return fmt.Errorf("chan %v not found in "+
|
||||||
|
"graph", chanID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Iterate through all the known channels within the graph DB, once
|
// Iterate through all the known channels within the graph DB, once
|
||||||
// again if the map is empty that indicates that all edges have
|
// again if the map is empty that indicates that all edges have
|
||||||
// properly been reached.
|
// properly been reached.
|
||||||
|
Loading…
Reference in New Issue
Block a user