pyln: Add code to unwrap an encrypted onion at the intended node

Changelog-Added: pyln-proto: Added pure python implementation of the sphinx onion creation and processing functionality.
This commit is contained in:
Christian Decker 2020-07-31 18:36:25 +02:00 committed by Rusty Russell
parent e8dcd59b24
commit 04462f6a64
2 changed files with 121 additions and 10 deletions

View File

@ -15,8 +15,9 @@ from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from hashlib import sha256
from io import BytesIO, SEEK_CUR
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
import coincurve
import io
import os
import struct
@ -44,7 +45,7 @@ class OnionPayload(object):
s = s.encode('ASCII')
return cls.from_bytes(bytes(unhexlify(s)))
def to_bytes(self):
def to_bytes(self, include_prefix):
raise ValueError("OnionPayload is an abstract class, use "
"LegacyOnionPayload or TlvPayload instead")
@ -92,20 +93,20 @@ class LegacyOnionPayload(OnionPayload):
padding = b.read(12)
return LegacyOnionPayload(a, o, s, padding)
def to_bytes(self, include_realm=True):
def to_bytes(self, include_prefix=True):
b = b''
if include_realm:
if include_prefix:
b += b'\x00'
b += struct.pack("!Q", self.short_channel_id)
b += struct.pack("!Q", self.amt_to_forward)
b += struct.pack("!L", self.outgoing_cltv_value)
b += self.padding
assert(len(b) == 32 + include_realm)
assert(len(b) == 32 + include_prefix)
return b
def to_hex(self, include_realm=True):
return hexlify(self.to_bytes(include_realm)).decode('ASCII')
def to_hex(self, include_prefix=True):
return hexlify(self.to_bytes(include_prefix)).decode('ASCII')
def __str__(self):
return ("LegacyOnionPayload[scid={self.short_channel_id}, "
@ -143,6 +144,12 @@ class TlvPayload(OnionPayload):
raise ValueError(
"Unable to read length at position {}".format(b.tell())
)
elif length > start + payload_length - b.tell():
b.seek(start + payload_length)
raise ValueError("Failed to parse TLV payload: value length "
"is longer than available bytes.")
val = b.read(length)
# Get the subclass that is the correct interpretation of this
@ -167,10 +174,11 @@ class TlvPayload(OnionPayload):
return f
return default
def to_bytes(self):
def to_bytes(self, include_prefix=True) -> bytes:
ser = [f.to_bytes() for f in self.fields]
b = BytesIO()
varint_encode(sum([len(b) for b in ser]), b)
if include_prefix:
varint_encode(sum([len(b) for b in ser]), b)
for f in ser:
b.write(f)
return b.getvalue()
@ -179,6 +187,40 @@ class TlvPayload(OnionPayload):
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]"
class RawPayload(OnionPayload):
"""A payload that doesn't deserialize correctly as TLV stream.
Mainly used if TLV parsing fails, but we still want access to the raw
payload.
"""
def __init__(self):
self.content: Optional[bytes] = None
@classmethod
def from_bytes(cls, b):
if isinstance(b, str):
b = b.encode('ASCII')
if isinstance(b, bytes):
b = BytesIO(b)
self = cls()
payload_length = varint_decode(b)
self.content = b.read(payload_length)
return self
def to_bytes(self, include_prefix=True) -> bytes:
b = BytesIO()
if self.content is None:
raise ValueError("Cannot serialize empty TLV payload")
if include_prefix:
varint_encode(len(self.content), b)
b.write(self.content)
return b.getvalue()
class TlvField(object):
def __init__(self, typenum, value=None, description=None):
@ -319,6 +361,57 @@ class RoutingOnion(object):
def to_hex(self):
return hexlify(self.to_bin())
def unwrap(self, privkey: PrivateKey, assocdata: Optional[bytes]) \
-> Tuple[OnionPayload, Optional['RoutingOnion']]:
shared_secret = ecdh(privkey, self.ephemeralkey)
keys = generate_keyset(shared_secret)
h = hmac.HMAC(keys.mu, hashes.SHA256(),
backend=default_backend())
h.update(self.payloads)
if assocdata is not None:
h.update(assocdata)
hh = h.finalize()
if hh != self.hmac:
raise ValueError("HMAC does not match, onion might have been "
"tampered with: {hh} != {hmac}".format(
hh=hexlify(hh).decode('ascii'),
hmac=hexlify(self.hmac).decode('ascii'),
))
# Create the scratch twice as large as the original packet, since we
# need to left-shift a single payload off, which may itself be up to
# ROUTING_INFO_SIZE in length.
payloads = bytearray(2 * ROUTING_INFO_SIZE)
payloads[:ROUTING_INFO_SIZE] = self.payloads
chacha20_stream(keys.rho, payloads)
r = io.BytesIO(payloads)
start = r.tell()
try:
payload = OnionPayload.from_bytes(r)
except ValueError:
r.seek(start)
payload = RawPayload.from_bytes(r)
next_hmac = r.read(32)
shift_size = r.tell()
if next_hmac == bytes(32):
return payload, None
else:
b = blind(self.ephemeralkey, shared_secret)
ek = blind_group_element(self.ephemeralkey, b)
payloads = payloads[shift_size:shift_size + ROUTING_INFO_SIZE]
return payload, RoutingOnion(
version=self.version,
ephemeralkey=ek,
payloads=payloads,
hmac=next_hmac,
)
KeySet = namedtuple('KeySet', ['rho', 'mu', 'um', 'pad', 'gamma', 'pi'])

View File

@ -12,7 +12,7 @@ def test_legacy_payload():
b'00000067000001000100000000000003e800000075000000000000000000000000'
)
payload = onion.OnionPayload.from_bytes(legacy)
assert(payload.to_bytes(include_realm=True) == legacy)
assert(payload.to_bytes(include_prefix=True) == legacy)
def test_tlv_payload():
@ -325,3 +325,21 @@ def test_sphinx_path_compile():
o = sp.compile()
assert(o.to_bin() == unhexlify(v['onion']))
def test_unwrap():
f = 'tests/vectors/onion-test-multi-frame.json'
sp, v = sphinx_path_from_test_vector(f)
o = onion.RoutingOnion.from_hex(v['onion'])
assocdata = unhexlify(v['generate']['associated_data'])
privkeys = [onion.PrivateKey(unhexlify(h)) for h in v['decode']]
for pk, h in zip(privkeys, v['generate']['hops']):
pl, o = o.unwrap(pk, assocdata=assocdata)
b = hexlify(pl.to_bytes(include_prefix=False))
if h['type'] == 'legacy':
assert(b == h['payload'].encode('ascii') + b'00' * 12)
else:
assert(b == h['payload'].encode('ascii'))
assert(o is None)