diff --git a/autopilot/graph.go b/autopilot/graph.go index 83447af9b..2ce49c127 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -89,7 +89,7 @@ func (d *dbNode) Addrs() []net.Addr { // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { - return d.db.ForEachNodeChannel(d.tx, d.node.PubKeyBytes, + return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes, func(tx kvdb.RTx, ei *models.ChannelEdgeInfo, ep, _ *models.ChannelEdgePolicy) error { diff --git a/channeldb/graph.go b/channeldb/graph.go index a37ea9682..414672166 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -565,7 +565,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, return c.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { channels := make(map[uint64]*DirectedChannel) - err := c.ForEachNodeChannel(tx, node.PubKeyBytes, + err := c.ForEachNodeChannelTx(tx, node.PubKeyBytes, func(tx kvdb.RTx, e *models.ChannelEdgeInfo, p1 *models.ChannelEdgePolicy, p2 *models.ChannelEdgePolicy) error { @@ -2931,7 +2931,7 @@ func (c *ChannelGraph) isPublic(tx kvdb.RTx, nodePub route.Vertex, // used to terminate the check early. nodeIsPublic := false errDone := errors.New("done") - err := c.ForEachNodeChannel(tx, nodePub, func(tx kvdb.RTx, + err := c.ForEachNodeChannelTx(tx, nodePub, func(tx kvdb.RTx, info *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -3224,13 +3224,29 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. +func (c *ChannelGraph) ForEachNodeChannel(nodePub route.Vertex, + cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { + + return nodeTraversal(nil, nodePub[:], c.db, cb) +} + +// ForEachNodeChannelTx iterates through all channels of the given node, +// executing the passed callback with an edge info structure and the policies +// of each end of the channel. The first edge policy is the outgoing edge *to* +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. // // If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should +// should be passed as the first argument. Otherwise, the first argument should // be nil and a fresh transaction will be created to execute the graph // traversal. -func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub route.Vertex, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, +func (c *ChannelGraph) ForEachNodeChannelTx(tx kvdb.RTx, + nodePub route.Vertex, cb func(kvdb.RTx, *models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { return nodeTraversal(tx, nodePub[:], c.db, cb) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 46bc0d3fd..256851238 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1055,7 +1055,7 @@ func TestGraphTraversal(t *testing.T) { // outgoing channels for a particular node. numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] - err = graph.ForEachNodeChannel(nil, firstNode.PubKeyBytes, + err = graph.ForEachNodeChannel(firstNode.PubKeyBytes, func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { @@ -2737,7 +2737,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { // Ensure that channel is reported with unknown policies. checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { calls := 0 - err := graph.ForEachNodeChannel(nil, node.PubKeyBytes, + err := graph.ForEachNodeChannel(node.PubKeyBytes, func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { diff --git a/routing/router.go b/routing/router.go index c91e0b187..33f5a7814 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2937,7 +2937,7 @@ func (r *ChannelRouter) ForEachNode( func (r *ChannelRouter) ForAllOutgoingChannels(cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { - return r.cfg.Graph.ForEachNodeChannel(nil, r.cfg.SelfNode, + return r.cfg.Graph.ForEachNodeChannel(r.cfg.SelfNode, func(tx kvdb.RTx, c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { diff --git a/rpcserver.go b/rpcserver.go index 2012e75d9..011674e6c 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6361,7 +6361,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, channels []*lnrpc.ChannelEdge ) - err = graph.ForEachNodeChannel(nil, node.PubKeyBytes, + err = graph.ForEachNodeChannel(node.PubKeyBytes, func(_ kvdb.RTx, edge *models.ChannelEdgeInfo, c1, c2 *models.ChannelEdgePolicy) error { @@ -7014,7 +7014,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, } var feeReports []*lnrpc.ChannelFeeReport - err = channelGraph.ForEachNodeChannel(nil, selfNode.PubKeyBytes, + err = channelGraph.ForEachNodeChannel(selfNode.PubKeyBytes, func(_ kvdb.RTx, chanInfo *models.ChannelEdgeInfo, edgePolicy, _ *models.ChannelEdgePolicy) error { diff --git a/server.go b/server.go index 555513406..f32fb611e 100644 --- a/server.go +++ b/server.go @@ -3119,7 +3119,7 @@ func (s *server) establishPersistentConnections() error { // TODO(roasbeef): instead iterate over link nodes and query graph for // each of the nodes. selfPub := s.identityECDH.PubKey().SerializeCompressed() - err = s.graphDB.ForEachNodeChannel(nil, sourceNode.PubKeyBytes, func( + err = s.graphDB.ForEachNodeChannel(sourceNode.PubKeyBytes, func( tx kvdb.RTx, chanInfo *models.ChannelEdgeInfo, policy, _ *models.ChannelEdgePolicy) error {