diff --git a/contrib/pyln-client/pyln/client/plugin.py b/contrib/pyln-client/pyln/client/plugin.py index 6d7586948..78c1976d5 100644 --- a/contrib/pyln-client/pyln/client/plugin.py +++ b/contrib/pyln-client/pyln/client/plugin.py @@ -1,10 +1,12 @@ +from .lightning import LightningRpc, Millisatoshi from binascii import hexlify from collections import OrderedDict from enum import Enum -from .lightning import LightningRpc, Millisatoshi from threading import RLock +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import inspect +import io import json import math import os @@ -12,6 +14,16 @@ import re import sys import traceback +# Notice that this definition is incomplete as it only checks the +# top-level. Arrays and Dicts could contain types that aren't encodeable. This +# limitation stems from the fact that recursive types are not really supported +# yet. +JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] + +# Yes, decorators are weird... +NoneDecoratorType = Callable[..., Callable[..., None]] +JsonDecoratorType = Callable[..., Callable[..., JSONType]] + class MethodType(Enum): RPCMETHOD = 0 @@ -32,8 +44,10 @@ class Method(object): - RPC exposed by RPC passthrough - HOOK registered to be called synchronously by lightningd """ - def __init__(self, name, func, mtype=MethodType.RPCMETHOD, category=None, - desc=None, long_desc=None, deprecated=False): + def __init__(self, name: str, func: Callable[..., JSONType], + mtype: MethodType = MethodType.RPCMETHOD, + category: str = None, desc: str = None, + long_desc: str = None, deprecated: bool = False): self.name = name self.func = func self.mtype = mtype @@ -47,7 +61,8 @@ class Method(object): class Request(dict): """A request object that wraps params and allows async return """ - def __init__(self, plugin, req_id, method, params, background=False): + def __init__(self, plugin: 'Plugin', req_id: Optional[int], method: str, + params: Any, background: bool = False): self.method = method self.params = params self.background = background @@ -55,15 +70,19 @@ class Request(dict): self.state = RequestState.PENDING self.id = req_id - def getattr(self, key): + def getattr(self, key: str) -> Union[Method, Any, int]: if key == "params": return self.params elif key == "id": return self.id elif key == "method": return self.method + else: + raise ValueError( + 'Cannot get attribute "{key}" on Request'.format(key=key) + ) - def set_result(self, result): + def set_result(self, result: Any) -> None: if self.state != RequestState.PENDING: raise ValueError( "Cannot set the result of a request that is not pending, " @@ -75,7 +94,7 @@ class Request(dict): 'result': self.result }) - def set_exception(self, exc): + def set_exception(self, exc: Exception) -> None: if self.state != RequestState.PENDING: raise ValueError( "Cannot set the exception of a request that is not pending, " @@ -93,7 +112,7 @@ class Request(dict): }, }) - def _write_result(self, result): + def _write_result(self, result: dict) -> None: self.plugin._write_locked(result) @@ -126,12 +145,20 @@ class Plugin(object): """ - def __init__(self, stdout=None, stdin=None, autopatch=True, dynamic=True, - init_features=None, node_features=None, invoice_features=None): - self.methods = {'init': Method('init', self._init, MethodType.RPCMETHOD)} - self.options = {} + def __init__(self, stdout: Optional[io.TextIOBase] = None, + stdin: Optional[io.TextIOBase] = None, autopatch: bool = True, + dynamic: bool = True, + init_features: Optional[Union[int, str, bytes]] = None, + node_features: Optional[Union[int, str, bytes]] = None, + invoice_features: Optional[Union[int, str, bytes]] = None): + self.methods = { + 'init': Method('init', self._init, MethodType.RPCMETHOD) + } - def convert_featurebits(bits): + self.options: Dict[str, Dict[str, Any]] = {} + + def convert_featurebits( + bits: Optional[Union[int, str, bytes]]) -> Optional[str]: """Convert the featurebits into the bytes required to hexencode. """ if bits is None: @@ -149,7 +176,9 @@ class Plugin(object): return hexlify(bits).decode('ASCII') else: - raise ValueError("Could not convert featurebits to hex-encoded string") + raise ValueError( + "Could not convert featurebits to hex-encoded string" + ) self.featurebits = { 'init': convert_featurebits(init_features), @@ -158,7 +187,7 @@ class Plugin(object): } # A dict from topics to handler functions - self.subscriptions = {} + self.subscriptions: Dict[str, Callable[..., None]] = {} if not stdout: self.stdout = sys.stdout @@ -172,17 +201,21 @@ class Plugin(object): monkey_patch(self, stdout=True, stderr=True) self.add_method("getmanifest", self._getmanifest, background=False) - self.rpc_filename = None - self.lightning_dir = None - self.rpc = None + self.rpc_filename: Optional[str] = None + self.lightning_dir: Optional[str] = None + self.rpc: Optional[LightningRpc] = None self.startup = True self.dynamic = dynamic - self.child_init = None + self.child_init: Optional[Callable[..., None]] = None self.write_lock = RLock() - def add_method(self, name, func, background=False, category=None, desc=None, - long_desc=None, deprecated=False): + def add_method(self, name: str, func: Callable[..., Any], + background: bool = False, + category: Optional[str] = None, + desc: Optional[str] = None, + long_desc: Optional[str] = None, + deprecated: bool = False) -> None: """Add a plugin method to the dispatch table. The function will be expected at call time (see `_dispatch`) @@ -221,11 +254,15 @@ class Plugin(object): ) # Register the function with the name - method = Method(name, func, MethodType.RPCMETHOD, category, desc, long_desc, deprecated) + method = Method( + name, func, MethodType.RPCMETHOD, category, desc, long_desc, + deprecated + ) + method.background = background self.methods[name] = method - def add_subscription(self, topic, func): + def add_subscription(self, topic: str, func: Callable[..., None]) -> None: """Add a subscription to our list of subscriptions. A subscription is an association between a topic and a handler @@ -243,9 +280,9 @@ class Plugin(object): "Topic {} already has a handler".format(topic) ) - # Make sure the notification callback has a **kwargs argument so that it - # doesn't break if we add more arguments to the call later on. Issue a - # warning if it does not. + # Make sure the notification callback has a **kwargs argument so that + # it doesn't break if we add more arguments to the call later + # on. Issue a warning if it does not. s = inspect.signature(func) kinds = [p.kind for p in s.parameters.values()] if inspect.Parameter.VAR_KEYWORD not in kinds: @@ -257,16 +294,20 @@ class Plugin(object): self.subscriptions[topic] = func - def subscribe(self, topic): + def subscribe(self, topic: str) -> NoneDecoratorType: """Function decorator to register a notification handler. + """ - def decorator(f): + # Yes, decorator type annotations are just weird, don't think too much + # about it... + def decorator(f: Callable[..., None]) -> Callable[..., None]: self.add_subscription(topic, f) return f return decorator - def add_option(self, name, default, description, opt_type="string", - deprecated=False): + def add_option(self, name: str, default: Optional[str], + description: Optional[str], + opt_type: str = "string", deprecated: bool = False) -> None: """Add an option that we'd like to register with lightningd. Needs to be called before `Plugin.run`, otherwise we might not @@ -279,7 +320,9 @@ class Plugin(object): ) if opt_type not in ["string", "int", "bool", "flag"]: - raise ValueError('{} not in supported type set (string, int, bool, flag)') + raise ValueError( + '{} not in supported type set (string, int, bool, flag)' + ) self.options[name] = { 'name': name, @@ -290,7 +333,8 @@ class Plugin(object): 'deprecated': deprecated, } - def add_flag_option(self, name, description, deprecated=False): + def add_flag_option(self, name: str, description: str, + deprecated: bool = False) -> None: """Add a flag option that we'd like to register with lightningd. Needs to be called before `Plugin.run`, otherwise we might not @@ -300,7 +344,7 @@ class Plugin(object): self.add_option(name, None, description, opt_type="flag", deprecated=deprecated) - def get_option(self, name): + def get_option(self, name: str) -> str: if name not in self.options: raise ValueError("No option with name {} registered".format(name)) @@ -309,31 +353,42 @@ class Plugin(object): else: return self.options[name]['default'] - def async_method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False): + def async_method(self, method_name: str, category: Optional[str] = None, + desc: Optional[str] = None, + long_desc: Optional[str] = None, + deprecated: bool = False) -> NoneDecoratorType: """Decorator to add an async plugin method to the dispatch table. Internally uses add_method. """ - def decorator(f): + def decorator(f: Callable[..., None]) -> Callable[..., None]: self.add_method(method_name, f, background=True, category=category, desc=desc, long_desc=long_desc, deprecated=deprecated) return f return decorator - def method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False): + def method(self, method_name: str, category: Optional[str] = None, + desc: Optional[str] = None, + long_desc: Optional[str] = None, + deprecated: bool = False) -> JsonDecoratorType: """Decorator to add a plugin method to the dispatch table. Internally uses add_method. """ - def decorator(f): - self.add_method(method_name, f, background=False, category=category, - desc=desc, long_desc=long_desc, + def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]: + self.add_method(method_name, + f, + background=False, + category=category, + desc=desc, + long_desc=long_desc, deprecated=deprecated) return f return decorator - def add_hook(self, name, func, background=False): + def add_hook(self, name: str, func: Callable[..., JSONType], + background: bool = False) -> None: """Register a hook that is called synchronously by lightningd on events """ if name in self.methods: @@ -357,40 +412,47 @@ class Plugin(object): method.background = background self.methods[name] = method - def hook(self, method_name): + def hook(self, method_name: str) -> JsonDecoratorType: """Decorator to add a plugin hook to the dispatch table. Internally uses add_hook. """ - def decorator(f): + def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]: self.add_hook(method_name, f, background=False) return f return decorator - def async_hook(self, method_name): + def async_hook(self, method_name: str) -> NoneDecoratorType: """Decorator to add an async plugin hook to the dispatch table. Internally uses add_hook. """ - def decorator(f): + def decorator(f: Callable[..., None]) -> Callable[..., None]: self.add_hook(method_name, f, background=True) return f return decorator - def init(self, *args, **kwargs): + def init(self) -> NoneDecoratorType: """Decorator to add a function called after plugin initialization """ - def decorator(f): + def decorator(f: Callable[..., None]) -> Callable[..., None]: if self.child_init is not None: - raise ValueError('The @plugin.init decorator should only be used once') + raise ValueError( + 'The @plugin.init decorator should only be used once' + ) self.child_init = f return f return decorator @staticmethod - def _coerce_arguments(func, ba): + def _coerce_arguments( + func: Callable[..., Any], + ba: inspect.BoundArguments) -> inspect.BoundArguments: args = OrderedDict() - annotations = func.__annotations__ if hasattr(func, "__annotations__") else {} + annotations = {} + if hasattr(func, "__annotations__"): + annotations = func.__annotations__ + for key, val in ba.arguments.items(): annotation = annotations.get(key, None) if annotation is not None and annotation == Millisatoshi: @@ -400,7 +462,8 @@ class Plugin(object): ba.arguments = args return ba - def _bind_pos(self, func, params, request): + def _bind_pos(self, func: Callable[..., Any], params: List[str], + request: Request) -> inspect.BoundArguments: """Positional binding of parameters """ assert(isinstance(params, list)) @@ -409,7 +472,7 @@ class Plugin(object): # Collect injections so we can sort them and insert them in the right # order later. If we don't apply inject them in increasing order we # might shift away an earlier injection. - injections = [] + injections: List[Tuple[int, Any]] = [] if 'plugin' in sig.parameters: pos = list(sig.parameters.keys()).index('plugin') injections.append((pos, self)) @@ -425,7 +488,8 @@ class Plugin(object): ba.apply_defaults() return ba - def _bind_kwargs(self, func, params, request): + def _bind_kwargs(self, func: Callable[..., Any], params: Dict[str, Any], + request: Request) -> inspect.BoundArguments: """Keyword based binding of parameters """ assert(isinstance(params, dict)) @@ -445,7 +509,8 @@ class Plugin(object): self._coerce_arguments(func, ba) return ba - def _exec_func(self, func, request): + def _exec_func(self, func: Callable[..., Any], + request: Request) -> JSONType: params = request.params if isinstance(params, list): ba = self._bind_pos(func, params, request) @@ -454,9 +519,11 @@ class Plugin(object): ba = self._bind_kwargs(func, params, request) return func(*ba.args, **ba.kwargs) else: - raise TypeError("Parameters to function call must be either a dict or a list.") + raise TypeError( + "Parameters to function call must be either a dict or a list." + ) - def _dispatch_request(self, request): + def _dispatch_request(self, request: Request) -> None: name = request.method if name not in self.methods: @@ -487,7 +554,7 @@ class Plugin(object): request.set_exception(e) self.log(traceback.format_exc()) - def _dispatch_notification(self, request): + def _dispatch_notification(self, request: Request) -> None: if request.method not in self.subscriptions: raise ValueError("No subscription for {name} found.".format( name=request.method)) @@ -498,15 +565,19 @@ class Plugin(object): except Exception: self.log(traceback.format_exc()) - def _write_locked(self, obj): + def _write_locked(self, obj: JSONType) -> None: # ensure_ascii turns UTF-8 into \uXXXX so we need to suppress that, # then utf8 ourselves. - s = bytes(json.dumps(obj, cls=LightningRpc.LightningJSONEncoder, ensure_ascii=False) + "\n\n", encoding='utf-8') + s = bytes(json.dumps( + obj, + cls=LightningRpc.LightningJSONEncoder, + ensure_ascii=False + ) + "\n\n", encoding='utf-8') with self.write_lock: self.stdout.buffer.write(s) self.stdout.flush() - def notify(self, method, params): + def notify(self, method: str, params: JSONType) -> None: payload = { 'jsonrpc': '2.0', 'method': method, @@ -514,30 +585,35 @@ class Plugin(object): } self._write_locked(payload) - def log(self, message, level='info'): + def log(self, message: str, level: str = 'info') -> None: # Split the log into multiple lines and print them # individually. Makes tracebacks much easier to read. for line in message.split('\n'): self.notify('log', {'level': level, 'message': line}) - def _parse_request(self, jsrequest): + def _parse_request(self, jsrequest: Dict[str, JSONType]) -> Request: + i = jsrequest.get('id', None) + if not isinstance(i, int) and i is not None: + raise ValueError('Non-integer request id "{i}"'.format(i=i)) + request = Request( plugin=self, - req_id=jsrequest.get('id', None), - method=jsrequest['method'], + req_id=i, + method=str(jsrequest['method']), params=jsrequest['params'], background=False, ) return request - def _multi_dispatch(self, msgs): + def _multi_dispatch(self, msgs: List[bytes]) -> bytes: """We received a couple of messages, now try to dispatch them all. Returns the last partial message that was not complete yet. """ for payload in msgs[:-1]: - # Note that we use function annotations to do Millisatoshi conversions - # in _exec_func, so we don't use LightningJSONDecoder here. + # Note that we use function annotations to do Millisatoshi + # conversions in _exec_func, so we don't use LightningJSONDecoder + # here. request = self._parse_request(json.loads(payload.decode('utf8'))) # If this has an 'id'-field, it's a request and returns a @@ -550,7 +626,7 @@ class Plugin(object): return msgs[-1] - def run(self): + def run(self) -> None: partial = b"" for l in self.stdin.buffer: partial += l @@ -561,7 +637,7 @@ class Plugin(object): partial = self._multi_dispatch(msgs) - def _getmanifest(self, **kwargs): + def _getmanifest(self, **kwargs) -> JSONType: if 'allow-deprecated-apis' in kwargs: self.deprecated_apis = kwargs['allow-deprecated-apis'] else: @@ -582,13 +658,21 @@ class Plugin(object): doc = inspect.getdoc(method.func) if not doc: self.log( - 'RPC method \'{}\' does not have a docstring.'.format(method.name) + 'RPC method \'{}\' does not have a docstring.'.format( + method.name + ) ) doc = "Undocumented RPC method from a plugin." doc = re.sub('\n+', ' ', doc) # Handles out-of-order use of parameters like: - # def hello_obfus(arg1, arg2, plugin, thing3, request=None, thing5='at', thing6=21) + # + # ```python3 + # + # def hello_obfus(arg1, arg2, plugin, thing3, request=None, + # thing5='at', thing6=21) + # + # ``` argspec = inspect.getfullargspec(method.func) defaults = argspec.defaults num_defaults = len(defaults) if defaults else 0 @@ -611,7 +695,8 @@ class Plugin(object): 'description': doc if not method.desc else method.desc }) if method.long_desc: - methods[len(methods) - 1]["long_description"] = method.long_desc + m = methods[len(methods) - 1] + m["long_description"] = method.long_desc manifest = { 'options': list(self.options.values()), @@ -628,12 +713,30 @@ class Plugin(object): return manifest - def _init(self, options, configuration, request): - self.rpc_filename = configuration['rpc-file'] - self.lightning_dir = configuration['lightning-dir'] + def _init(self, options: Dict[str, JSONType], + configuration: Dict[str, JSONType], + request: Request) -> JSONType: + + def verify_str(d: Dict[str, JSONType], key: str) -> str: + v = d.get(key) + if not isinstance(v, str): + raise ValueError("Wrong argument to init: expected {key} to be" + " a string, got {v}".format(key=key, v=v)) + return v + + def verify_bool(d: Dict[str, JSONType], key: str) -> bool: + v = d.get(key) + if not isinstance(v, bool): + raise ValueError("Wrong argument to init: expected {key} to be" + " a bool, got {v}".format(key=key, v=v)) + return v + + self.rpc_filename = verify_str(configuration, 'rpc-file') + self.lightning_dir = verify_str(configuration, 'lightning-dir') + path = os.path.join(self.lightning_dir, self.rpc_filename) self.rpc = LightningRpc(path) - self.startup = configuration['startup'] + self.startup = verify_bool(configuration, 'startup') for name, value in options.items(): self.options[name]['value'] = value @@ -647,18 +750,18 @@ class PluginStream(object): """Sink that turns everything that is written to it into a notification. """ - def __init__(self, plugin, level="info"): + def __init__(self, plugin: Plugin, level: str = "info"): self.plugin = plugin self.level = level self.buff = '' - def write(self, payload): + def write(self, payload: str) -> None: self.buff += payload if len(payload) > 0 and payload[-1] == '\n': self.flush() - def flush(self): + def flush(self) -> None: lines = self.buff.split('\n') if len(lines) < 2: return @@ -670,7 +773,8 @@ class PluginStream(object): self.buff = lines[-1] -def monkey_patch(plugin, stdout=True, stderr=False): +def monkey_patch(plugin: Plugin, stdout: bool = True, + stderr: bool = False) -> None: """Monkey patch stderr and stdout so we use notifications instead. A plugin commonly communicates with lightningd over its stdout and diff --git a/contrib/pyln-proto/pyln/proto/primitives.py b/contrib/pyln-proto/pyln/proto/primitives.py index 4c1d10ebe..b0bec15d3 100644 --- a/contrib/pyln-proto/pyln/proto/primitives.py +++ b/contrib/pyln-proto/pyln/proto/primitives.py @@ -1,3 +1,4 @@ +import coincurve import struct @@ -66,5 +67,79 @@ class ShortChannelId(object): def __str__(self): return "{self.block}x{self.txnum}x{self.outnum}".format(self=self) - def __eq__(self, other): - return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum + def __eq__(self, other: object) -> bool: + if not isinstance(other, ShortChannelId): + return False + + return ( + self.block == other.block + and self.txnum == other.txnum + and self.outnum == other.outnum + ) + + +class Secret(object): + def __init__(self, data: bytes) -> None: + assert(len(data) == 32) + self.data = data + + def to_bytes(self) -> bytes: + return self.data + + def __eq__(self, other: object) -> bool: + return isinstance(other, Secret) and self.data == other.data + + def __str__(self): + return "Secret[0x{}]".format(self.data.hex()) + + +class PrivateKey(object): + def __init__(self, rawkey) -> None: + if not isinstance(rawkey, bytes): + raise TypeError(f"rawkey must be bytes, {type(rawkey)} received") + elif len(rawkey) != 32: + raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received") + + self.rawkey = rawkey + self.key = coincurve.PrivateKey(rawkey) + + def serializeCompressed(self): + return self.key.secret + + def public_key(self): + return PublicKey(self.key.public_key) + + +class PublicKey(object): + def __init__(self, innerkey): + # We accept either 33-bytes raw keys, or an EC PublicKey as returned + # by coincurve + if isinstance(innerkey, bytes): + if innerkey[0] in [2, 3] and len(innerkey) == 33: + innerkey = coincurve.PublicKey(innerkey) + else: + raise ValueError( + "Byte keys must be 33-byte long starting from either 02 or 03" + ) + + elif not isinstance(innerkey, coincurve.keys.PublicKey): + raise ValueError( + "Key must either be bytes or coincurve.keys.PublicKey" + ) + self.key = innerkey + + def serializeCompressed(self): + return self.key.format(compressed=True) + + def to_bytes(self) -> bytes: + return self.serializeCompressed() + + def __str__(self): + return "PublicKey[0x{}]".format( + self.serializeCompressed().hex() + ) + + +def Keypair(object): + def __init__(self, priv, pub): + self.priv, self.pub = priv, pub diff --git a/contrib/pyln-proto/pyln/proto/wire.py b/contrib/pyln-proto/pyln/proto/wire.py index f306938c5..adb1f0907 100644 --- a/contrib/pyln-proto/pyln/proto/wire.py +++ b/contrib/pyln-proto/pyln/proto/wire.py @@ -4,6 +4,7 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from .primitives import Secret, PrivateKey, PublicKey from hashlib import sha256 import coincurve import os @@ -55,64 +56,6 @@ def decryptWithAD(k, n, ad, ciphertext): return chacha.decrypt(n, ciphertext, ad) -class PrivateKey(object): - def __init__(self, rawkey): - if not isinstance(rawkey, bytes): - raise TypeError(f"rawkey must be bytes, {type(rawkey)} received") - elif len(rawkey) != 32: - raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received") - - self.rawkey = rawkey - self.key = coincurve.PrivateKey(rawkey) - - def serializeCompressed(self): - return self.key.secret - - def public_key(self): - return PublicKey(self.key.public_key) - - -class Secret(object): - def __init__(self, raw): - assert(len(raw) == 32) - self.raw = raw - - def __str__(self): - return "Secret[0x{}]".format(self.raw.hex()) - - -class PublicKey(object): - def __init__(self, innerkey): - # We accept either 33-bytes raw keys, or an EC PublicKey as returned - # by coincurve - if isinstance(innerkey, bytes): - if innerkey[0] in [2, 3] and len(innerkey) == 33: - innerkey = coincurve.PublicKey(innerkey) - else: - raise ValueError( - "Byte keys must be 33-byte long starting from either 02 or 03" - ) - - elif not isinstance(innerkey, coincurve.keys.PublicKey): - raise ValueError( - "Key must either be bytes or coincurve.keys.PublicKey" - ) - self.key = innerkey - - def serializeCompressed(self): - return self.key.format(compressed=True) - - def __str__(self): - return "PublicKey[0x{}]".format( - self.serializeCompressed().hex() - ) - - -def Keypair(object): - def __init__(self, priv, pub): - self.priv, self.pub = priv, pub - - class Sha256Mixer(object): def __init__(self, base): self.hash = sha256(base).digest() @@ -174,7 +117,7 @@ class LightningConnection(object): h.hash = self.handshake['h'] h.update(self.handshake['e'].public_key().serializeCompressed()) es = ecdh(self.handshake['e'], self.remote_pubkey) - t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'') + t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'') assert(len(t) == 64) self.chaining_key, temp_k1 = t[:32], t[32:] c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'') @@ -194,7 +137,7 @@ class LightningConnection(object): h.update(re.serializeCompressed()) es = ecdh(self.local_privkey, re) self.handshake['re'] = re - t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'') + t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'') self.chaining_key, temp_k1 = t[:32], t[32:] try: @@ -210,7 +153,7 @@ class LightningConnection(object): h.hash = self.handshake['h'] h.update(self.handshake['e'].public_key().serializeCompressed()) ee = ecdh(self.handshake['e'], self.handshake['re']) - t = hkdf(salt=self.chaining_key, ikm=ee.raw, info=b'') + t = hkdf(salt=self.chaining_key, ikm=ee.data, info=b'') assert(len(t) == 64) self.chaining_key, self.temp_k2 = t[:32], t[32:] c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'') @@ -231,7 +174,7 @@ class LightningConnection(object): h.update(re.serializeCompressed()) ee = ecdh(self.handshake['e'], re) self.chaining_key, self.temp_k2 = hkdf_two_keys( - salt=self.chaining_key, ikm=ee.raw + salt=self.chaining_key, ikm=ee.data ) try: decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c) @@ -249,7 +192,7 @@ class LightningConnection(object): se = ecdh(self.local_privkey, self.re) self.chaining_key, self.temp_k3 = hkdf_two_keys( - salt=self.chaining_key, ikm=se.raw + salt=self.chaining_key, ikm=se.data ) t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'') m = b'\x00' + c + t @@ -272,7 +215,7 @@ class LightningConnection(object): se = ecdh(self.handshake['e'], self.remote_pubkey) self.chaining_key, self.temp_k3 = hkdf_two_keys( - se.raw, self.chaining_key + se.data, self.chaining_key ) decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t) self.rn, self.sn = 0, 0