mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-01-17 19:03:42 +01:00
pyln-proto: Added a couple of utilities to manage onions and zbase32
This commit is contained in:
parent
d36af2c340
commit
d3f6ebf911
236
contrib/pyln-proto/pyln/proto/onion.py
Normal file
236
contrib/pyln-proto/pyln/proto/onion.py
Normal file
@ -0,0 +1,236 @@
|
||||
from .primitives import varint_decode, varint_encode
|
||||
from io import BytesIO, SEEK_CUR
|
||||
from binascii import hexlify, unhexlify
|
||||
import struct
|
||||
|
||||
|
||||
class OnionPayload(object):
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, b):
|
||||
if isinstance(b, bytes):
|
||||
b = BytesIO(b)
|
||||
|
||||
realm = b.read(1)
|
||||
b.seek(-1, SEEK_CUR)
|
||||
|
||||
if realm == b'\x00':
|
||||
return LegacyOnionPayload.from_bytes(b)
|
||||
elif realm != b'\x01':
|
||||
return TlvPayload.from_bytes(b, skip_length=False)
|
||||
else:
|
||||
raise ValueError("Onion payloads with realm 0x01 are unsupported")
|
||||
|
||||
@classmethod
|
||||
def from_hex(cls, s):
|
||||
if isinstance(s, str):
|
||||
s = s.encode('ASCII')
|
||||
return cls.from_bytes(bytes(unhexlify(s)))
|
||||
|
||||
def to_bytes(self):
|
||||
raise ValueError("OnionPayload is an abstract class, use "
|
||||
"LegacyOnionPayload or TlvPayload instead")
|
||||
|
||||
def to_hex(self):
|
||||
return hexlify(self.to_bytes()).decode('ASCII')
|
||||
|
||||
|
||||
class LegacyOnionPayload(OnionPayload):
|
||||
|
||||
def __init__(self, amt_to_forward, outgoing_cltv_value,
|
||||
short_channel_id=None, padding=None):
|
||||
assert(padding is None or len(padding) == 12)
|
||||
self.padding = b'\x00' * 12 if padding is None else padding
|
||||
|
||||
if isinstance(amt_to_forward, str):
|
||||
self.amt_to_forward = int(amt_to_forward)
|
||||
else:
|
||||
self.amt_to_forward = amt_to_forward
|
||||
|
||||
self.outgoing_cltv_value = outgoing_cltv_value
|
||||
|
||||
if isinstance(short_channel_id, str) and 'x' in short_channel_id:
|
||||
# Convert the short_channel_id from its string representation to its numeric representation
|
||||
block, tx, out = short_channel_id.split('x')
|
||||
num_scid = int(block) << 40 | int(tx) << 16 | int(out)
|
||||
self.short_channel_id = num_scid
|
||||
elif isinstance(short_channel_id, int):
|
||||
self.short_channel_id = short_channel_id
|
||||
else:
|
||||
raise ValueError("short_channel_id format cannot be recognized: {}".format(short_channel_id))
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, b):
|
||||
if isinstance(b, bytes):
|
||||
b = BytesIO(b)
|
||||
|
||||
assert(b.read(1) == b'\x00')
|
||||
|
||||
s, a, o = struct.unpack("!QQL", b.read(20))
|
||||
padding = b.read(12)
|
||||
return LegacyOnionPayload(a, o, s, padding)
|
||||
|
||||
def to_bytes(self, include_realm=True):
|
||||
b = b''
|
||||
if include_realm:
|
||||
b += b'\x00'
|
||||
|
||||
b += struct.pack("!Q", self.short_channel_id)
|
||||
b += struct.pack("!Q", self.amt_to_forward)
|
||||
b += struct.pack("!L", self.outgoing_cltv_value)
|
||||
b += self.padding
|
||||
assert(len(b) == 32 + include_realm)
|
||||
return b
|
||||
|
||||
def to_hex(self, include_realm=True):
|
||||
return hexlify(self.to_bytes(include_realm)).decode('ASCII')
|
||||
|
||||
def __str__(self):
|
||||
return ("LegacyOnionPayload[scid={self.short_channel_id}, "
|
||||
"amt_to_forward={self.amt_to_forward}, "
|
||||
"outgoing_cltv={self.outgoing_cltv_value}]").format(self=self)
|
||||
|
||||
|
||||
class TlvPayload(OnionPayload):
|
||||
|
||||
def __init__(self, fields=None):
|
||||
self.fields = [] if fields is None else fields
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, b, skip_length=False):
|
||||
if isinstance(b, str):
|
||||
b = b.encode('ASCII')
|
||||
if isinstance(b, bytes):
|
||||
b = BytesIO(b)
|
||||
|
||||
if skip_length:
|
||||
# Consume the entire remainder of the buffer.
|
||||
payload_length = len(b.getvalue()) - b.tell()
|
||||
else:
|
||||
payload_length = varint_decode(b)
|
||||
|
||||
instance = TlvPayload()
|
||||
|
||||
start = b.tell()
|
||||
while b.tell() < start + payload_length:
|
||||
typenum = varint_decode(b)
|
||||
if typenum is None:
|
||||
break
|
||||
length = varint_decode(b)
|
||||
if length is None:
|
||||
raise ValueError(
|
||||
"Unable to read length at position {}".format(b.tell())
|
||||
)
|
||||
val = b.read(length)
|
||||
|
||||
# Get the subclass that is the correct interpretation of this
|
||||
# field. Default to the binary field type.
|
||||
c = tlv_types.get(typenum, (TlvField, "unknown"))
|
||||
cls = c[0]
|
||||
field = cls.from_bytes(typenum=typenum, b=val, description=c[1])
|
||||
instance.fields.append(field)
|
||||
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_hex(cls, h):
|
||||
return cls.from_bytes(unhexlify(h))
|
||||
|
||||
def add_field(self, typenum, value):
|
||||
self.fields.append(TlvField(typenum=typenum, value=value))
|
||||
|
||||
def get(self, key, default=None):
|
||||
for f in self.fields:
|
||||
if f.typenum == key:
|
||||
return f
|
||||
return default
|
||||
|
||||
def to_bytes(self):
|
||||
ser = [f.to_bytes() for f in self.fields]
|
||||
b = BytesIO()
|
||||
varint_encode(sum([len(b) for b in ser]), b)
|
||||
for f in ser:
|
||||
b.write(f)
|
||||
return b.getvalue()
|
||||
|
||||
def __str__(self):
|
||||
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]"
|
||||
|
||||
|
||||
class TlvField(object):
|
||||
|
||||
def __init__(self, typenum, value=None, description=None):
|
||||
self.typenum = typenum
|
||||
self.value = value
|
||||
self.description = description
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, typenum, b, description=None):
|
||||
return TlvField(typenum=typenum, value=b, description=description)
|
||||
|
||||
def __str__(self):
|
||||
return "TlvField[{description},{num}={hex}]".format(
|
||||
description=self.description,
|
||||
num=self.typenum,
|
||||
hex=hexlify(self.value).decode('ASCII')
|
||||
)
|
||||
|
||||
def to_bytes(self):
|
||||
b = BytesIO()
|
||||
varint_encode(self.typenum, b)
|
||||
varint_encode(len(self.value), b)
|
||||
b.write(self.value)
|
||||
return b.getvalue()
|
||||
|
||||
|
||||
class Tu32Field(TlvField):
|
||||
pass
|
||||
|
||||
|
||||
class Tu64Field(TlvField):
|
||||
pass
|
||||
|
||||
|
||||
class ShortChannelIdField(TlvField):
|
||||
pass
|
||||
|
||||
|
||||
class TextField(TlvField):
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, typenum, b, description=None):
|
||||
val = b.decode('UTF-8')
|
||||
return TextField(typenum, value=val, description=description)
|
||||
|
||||
def to_bytes(self):
|
||||
b = BytesIO()
|
||||
val = self.value.encode('UTF-8')
|
||||
varint_encode(self.typenum, b)
|
||||
varint_encode(len(val), b)
|
||||
b.write(val)
|
||||
return b.getvalue()
|
||||
|
||||
def __str__(self):
|
||||
return "TextField[{description},{num}=\"{val}\"]".format(
|
||||
description=self.description,
|
||||
num=self.typenum,
|
||||
val=self.value,
|
||||
)
|
||||
|
||||
|
||||
class HashField(TlvField):
|
||||
pass
|
||||
|
||||
|
||||
class SignatureField(TlvField):
|
||||
pass
|
||||
|
||||
|
||||
# A mapping of known TLV types
|
||||
tlv_types = {
|
||||
2: (Tu64Field, 'amt_to_forward'),
|
||||
4: (Tu32Field, 'outgoing_cltv_value'),
|
||||
6: (ShortChannelIdField, 'short_channel_id'),
|
||||
34349334: (TextField, 'noise_message_body'),
|
||||
34349336: (SignatureField, 'noise_message_signature'),
|
||||
}
|
70
contrib/pyln-proto/pyln/proto/primitives.py
Normal file
70
contrib/pyln-proto/pyln/proto/primitives.py
Normal file
@ -0,0 +1,70 @@
|
||||
import struct
|
||||
|
||||
|
||||
def varint_encode(i, w):
|
||||
"""Encode an integer `i` into the writer `w`
|
||||
"""
|
||||
if i < 0xFD:
|
||||
w.write(struct.pack("!B", i))
|
||||
elif i <= 0xFFFF:
|
||||
w.write(struct.pack("!BH", 0xFD, i))
|
||||
elif i <= 0xFFFFFFFF:
|
||||
w.write(struct.pack("!BL", 0xFE, i))
|
||||
else:
|
||||
w.write(struct.pack("!BQ", 0xFF, i))
|
||||
|
||||
|
||||
def varint_decode(r):
|
||||
"""Decode an integer from reader `r`
|
||||
"""
|
||||
raw = r.read(1)
|
||||
if len(raw) != 1:
|
||||
return None
|
||||
|
||||
i, = struct.unpack("!B", raw)
|
||||
if i < 0xFD:
|
||||
return i
|
||||
elif i == 0xFD:
|
||||
return struct.unpack("!H", r.read(2))[0]
|
||||
elif i == 0xFE:
|
||||
return struct.unpack("!L", r.read(4))[0]
|
||||
else:
|
||||
return struct.unpack("!Q", r.read(8))[0]
|
||||
|
||||
|
||||
class ShortChannelId(object):
|
||||
def __init__(self, block, txnum, outnum):
|
||||
self.block = block
|
||||
self.txnum = txnum
|
||||
self.outnum = outnum
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, b):
|
||||
assert(len(b) == 8)
|
||||
i, = struct.unpack("!Q", b)
|
||||
return cls.from_int(i)
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, i):
|
||||
block = (i >> 40) & 0xFFFFFF
|
||||
txnum = (i >> 16) & 0xFFFFFF
|
||||
outnum = (i >> 0) & 0xFFFF
|
||||
return cls(block=block, txnum=txnum, outnum=outnum)
|
||||
|
||||
@classmethod
|
||||
def from_str(self, s):
|
||||
block, txnum, outnum = s.split('x')
|
||||
return ShortChannelId(block=int(block), txnum=int(txnum),
|
||||
outnum=int(outnum))
|
||||
|
||||
def to_int(self):
|
||||
return self.block << 40 | self.txnum << 16 | self.outnum
|
||||
|
||||
def to_bytes(self):
|
||||
return struct.pack("!Q", self.to_int())
|
||||
|
||||
def __str__(self):
|
||||
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum
|
56
contrib/pyln-proto/pyln/proto/zbase32.py
Normal file
56
contrib/pyln-proto/pyln/proto/zbase32.py
Normal file
@ -0,0 +1,56 @@
|
||||
import bitstring
|
||||
|
||||
|
||||
zbase32_chars = b'ybndrfg8ejkmcpqxot1uwisza345h769'
|
||||
zbase32_revchars = [
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 18, 255, 25, 26, 27, 30, 29, 7, 31, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, 255, 11, 2,
|
||||
16, 13, 14, 4, 22, 17, 19, 255, 20, 15, 0, 23, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
||||
255, 255, 255, 255, 255, 255, 255
|
||||
]
|
||||
|
||||
|
||||
def bitarray_to_u5(barr):
|
||||
assert len(barr) % 5 == 0
|
||||
ret = []
|
||||
s = bitstring.ConstBitStream(barr)
|
||||
while s.pos != s.len:
|
||||
ret.append(s.read(5).uint)
|
||||
return ret
|
||||
|
||||
|
||||
def u5_to_bitarray(arr):
|
||||
ret = bitstring.BitArray()
|
||||
for a in arr:
|
||||
ret += bitstring.pack("uint:5", a)
|
||||
return ret
|
||||
|
||||
|
||||
def encode(b):
|
||||
uint5s = bitarray_to_u5(b)
|
||||
res = [zbase32_chars[c] for c in uint5s]
|
||||
return bytes(res)
|
||||
|
||||
|
||||
def decode(b):
|
||||
if isinstance(b, str):
|
||||
b = b.encode('ASCII')
|
||||
|
||||
uint5s = []
|
||||
for c in b:
|
||||
uint5s.append(zbase32_revchars[c])
|
||||
dec = u5_to_bitarray(uint5s)
|
||||
return dec.bytes
|
@ -1,2 +1,3 @@
|
||||
bitstring==3.1.6
|
||||
cryptography==2.7
|
||||
coincurve==12.0.0
|
||||
|
32
contrib/pyln-proto/tests/test_onion.py
Normal file
32
contrib/pyln-proto/tests/test_onion.py
Normal file
@ -0,0 +1,32 @@
|
||||
from binascii import unhexlify
|
||||
|
||||
from pyln.proto import onion
|
||||
|
||||
|
||||
def test_legacy_payload():
|
||||
legacy = unhexlify(
|
||||
b'00000067000001000100000000000003e800000075000000000000000000000000'
|
||||
)
|
||||
payload = onion.OnionPayload.from_bytes(legacy)
|
||||
assert(payload.to_bytes(include_realm=True) == legacy)
|
||||
|
||||
|
||||
def test_tlv_payload():
|
||||
tlv = unhexlify(
|
||||
b'58fe020c21160c48656c6c6f20776f726c6421fe020c21184076e8acd54afbf2361'
|
||||
b'0b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e'
|
||||
b'7fc8ce1c7d3651acefde899c33f12b6958d3304106a0'
|
||||
)
|
||||
payload = onion.OnionPayload.from_bytes(tlv)
|
||||
assert(payload.to_bytes() == tlv)
|
||||
|
||||
fields = payload.fields
|
||||
assert(len(fields) == 2)
|
||||
assert(isinstance(fields[0], onion.TextField))
|
||||
assert(fields[0].typenum == 34349334 and fields[0].value == "Hello world!")
|
||||
assert(fields[1].typenum == 34349336 and fields[1].value == unhexlify(
|
||||
b'76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d'
|
||||
b'0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0'
|
||||
))
|
||||
|
||||
assert(payload.to_bytes() == tlv)
|
30
contrib/pyln-proto/tests/test_primitives.py
Normal file
30
contrib/pyln-proto/tests/test_primitives.py
Normal file
@ -0,0 +1,30 @@
|
||||
from binascii import hexlify, unhexlify
|
||||
from pyln.proto import zbase32
|
||||
from pyln.proto.primitives import ShortChannelId
|
||||
|
||||
|
||||
def test_short_channel_id():
|
||||
num = 618150934845652992
|
||||
b = unhexlify(b'08941d00090d0000')
|
||||
s = '562205x2317x0'
|
||||
s1 = ShortChannelId.from_int(num)
|
||||
s2 = ShortChannelId.from_str(s)
|
||||
s3 = ShortChannelId.from_bytes(b)
|
||||
expected = ShortChannelId(block=562205, txnum=2317, outnum=0)
|
||||
|
||||
assert(s1 == expected)
|
||||
assert(s2 == expected)
|
||||
assert(s3 == expected)
|
||||
|
||||
assert(expected.to_bytes() == b)
|
||||
assert(str(expected) == s)
|
||||
assert(expected.to_int() == num)
|
||||
|
||||
|
||||
def test_zbase32():
|
||||
zb32 = b'd75qtmgijm79rpooshmgzjwji9gj7dsdat8remuskyjp9oq1ugkaoj6orbxzhuo4njtyh96e3aq84p1tiuz77nchgxa1s4ka4carnbiy'
|
||||
b = zbase32.decode(zb32)
|
||||
assert(hexlify(b) == b'1f76e8acd54afbf23610b7166ba689afcc9e8ec3c44e442e765012dfc1d299958827d0205f7e4e1a12620e7fc8ce1c7d3651acefde899c33f12b6958d3304106a0')
|
||||
|
||||
enc = zbase32.encode(b)
|
||||
assert(enc == zb32)
|
Loading…
Reference in New Issue
Block a user