pygossmap: adds get_neighbors and get_neighbors_hc flodding method

This commit is contained in:
Michael Schmoock 2023-02-18 00:01:53 +01:00 committed by Rusty Russell
parent 5a9a3d83c9
commit 9409f2f1ea
2 changed files with 206 additions and 1 deletions

View file

@ -3,7 +3,7 @@
from pyln.spec.bolt7 import (channel_announcement, channel_update, from pyln.spec.bolt7 import (channel_announcement, channel_update,
node_announcement) node_announcement)
from pyln.proto import ShortChannelId, PublicKey from pyln.proto import ShortChannelId, PublicKey
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Set, Optional, Union
import io import io
import struct import struct
@ -273,12 +273,97 @@ class Gossmap(object):
channel = self.get_channel(short_channel_id) channel = self.get_channel(short_channel_id)
return channel.half_channels[direction] return channel.half_channels[direction]
def get_neighbors_hc(self,
source: Union[GossmapNodeId, str, None] = None,
destination: Union[GossmapNodeId, str, None] = None,
depth: int = 0,
excludes: Union[Set[Any], List[Any]] = set()):
""" Returns a set[GossmapHalfchannel]` from `source` or towards
`destination` node ID. Using the optional `depth` greater than `0`
will result in a second, third, ... order list of connected
channels towards or from that node.
Note: only one of `source` or `destination` can be given. """
assert (source is None) ^ (destination is None), "Only one of source or destination must be given"
assert depth >= 0, "Depth cannot be smaller than 0"
node = self.get_node(source if source else destination)
assert node is not None, "source or destination unknown"
if isinstance(excludes, List):
excludes = set(excludes)
# first get set of reachable nodes ...
reachable = self.get_neighbors(source, destination, depth, excludes)
# and iterate and check any each source/dest channel from here
result = set()
for node in reachable:
for channel in node.channels:
if channel in excludes:
continue
other = channel.node1 if node != channel.node1 else channel.node2
if other in reachable or other in excludes:
continue
direction = 0
if source is not None and node > other:
direction = 1
if destination is not None and node < other:
direction = 1
hc = channel.half_channels[direction]
# skip excluded or non existent halfchannels
if hc is None or hc in excludes:
continue
result.add(hc)
return result
def get_node(self, node_id: Union[GossmapNodeId, str]): def get_node(self, node_id: Union[GossmapNodeId, str]):
""" Resolves a node by its public key node_id """ """ Resolves a node by its public key node_id """
if isinstance(node_id, str): if isinstance(node_id, str):
node_id = GossmapNodeId.from_str(node_id) node_id = GossmapNodeId.from_str(node_id)
return self.nodes.get(node_id) return self.nodes.get(node_id)
def get_neighbors(self,
source: Union[GossmapNodeId, str, None] = None,
destination: Union[GossmapNodeId, str, None] = None,
depth: int = 0,
excludes: Union[Set[Any], List[Any]] = set()):
""" Returns a set of nodes within a given depth from a source node """
assert (source is None) ^ (destination is None), "Only one of source or destination must be given"
assert depth >= 0, "Depth cannot be smaller than 0"
node = self.get_node(source if source else destination)
assert node is not None, "source or destination unknown"
if isinstance(excludes, List):
excludes = set(excludes)
result = set()
result.add(node)
inner = set()
inner.add(node)
while depth > 0:
shell = set()
for node in inner:
for channel in node.channels:
if channel in excludes: # skip excluded channels
continue
other = channel.node1 if channel.node1 != node else channel.node2
direction = 0
if source is not None and node > other:
direction = 1
if destination is not None and node < other:
direction = 1
if channel.half_channels[direction] is None:
continue # one way channel in the wrong direction
halfchannel = channel.half_channels[direction]
if halfchannel in excludes: # skip excluded halfchannels
continue
# skip excluded or already seen nodes
if other in excludes or other in inner or other in result:
continue
shell.add(other)
if len(shell) == 0:
break
depth -= 1
result.update(shell)
inner = shell
return result
def _update_channel(self, rec: bytes, hdr: GossipStoreHeader): def _update_channel(self, rec: bytes, hdr: GossipStoreHeader):
fields = channel_update.read(io.BytesIO(rec[2:]), {}) fields = channel_update.read(io.BytesIO(rec[2:]), {})
direction = fields['channel_flags'] & 1 direction = fields['channel_flags'] & 1

View file

@ -159,6 +159,8 @@ def test_mesh(tmp_path):
scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56, scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56,
scid58, scid69, scid78, scid89] scid58, scid69, scid78, scid89]
nodes = [g.get_node(nid) for nid in nodeids]
# check all nodes are there # check all nodes are there
for nodeid in nodeids: for nodeid in nodeids:
node = g.get_node(nodeid) node = g.get_node(nodeid)
@ -174,3 +176,121 @@ def test_mesh(tmp_path):
assert str(channel.scid) == scid assert str(channel.scid) == scid
assert channel.half_channels[0] assert channel.half_channels[0]
assert channel.half_channels[1] assert channel.half_channels[1]
# check basic relations
# get_neighbors l5 in the middle depth=0 returns just that node
result = g.get_neighbors(source=nodeids[4])
assert len(result) == 1
assert str(next(iter(result)).node_id) == nodeids[4]
result = g.get_neighbors(source=nodeids[4], depth=1)
assert len(result) == 5
# on depth=1 the cross l2, l4, l5, l6, l8 must be returned
assert nodes[1] in result
assert nodes[3] in result
assert nodes[4] in result
assert nodes[5] in result
assert nodes[7] in result
# on depth>=2 all nodes must be returned as we visited the whole graph
for d in range(2, 4):
result = g.get_neighbors(source=nodeids[4], depth=d)
assert len(result) == 9
for node in nodes:
assert node in result
# get_neighbors on l9 with depth=3 must return all but l1
result = g.get_neighbors(nodeids[8], depth=3)
assert len(result) == 8
assert nodes[0] not in result
# get_neighbors on l9 with depth=4 and excludes l5 must return all but l5
result = g.get_neighbors(nodeids[8], depth=4, excludes=[nodes[4]])
assert len(result) == 8
assert nodes[4] not in result
# get_neighbors_hc l5 in the middle expect: 25, 45, 65 and 85
result = g.get_neighbors_hc(source=nodeids[4])
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
exp_scidds = [scid25 + '/1', scid45 + '/0', scid56 + '/1', scid58 + '/0']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.source.node_id) == nodeids[4]
assert str(halfchan.destination.node_id) in exp_ids
assert str(halfchan) in exp_scidds
# same but other direction
result = g.get_neighbors_hc(destination=nodeids[4])
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
exp_scidds = [scid25 + '/0', scid45 + '/1', scid56 + '/0', scid58 + '/1']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.destination.node_id) == nodeids[4]
assert str(halfchan.source.node_id) in exp_ids
assert str(halfchan) in exp_scidds
# get all channels which have l1 as destination
result = g.get_neighbors_hc(destination=nodeids[0])
exp_ids = [nodeids[1], nodeids[3]]
exp_scidds = [scid12 + '/0', scid14 + '/1']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.destination.node_id) == nodeids[0]
assert str(halfchan.source.node_id) in exp_ids
assert str(halfchan) in exp_scidds
# l5 as destination in the middle but depth=1, so the outer ring
# epxect: 12, 14, 32, 36, 74, 78, 98, 96
result = g.get_neighbors_hc(destination=nodeids[4], depth=1)
exp_scidds = [scid12 + '/1', scid14 + '/0', scid23 + '/1', scid36 + '/1',
scid47 + '/0', scid69 + '/1', scid78 + '/0', scid89 + '/0']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# same but other direction
result = g.get_neighbors_hc(source=nodeids[4], depth=1)
exp_scidds = [scid12 + '/0', scid14 + '/1', scid23 + '/0', scid36 + '/0',
scid47 + '/1', scid69 + '/0', scid78 + '/1', scid89 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination and depth=2 expect: 23 25 45 47
result = g.get_neighbors_hc(destination=nodeids[8], depth=2)
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1', scid47 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination depth=2 exclude=[l7] expect: 23 25 45
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=[nodes[6]])
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# same as above, but excludes halfchannels of l7 expect: 23 25 45
hcs = [c.half_channels[0] for c in nodes[6].channels]
hcs += [c.half_channels[1] for c in nodes[6].channels]
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=hcs)
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# again, same as above, but excludes channels of l7 expect: 23 25 45
chs = [c for c in nodes[6].channels]
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=chs)
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination and depth=3 expect: 12 14
result = g.get_neighbors_hc(destination=nodeids[8], depth=3)
exp_scidds = [scid12 + '/1', scid14 + '/0']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds
# l9 as destination and depth>=4 expect: empty set
for d in range(4, 6):
result = g.get_neighbors_hc(destination=nodeids[8], depth=d)
assert len(result) == 0