mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-01-18 05:12:45 +01:00
pyln: Add code to unwrap an encrypted onion at the intended node
Changelog-Added: pyln-proto: Added pure python implementation of the sphinx onion creation and processing functionality.
This commit is contained in:
parent
e8dcd59b24
commit
04462f6a64
@ -15,8 +15,9 @@ from cryptography.hazmat.primitives import hashes, hmac
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
||||
from hashlib import sha256
|
||||
from io import BytesIO, SEEK_CUR
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Tuple
|
||||
import coincurve
|
||||
import io
|
||||
import os
|
||||
import struct
|
||||
|
||||
@ -44,7 +45,7 @@ class OnionPayload(object):
|
||||
s = s.encode('ASCII')
|
||||
return cls.from_bytes(bytes(unhexlify(s)))
|
||||
|
||||
def to_bytes(self):
|
||||
def to_bytes(self, include_prefix):
|
||||
raise ValueError("OnionPayload is an abstract class, use "
|
||||
"LegacyOnionPayload or TlvPayload instead")
|
||||
|
||||
@ -92,20 +93,20 @@ class LegacyOnionPayload(OnionPayload):
|
||||
padding = b.read(12)
|
||||
return LegacyOnionPayload(a, o, s, padding)
|
||||
|
||||
def to_bytes(self, include_realm=True):
|
||||
def to_bytes(self, include_prefix=True):
|
||||
b = b''
|
||||
if include_realm:
|
||||
if include_prefix:
|
||||
b += b'\x00'
|
||||
|
||||
b += struct.pack("!Q", self.short_channel_id)
|
||||
b += struct.pack("!Q", self.amt_to_forward)
|
||||
b += struct.pack("!L", self.outgoing_cltv_value)
|
||||
b += self.padding
|
||||
assert(len(b) == 32 + include_realm)
|
||||
assert(len(b) == 32 + include_prefix)
|
||||
return b
|
||||
|
||||
def to_hex(self, include_realm=True):
|
||||
return hexlify(self.to_bytes(include_realm)).decode('ASCII')
|
||||
def to_hex(self, include_prefix=True):
|
||||
return hexlify(self.to_bytes(include_prefix)).decode('ASCII')
|
||||
|
||||
def __str__(self):
|
||||
return ("LegacyOnionPayload[scid={self.short_channel_id}, "
|
||||
@ -143,6 +144,12 @@ class TlvPayload(OnionPayload):
|
||||
raise ValueError(
|
||||
"Unable to read length at position {}".format(b.tell())
|
||||
)
|
||||
|
||||
elif length > start + payload_length - b.tell():
|
||||
b.seek(start + payload_length)
|
||||
raise ValueError("Failed to parse TLV payload: value length "
|
||||
"is longer than available bytes.")
|
||||
|
||||
val = b.read(length)
|
||||
|
||||
# Get the subclass that is the correct interpretation of this
|
||||
@ -167,10 +174,11 @@ class TlvPayload(OnionPayload):
|
||||
return f
|
||||
return default
|
||||
|
||||
def to_bytes(self):
|
||||
def to_bytes(self, include_prefix=True) -> bytes:
|
||||
ser = [f.to_bytes() for f in self.fields]
|
||||
b = BytesIO()
|
||||
varint_encode(sum([len(b) for b in ser]), b)
|
||||
if include_prefix:
|
||||
varint_encode(sum([len(b) for b in ser]), b)
|
||||
for f in ser:
|
||||
b.write(f)
|
||||
return b.getvalue()
|
||||
@ -179,6 +187,40 @@ class TlvPayload(OnionPayload):
|
||||
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]"
|
||||
|
||||
|
||||
class RawPayload(OnionPayload):
|
||||
"""A payload that doesn't deserialize correctly as TLV stream.
|
||||
|
||||
Mainly used if TLV parsing fails, but we still want access to the raw
|
||||
payload.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.content: Optional[bytes] = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, b):
|
||||
if isinstance(b, str):
|
||||
b = b.encode('ASCII')
|
||||
if isinstance(b, bytes):
|
||||
b = BytesIO(b)
|
||||
|
||||
self = cls()
|
||||
payload_length = varint_decode(b)
|
||||
self.content = b.read(payload_length)
|
||||
return self
|
||||
|
||||
def to_bytes(self, include_prefix=True) -> bytes:
|
||||
b = BytesIO()
|
||||
if self.content is None:
|
||||
raise ValueError("Cannot serialize empty TLV payload")
|
||||
|
||||
if include_prefix:
|
||||
varint_encode(len(self.content), b)
|
||||
b.write(self.content)
|
||||
return b.getvalue()
|
||||
|
||||
|
||||
class TlvField(object):
|
||||
|
||||
def __init__(self, typenum, value=None, description=None):
|
||||
@ -319,6 +361,57 @@ class RoutingOnion(object):
|
||||
def to_hex(self):
|
||||
return hexlify(self.to_bin())
|
||||
|
||||
def unwrap(self, privkey: PrivateKey, assocdata: Optional[bytes]) \
|
||||
-> Tuple[OnionPayload, Optional['RoutingOnion']]:
|
||||
shared_secret = ecdh(privkey, self.ephemeralkey)
|
||||
keys = generate_keyset(shared_secret)
|
||||
|
||||
h = hmac.HMAC(keys.mu, hashes.SHA256(),
|
||||
backend=default_backend())
|
||||
h.update(self.payloads)
|
||||
if assocdata is not None:
|
||||
h.update(assocdata)
|
||||
hh = h.finalize()
|
||||
|
||||
if hh != self.hmac:
|
||||
raise ValueError("HMAC does not match, onion might have been "
|
||||
"tampered with: {hh} != {hmac}".format(
|
||||
hh=hexlify(hh).decode('ascii'),
|
||||
hmac=hexlify(self.hmac).decode('ascii'),
|
||||
))
|
||||
|
||||
# Create the scratch twice as large as the original packet, since we
|
||||
# need to left-shift a single payload off, which may itself be up to
|
||||
# ROUTING_INFO_SIZE in length.
|
||||
payloads = bytearray(2 * ROUTING_INFO_SIZE)
|
||||
payloads[:ROUTING_INFO_SIZE] = self.payloads
|
||||
chacha20_stream(keys.rho, payloads)
|
||||
|
||||
r = io.BytesIO(payloads)
|
||||
start = r.tell()
|
||||
|
||||
try:
|
||||
payload = OnionPayload.from_bytes(r)
|
||||
except ValueError:
|
||||
r.seek(start)
|
||||
payload = RawPayload.from_bytes(r)
|
||||
|
||||
next_hmac = r.read(32)
|
||||
shift_size = r.tell()
|
||||
|
||||
if next_hmac == bytes(32):
|
||||
return payload, None
|
||||
else:
|
||||
b = blind(self.ephemeralkey, shared_secret)
|
||||
ek = blind_group_element(self.ephemeralkey, b)
|
||||
payloads = payloads[shift_size:shift_size + ROUTING_INFO_SIZE]
|
||||
return payload, RoutingOnion(
|
||||
version=self.version,
|
||||
ephemeralkey=ek,
|
||||
payloads=payloads,
|
||||
hmac=next_hmac,
|
||||
)
|
||||
|
||||
|
||||
KeySet = namedtuple('KeySet', ['rho', 'mu', 'um', 'pad', 'gamma', 'pi'])
|
||||
|
||||
|
@ -12,7 +12,7 @@ def test_legacy_payload():
|
||||
b'00000067000001000100000000000003e800000075000000000000000000000000'
|
||||
)
|
||||
payload = onion.OnionPayload.from_bytes(legacy)
|
||||
assert(payload.to_bytes(include_realm=True) == legacy)
|
||||
assert(payload.to_bytes(include_prefix=True) == legacy)
|
||||
|
||||
|
||||
def test_tlv_payload():
|
||||
@ -325,3 +325,21 @@ def test_sphinx_path_compile():
|
||||
o = sp.compile()
|
||||
|
||||
assert(o.to_bin() == unhexlify(v['onion']))
|
||||
|
||||
|
||||
def test_unwrap():
|
||||
f = 'tests/vectors/onion-test-multi-frame.json'
|
||||
sp, v = sphinx_path_from_test_vector(f)
|
||||
o = onion.RoutingOnion.from_hex(v['onion'])
|
||||
assocdata = unhexlify(v['generate']['associated_data'])
|
||||
privkeys = [onion.PrivateKey(unhexlify(h)) for h in v['decode']]
|
||||
|
||||
for pk, h in zip(privkeys, v['generate']['hops']):
|
||||
pl, o = o.unwrap(pk, assocdata=assocdata)
|
||||
|
||||
b = hexlify(pl.to_bytes(include_prefix=False))
|
||||
if h['type'] == 'legacy':
|
||||
assert(b == h['payload'].encode('ascii') + b'00' * 12)
|
||||
else:
|
||||
assert(b == h['payload'].encode('ascii'))
|
||||
assert(o is None)
|
||||
|
Loading…
Reference in New Issue
Block a user