mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-02-23 15:00:34 +01:00
pygossmap: adds get_neighbors and get_neighbors_hc flodding method
This commit is contained in:
parent
5a9a3d83c9
commit
9409f2f1ea
2 changed files with 206 additions and 1 deletions
|
@ -3,7 +3,7 @@
|
||||||
from pyln.spec.bolt7 import (channel_announcement, channel_update,
|
from pyln.spec.bolt7 import (channel_announcement, channel_update,
|
||||||
node_announcement)
|
node_announcement)
|
||||||
from pyln.proto import ShortChannelId, PublicKey
|
from pyln.proto import ShortChannelId, PublicKey
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Set, Optional, Union
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import struct
|
import struct
|
||||||
|
@ -273,12 +273,97 @@ class Gossmap(object):
|
||||||
channel = self.get_channel(short_channel_id)
|
channel = self.get_channel(short_channel_id)
|
||||||
return channel.half_channels[direction]
|
return channel.half_channels[direction]
|
||||||
|
|
||||||
|
def get_neighbors_hc(self,
|
||||||
|
source: Union[GossmapNodeId, str, None] = None,
|
||||||
|
destination: Union[GossmapNodeId, str, None] = None,
|
||||||
|
depth: int = 0,
|
||||||
|
excludes: Union[Set[Any], List[Any]] = set()):
|
||||||
|
""" Returns a set[GossmapHalfchannel]` from `source` or towards
|
||||||
|
`destination` node ID. Using the optional `depth` greater than `0`
|
||||||
|
will result in a second, third, ... order list of connected
|
||||||
|
channels towards or from that node.
|
||||||
|
Note: only one of `source` or `destination` can be given. """
|
||||||
|
assert (source is None) ^ (destination is None), "Only one of source or destination must be given"
|
||||||
|
assert depth >= 0, "Depth cannot be smaller than 0"
|
||||||
|
node = self.get_node(source if source else destination)
|
||||||
|
assert node is not None, "source or destination unknown"
|
||||||
|
if isinstance(excludes, List):
|
||||||
|
excludes = set(excludes)
|
||||||
|
|
||||||
|
# first get set of reachable nodes ...
|
||||||
|
reachable = self.get_neighbors(source, destination, depth, excludes)
|
||||||
|
# and iterate and check any each source/dest channel from here
|
||||||
|
result = set()
|
||||||
|
for node in reachable:
|
||||||
|
for channel in node.channels:
|
||||||
|
if channel in excludes:
|
||||||
|
continue
|
||||||
|
other = channel.node1 if node != channel.node1 else channel.node2
|
||||||
|
if other in reachable or other in excludes:
|
||||||
|
continue
|
||||||
|
direction = 0
|
||||||
|
if source is not None and node > other:
|
||||||
|
direction = 1
|
||||||
|
if destination is not None and node < other:
|
||||||
|
direction = 1
|
||||||
|
hc = channel.half_channels[direction]
|
||||||
|
# skip excluded or non existent halfchannels
|
||||||
|
if hc is None or hc in excludes:
|
||||||
|
continue
|
||||||
|
result.add(hc)
|
||||||
|
return result
|
||||||
|
|
||||||
def get_node(self, node_id: Union[GossmapNodeId, str]):
|
def get_node(self, node_id: Union[GossmapNodeId, str]):
|
||||||
""" Resolves a node by its public key node_id """
|
""" Resolves a node by its public key node_id """
|
||||||
if isinstance(node_id, str):
|
if isinstance(node_id, str):
|
||||||
node_id = GossmapNodeId.from_str(node_id)
|
node_id = GossmapNodeId.from_str(node_id)
|
||||||
return self.nodes.get(node_id)
|
return self.nodes.get(node_id)
|
||||||
|
|
||||||
|
def get_neighbors(self,
|
||||||
|
source: Union[GossmapNodeId, str, None] = None,
|
||||||
|
destination: Union[GossmapNodeId, str, None] = None,
|
||||||
|
depth: int = 0,
|
||||||
|
excludes: Union[Set[Any], List[Any]] = set()):
|
||||||
|
""" Returns a set of nodes within a given depth from a source node """
|
||||||
|
assert (source is None) ^ (destination is None), "Only one of source or destination must be given"
|
||||||
|
assert depth >= 0, "Depth cannot be smaller than 0"
|
||||||
|
node = self.get_node(source if source else destination)
|
||||||
|
assert node is not None, "source or destination unknown"
|
||||||
|
if isinstance(excludes, List):
|
||||||
|
excludes = set(excludes)
|
||||||
|
|
||||||
|
result = set()
|
||||||
|
result.add(node)
|
||||||
|
inner = set()
|
||||||
|
inner.add(node)
|
||||||
|
while depth > 0:
|
||||||
|
shell = set()
|
||||||
|
for node in inner:
|
||||||
|
for channel in node.channels:
|
||||||
|
if channel in excludes: # skip excluded channels
|
||||||
|
continue
|
||||||
|
other = channel.node1 if channel.node1 != node else channel.node2
|
||||||
|
direction = 0
|
||||||
|
if source is not None and node > other:
|
||||||
|
direction = 1
|
||||||
|
if destination is not None and node < other:
|
||||||
|
direction = 1
|
||||||
|
if channel.half_channels[direction] is None:
|
||||||
|
continue # one way channel in the wrong direction
|
||||||
|
halfchannel = channel.half_channels[direction]
|
||||||
|
if halfchannel in excludes: # skip excluded halfchannels
|
||||||
|
continue
|
||||||
|
# skip excluded or already seen nodes
|
||||||
|
if other in excludes or other in inner or other in result:
|
||||||
|
continue
|
||||||
|
shell.add(other)
|
||||||
|
if len(shell) == 0:
|
||||||
|
break
|
||||||
|
depth -= 1
|
||||||
|
result.update(shell)
|
||||||
|
inner = shell
|
||||||
|
return result
|
||||||
|
|
||||||
def _update_channel(self, rec: bytes, hdr: GossipStoreHeader):
|
def _update_channel(self, rec: bytes, hdr: GossipStoreHeader):
|
||||||
fields = channel_update.read(io.BytesIO(rec[2:]), {})
|
fields = channel_update.read(io.BytesIO(rec[2:]), {})
|
||||||
direction = fields['channel_flags'] & 1
|
direction = fields['channel_flags'] & 1
|
||||||
|
|
|
@ -159,6 +159,8 @@ def test_mesh(tmp_path):
|
||||||
scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56,
|
scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56,
|
||||||
scid58, scid69, scid78, scid89]
|
scid58, scid69, scid78, scid89]
|
||||||
|
|
||||||
|
nodes = [g.get_node(nid) for nid in nodeids]
|
||||||
|
|
||||||
# check all nodes are there
|
# check all nodes are there
|
||||||
for nodeid in nodeids:
|
for nodeid in nodeids:
|
||||||
node = g.get_node(nodeid)
|
node = g.get_node(nodeid)
|
||||||
|
@ -174,3 +176,121 @@ def test_mesh(tmp_path):
|
||||||
assert str(channel.scid) == scid
|
assert str(channel.scid) == scid
|
||||||
assert channel.half_channels[0]
|
assert channel.half_channels[0]
|
||||||
assert channel.half_channels[1]
|
assert channel.half_channels[1]
|
||||||
|
|
||||||
|
# check basic relations
|
||||||
|
# get_neighbors l5 in the middle depth=0 returns just that node
|
||||||
|
result = g.get_neighbors(source=nodeids[4])
|
||||||
|
assert len(result) == 1
|
||||||
|
assert str(next(iter(result)).node_id) == nodeids[4]
|
||||||
|
result = g.get_neighbors(source=nodeids[4], depth=1)
|
||||||
|
assert len(result) == 5
|
||||||
|
# on depth=1 the cross l2, l4, l5, l6, l8 must be returned
|
||||||
|
assert nodes[1] in result
|
||||||
|
assert nodes[3] in result
|
||||||
|
assert nodes[4] in result
|
||||||
|
assert nodes[5] in result
|
||||||
|
assert nodes[7] in result
|
||||||
|
# on depth>=2 all nodes must be returned as we visited the whole graph
|
||||||
|
for d in range(2, 4):
|
||||||
|
result = g.get_neighbors(source=nodeids[4], depth=d)
|
||||||
|
assert len(result) == 9
|
||||||
|
for node in nodes:
|
||||||
|
assert node in result
|
||||||
|
# get_neighbors on l9 with depth=3 must return all but l1
|
||||||
|
result = g.get_neighbors(nodeids[8], depth=3)
|
||||||
|
assert len(result) == 8
|
||||||
|
assert nodes[0] not in result
|
||||||
|
# get_neighbors on l9 with depth=4 and excludes l5 must return all but l5
|
||||||
|
result = g.get_neighbors(nodeids[8], depth=4, excludes=[nodes[4]])
|
||||||
|
assert len(result) == 8
|
||||||
|
assert nodes[4] not in result
|
||||||
|
|
||||||
|
# get_neighbors_hc l5 in the middle expect: 25, 45, 65 and 85
|
||||||
|
result = g.get_neighbors_hc(source=nodeids[4])
|
||||||
|
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
|
||||||
|
exp_scidds = [scid25 + '/1', scid45 + '/0', scid56 + '/1', scid58 + '/0']
|
||||||
|
assert len(result) == len(exp_ids)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan.source.node_id) == nodeids[4]
|
||||||
|
assert str(halfchan.destination.node_id) in exp_ids
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# same but other direction
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[4])
|
||||||
|
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
|
||||||
|
exp_scidds = [scid25 + '/0', scid45 + '/1', scid56 + '/0', scid58 + '/1']
|
||||||
|
assert len(result) == len(exp_ids)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan.destination.node_id) == nodeids[4]
|
||||||
|
assert str(halfchan.source.node_id) in exp_ids
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# get all channels which have l1 as destination
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[0])
|
||||||
|
exp_ids = [nodeids[1], nodeids[3]]
|
||||||
|
exp_scidds = [scid12 + '/0', scid14 + '/1']
|
||||||
|
assert len(result) == len(exp_ids)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan.destination.node_id) == nodeids[0]
|
||||||
|
assert str(halfchan.source.node_id) in exp_ids
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# l5 as destination in the middle but depth=1, so the outer ring
|
||||||
|
# epxect: 12, 14, 32, 36, 74, 78, 98, 96
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[4], depth=1)
|
||||||
|
exp_scidds = [scid12 + '/1', scid14 + '/0', scid23 + '/1', scid36 + '/1',
|
||||||
|
scid47 + '/0', scid69 + '/1', scid78 + '/0', scid89 + '/0']
|
||||||
|
assert len(result) == len(exp_scidds)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# same but other direction
|
||||||
|
result = g.get_neighbors_hc(source=nodeids[4], depth=1)
|
||||||
|
exp_scidds = [scid12 + '/0', scid14 + '/1', scid23 + '/0', scid36 + '/0',
|
||||||
|
scid47 + '/1', scid69 + '/0', scid78 + '/1', scid89 + '/1']
|
||||||
|
assert len(result) == len(exp_scidds)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# l9 as destination and depth=2 expect: 23 25 45 47
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[8], depth=2)
|
||||||
|
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1', scid47 + '/1']
|
||||||
|
assert len(result) == len(exp_scidds)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# l9 as destination depth=2 exclude=[l7] expect: 23 25 45
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=[nodes[6]])
|
||||||
|
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
|
||||||
|
assert len(result) == len(exp_scidds)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# same as above, but excludes halfchannels of l7 expect: 23 25 45
|
||||||
|
hcs = [c.half_channels[0] for c in nodes[6].channels]
|
||||||
|
hcs += [c.half_channels[1] for c in nodes[6].channels]
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=hcs)
|
||||||
|
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
|
||||||
|
assert len(result) == len(exp_scidds)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# again, same as above, but excludes channels of l7 expect: 23 25 45
|
||||||
|
chs = [c for c in nodes[6].channels]
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[8], depth=2, excludes=chs)
|
||||||
|
exp_scidds = [scid23 + '/0', scid25 + '/0', scid45 + '/1']
|
||||||
|
assert len(result) == len(exp_scidds)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# l9 as destination and depth=3 expect: 12 14
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[8], depth=3)
|
||||||
|
exp_scidds = [scid12 + '/1', scid14 + '/0']
|
||||||
|
assert len(result) == len(exp_scidds)
|
||||||
|
for halfchan in result:
|
||||||
|
assert str(halfchan) in exp_scidds
|
||||||
|
|
||||||
|
# l9 as destination and depth>=4 expect: empty set
|
||||||
|
for d in range(4, 6):
|
||||||
|
result = g.get_neighbors_hc(destination=nodeids[8], depth=d)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
Loading…
Add table
Reference in a new issue