pyln: Add type-annotations to plugin.py

This should help users that have type-checking enabled.
This commit is contained in:
Christian Decker 2020-06-26 16:39:02 +02:00 committed by Rusty Russell
parent d27da4d152
commit 49ec800a07
3 changed files with 267 additions and 145 deletions

View File

@ -1,10 +1,12 @@
from .lightning import LightningRpc, Millisatoshi
from binascii import hexlify
from collections import OrderedDict
from enum import Enum
from .lightning import LightningRpc, Millisatoshi
from threading import RLock
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import inspect
import io
import json
import math
import os
@ -12,6 +14,16 @@ import re
import sys
import traceback
# Notice that this definition is incomplete as it only checks the
# top-level. Arrays and Dicts could contain types that aren't encodeable. This
# limitation stems from the fact that recursive types are not really supported
# yet.
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
# Yes, decorators are weird...
NoneDecoratorType = Callable[..., Callable[..., None]]
JsonDecoratorType = Callable[..., Callable[..., JSONType]]
class MethodType(Enum):
RPCMETHOD = 0
@ -32,8 +44,10 @@ class Method(object):
- RPC exposed by RPC passthrough
- HOOK registered to be called synchronously by lightningd
"""
def __init__(self, name, func, mtype=MethodType.RPCMETHOD, category=None,
desc=None, long_desc=None, deprecated=False):
def __init__(self, name: str, func: Callable[..., JSONType],
mtype: MethodType = MethodType.RPCMETHOD,
category: str = None, desc: str = None,
long_desc: str = None, deprecated: bool = False):
self.name = name
self.func = func
self.mtype = mtype
@ -47,7 +61,8 @@ class Method(object):
class Request(dict):
"""A request object that wraps params and allows async return
"""
def __init__(self, plugin, req_id, method, params, background=False):
def __init__(self, plugin: 'Plugin', req_id: Optional[int], method: str,
params: Any, background: bool = False):
self.method = method
self.params = params
self.background = background
@ -55,15 +70,19 @@ class Request(dict):
self.state = RequestState.PENDING
self.id = req_id
def getattr(self, key):
def getattr(self, key: str) -> Union[Method, Any, int]:
if key == "params":
return self.params
elif key == "id":
return self.id
elif key == "method":
return self.method
else:
raise ValueError(
'Cannot get attribute "{key}" on Request'.format(key=key)
)
def set_result(self, result):
def set_result(self, result: Any) -> None:
if self.state != RequestState.PENDING:
raise ValueError(
"Cannot set the result of a request that is not pending, "
@ -75,7 +94,7 @@ class Request(dict):
'result': self.result
})
def set_exception(self, exc):
def set_exception(self, exc: Exception) -> None:
if self.state != RequestState.PENDING:
raise ValueError(
"Cannot set the exception of a request that is not pending, "
@ -93,7 +112,7 @@ class Request(dict):
},
})
def _write_result(self, result):
def _write_result(self, result: dict) -> None:
self.plugin._write_locked(result)
@ -126,12 +145,20 @@ class Plugin(object):
"""
def __init__(self, stdout=None, stdin=None, autopatch=True, dynamic=True,
init_features=None, node_features=None, invoice_features=None):
self.methods = {'init': Method('init', self._init, MethodType.RPCMETHOD)}
self.options = {}
def __init__(self, stdout: Optional[io.TextIOBase] = None,
stdin: Optional[io.TextIOBase] = None, autopatch: bool = True,
dynamic: bool = True,
init_features: Optional[Union[int, str, bytes]] = None,
node_features: Optional[Union[int, str, bytes]] = None,
invoice_features: Optional[Union[int, str, bytes]] = None):
self.methods = {
'init': Method('init', self._init, MethodType.RPCMETHOD)
}
def convert_featurebits(bits):
self.options: Dict[str, Dict[str, Any]] = {}
def convert_featurebits(
bits: Optional[Union[int, str, bytes]]) -> Optional[str]:
"""Convert the featurebits into the bytes required to hexencode.
"""
if bits is None:
@ -149,7 +176,9 @@ class Plugin(object):
return hexlify(bits).decode('ASCII')
else:
raise ValueError("Could not convert featurebits to hex-encoded string")
raise ValueError(
"Could not convert featurebits to hex-encoded string"
)
self.featurebits = {
'init': convert_featurebits(init_features),
@ -158,7 +187,7 @@ class Plugin(object):
}
# A dict from topics to handler functions
self.subscriptions = {}
self.subscriptions: Dict[str, Callable[..., None]] = {}
if not stdout:
self.stdout = sys.stdout
@ -172,17 +201,21 @@ class Plugin(object):
monkey_patch(self, stdout=True, stderr=True)
self.add_method("getmanifest", self._getmanifest, background=False)
self.rpc_filename = None
self.lightning_dir = None
self.rpc = None
self.rpc_filename: Optional[str] = None
self.lightning_dir: Optional[str] = None
self.rpc: Optional[LightningRpc] = None
self.startup = True
self.dynamic = dynamic
self.child_init = None
self.child_init: Optional[Callable[..., None]] = None
self.write_lock = RLock()
def add_method(self, name, func, background=False, category=None, desc=None,
long_desc=None, deprecated=False):
def add_method(self, name: str, func: Callable[..., Any],
background: bool = False,
category: Optional[str] = None,
desc: Optional[str] = None,
long_desc: Optional[str] = None,
deprecated: bool = False) -> None:
"""Add a plugin method to the dispatch table.
The function will be expected at call time (see `_dispatch`)
@ -221,11 +254,15 @@ class Plugin(object):
)
# Register the function with the name
method = Method(name, func, MethodType.RPCMETHOD, category, desc, long_desc, deprecated)
method = Method(
name, func, MethodType.RPCMETHOD, category, desc, long_desc,
deprecated
)
method.background = background
self.methods[name] = method
def add_subscription(self, topic, func):
def add_subscription(self, topic: str, func: Callable[..., None]) -> None:
"""Add a subscription to our list of subscriptions.
A subscription is an association between a topic and a handler
@ -243,9 +280,9 @@ class Plugin(object):
"Topic {} already has a handler".format(topic)
)
# Make sure the notification callback has a **kwargs argument so that it
# doesn't break if we add more arguments to the call later on. Issue a
# warning if it does not.
# Make sure the notification callback has a **kwargs argument so that
# it doesn't break if we add more arguments to the call later
# on. Issue a warning if it does not.
s = inspect.signature(func)
kinds = [p.kind for p in s.parameters.values()]
if inspect.Parameter.VAR_KEYWORD not in kinds:
@ -257,16 +294,20 @@ class Plugin(object):
self.subscriptions[topic] = func
def subscribe(self, topic):
def subscribe(self, topic: str) -> NoneDecoratorType:
"""Function decorator to register a notification handler.
"""
def decorator(f):
# Yes, decorator type annotations are just weird, don't think too much
# about it...
def decorator(f: Callable[..., None]) -> Callable[..., None]:
self.add_subscription(topic, f)
return f
return decorator
def add_option(self, name, default, description, opt_type="string",
deprecated=False):
def add_option(self, name: str, default: Optional[str],
description: Optional[str],
opt_type: str = "string", deprecated: bool = False) -> None:
"""Add an option that we'd like to register with lightningd.
Needs to be called before `Plugin.run`, otherwise we might not
@ -279,7 +320,9 @@ class Plugin(object):
)
if opt_type not in ["string", "int", "bool", "flag"]:
raise ValueError('{} not in supported type set (string, int, bool, flag)')
raise ValueError(
'{} not in supported type set (string, int, bool, flag)'
)
self.options[name] = {
'name': name,
@ -290,7 +333,8 @@ class Plugin(object):
'deprecated': deprecated,
}
def add_flag_option(self, name, description, deprecated=False):
def add_flag_option(self, name: str, description: str,
deprecated: bool = False) -> None:
"""Add a flag option that we'd like to register with lightningd.
Needs to be called before `Plugin.run`, otherwise we might not
@ -300,7 +344,7 @@ class Plugin(object):
self.add_option(name, None, description, opt_type="flag",
deprecated=deprecated)
def get_option(self, name):
def get_option(self, name: str) -> str:
if name not in self.options:
raise ValueError("No option with name {} registered".format(name))
@ -309,31 +353,42 @@ class Plugin(object):
else:
return self.options[name]['default']
def async_method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False):
def async_method(self, method_name: str, category: Optional[str] = None,
desc: Optional[str] = None,
long_desc: Optional[str] = None,
deprecated: bool = False) -> NoneDecoratorType:
"""Decorator to add an async plugin method to the dispatch table.
Internally uses add_method.
"""
def decorator(f):
def decorator(f: Callable[..., None]) -> Callable[..., None]:
self.add_method(method_name, f, background=True, category=category,
desc=desc, long_desc=long_desc,
deprecated=deprecated)
return f
return decorator
def method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False):
def method(self, method_name: str, category: Optional[str] = None,
desc: Optional[str] = None,
long_desc: Optional[str] = None,
deprecated: bool = False) -> JsonDecoratorType:
"""Decorator to add a plugin method to the dispatch table.
Internally uses add_method.
"""
def decorator(f):
self.add_method(method_name, f, background=False, category=category,
desc=desc, long_desc=long_desc,
def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
self.add_method(method_name,
f,
background=False,
category=category,
desc=desc,
long_desc=long_desc,
deprecated=deprecated)
return f
return decorator
def add_hook(self, name, func, background=False):
def add_hook(self, name: str, func: Callable[..., JSONType],
background: bool = False) -> None:
"""Register a hook that is called synchronously by lightningd on events
"""
if name in self.methods:
@ -357,40 +412,47 @@ class Plugin(object):
method.background = background
self.methods[name] = method
def hook(self, method_name):
def hook(self, method_name: str) -> JsonDecoratorType:
"""Decorator to add a plugin hook to the dispatch table.
Internally uses add_hook.
"""
def decorator(f):
def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
self.add_hook(method_name, f, background=False)
return f
return decorator
def async_hook(self, method_name):
def async_hook(self, method_name: str) -> NoneDecoratorType:
"""Decorator to add an async plugin hook to the dispatch table.
Internally uses add_hook.
"""
def decorator(f):
def decorator(f: Callable[..., None]) -> Callable[..., None]:
self.add_hook(method_name, f, background=True)
return f
return decorator
def init(self, *args, **kwargs):
def init(self) -> NoneDecoratorType:
"""Decorator to add a function called after plugin initialization
"""
def decorator(f):
def decorator(f: Callable[..., None]) -> Callable[..., None]:
if self.child_init is not None:
raise ValueError('The @plugin.init decorator should only be used once')
raise ValueError(
'The @plugin.init decorator should only be used once'
)
self.child_init = f
return f
return decorator
@staticmethod
def _coerce_arguments(func, ba):
def _coerce_arguments(
func: Callable[..., Any],
ba: inspect.BoundArguments) -> inspect.BoundArguments:
args = OrderedDict()
annotations = func.__annotations__ if hasattr(func, "__annotations__") else {}
annotations = {}
if hasattr(func, "__annotations__"):
annotations = func.__annotations__
for key, val in ba.arguments.items():
annotation = annotations.get(key, None)
if annotation is not None and annotation == Millisatoshi:
@ -400,7 +462,8 @@ class Plugin(object):
ba.arguments = args
return ba
def _bind_pos(self, func, params, request):
def _bind_pos(self, func: Callable[..., Any], params: List[str],
request: Request) -> inspect.BoundArguments:
"""Positional binding of parameters
"""
assert(isinstance(params, list))
@ -409,7 +472,7 @@ class Plugin(object):
# Collect injections so we can sort them and insert them in the right
# order later. If we don't apply inject them in increasing order we
# might shift away an earlier injection.
injections = []
injections: List[Tuple[int, Any]] = []
if 'plugin' in sig.parameters:
pos = list(sig.parameters.keys()).index('plugin')
injections.append((pos, self))
@ -425,7 +488,8 @@ class Plugin(object):
ba.apply_defaults()
return ba
def _bind_kwargs(self, func, params, request):
def _bind_kwargs(self, func: Callable[..., Any], params: Dict[str, Any],
request: Request) -> inspect.BoundArguments:
"""Keyword based binding of parameters
"""
assert(isinstance(params, dict))
@ -445,7 +509,8 @@ class Plugin(object):
self._coerce_arguments(func, ba)
return ba
def _exec_func(self, func, request):
def _exec_func(self, func: Callable[..., Any],
request: Request) -> JSONType:
params = request.params
if isinstance(params, list):
ba = self._bind_pos(func, params, request)
@ -454,9 +519,11 @@ class Plugin(object):
ba = self._bind_kwargs(func, params, request)
return func(*ba.args, **ba.kwargs)
else:
raise TypeError("Parameters to function call must be either a dict or a list.")
raise TypeError(
"Parameters to function call must be either a dict or a list."
)
def _dispatch_request(self, request):
def _dispatch_request(self, request: Request) -> None:
name = request.method
if name not in self.methods:
@ -487,7 +554,7 @@ class Plugin(object):
request.set_exception(e)
self.log(traceback.format_exc())
def _dispatch_notification(self, request):
def _dispatch_notification(self, request: Request) -> None:
if request.method not in self.subscriptions:
raise ValueError("No subscription for {name} found.".format(
name=request.method))
@ -498,15 +565,19 @@ class Plugin(object):
except Exception:
self.log(traceback.format_exc())
def _write_locked(self, obj):
def _write_locked(self, obj: JSONType) -> None:
# ensure_ascii turns UTF-8 into \uXXXX so we need to suppress that,
# then utf8 ourselves.
s = bytes(json.dumps(obj, cls=LightningRpc.LightningJSONEncoder, ensure_ascii=False) + "\n\n", encoding='utf-8')
s = bytes(json.dumps(
obj,
cls=LightningRpc.LightningJSONEncoder,
ensure_ascii=False
) + "\n\n", encoding='utf-8')
with self.write_lock:
self.stdout.buffer.write(s)
self.stdout.flush()
def notify(self, method, params):
def notify(self, method: str, params: JSONType) -> None:
payload = {
'jsonrpc': '2.0',
'method': method,
@ -514,30 +585,35 @@ class Plugin(object):
}
self._write_locked(payload)
def log(self, message, level='info'):
def log(self, message: str, level: str = 'info') -> None:
# Split the log into multiple lines and print them
# individually. Makes tracebacks much easier to read.
for line in message.split('\n'):
self.notify('log', {'level': level, 'message': line})
def _parse_request(self, jsrequest):
def _parse_request(self, jsrequest: Dict[str, JSONType]) -> Request:
i = jsrequest.get('id', None)
if not isinstance(i, int) and i is not None:
raise ValueError('Non-integer request id "{i}"'.format(i=i))
request = Request(
plugin=self,
req_id=jsrequest.get('id', None),
method=jsrequest['method'],
req_id=i,
method=str(jsrequest['method']),
params=jsrequest['params'],
background=False,
)
return request
def _multi_dispatch(self, msgs):
def _multi_dispatch(self, msgs: List[bytes]) -> bytes:
"""We received a couple of messages, now try to dispatch them all.
Returns the last partial message that was not complete yet.
"""
for payload in msgs[:-1]:
# Note that we use function annotations to do Millisatoshi conversions
# in _exec_func, so we don't use LightningJSONDecoder here.
# Note that we use function annotations to do Millisatoshi
# conversions in _exec_func, so we don't use LightningJSONDecoder
# here.
request = self._parse_request(json.loads(payload.decode('utf8')))
# If this has an 'id'-field, it's a request and returns a
@ -550,7 +626,7 @@ class Plugin(object):
return msgs[-1]
def run(self):
def run(self) -> None:
partial = b""
for l in self.stdin.buffer:
partial += l
@ -561,7 +637,7 @@ class Plugin(object):
partial = self._multi_dispatch(msgs)
def _getmanifest(self, **kwargs):
def _getmanifest(self, **kwargs) -> JSONType:
if 'allow-deprecated-apis' in kwargs:
self.deprecated_apis = kwargs['allow-deprecated-apis']
else:
@ -582,13 +658,21 @@ class Plugin(object):
doc = inspect.getdoc(method.func)
if not doc:
self.log(
'RPC method \'{}\' does not have a docstring.'.format(method.name)
'RPC method \'{}\' does not have a docstring.'.format(
method.name
)
)
doc = "Undocumented RPC method from a plugin."
doc = re.sub('\n+', ' ', doc)
# Handles out-of-order use of parameters like:
# def hello_obfus(arg1, arg2, plugin, thing3, request=None, thing5='at', thing6=21)
#
# ```python3
#
# def hello_obfus(arg1, arg2, plugin, thing3, request=None,
# thing5='at', thing6=21)
#
# ```
argspec = inspect.getfullargspec(method.func)
defaults = argspec.defaults
num_defaults = len(defaults) if defaults else 0
@ -611,7 +695,8 @@ class Plugin(object):
'description': doc if not method.desc else method.desc
})
if method.long_desc:
methods[len(methods) - 1]["long_description"] = method.long_desc
m = methods[len(methods) - 1]
m["long_description"] = method.long_desc
manifest = {
'options': list(self.options.values()),
@ -628,12 +713,30 @@ class Plugin(object):
return manifest
def _init(self, options, configuration, request):
self.rpc_filename = configuration['rpc-file']
self.lightning_dir = configuration['lightning-dir']
def _init(self, options: Dict[str, JSONType],
configuration: Dict[str, JSONType],
request: Request) -> JSONType:
def verify_str(d: Dict[str, JSONType], key: str) -> str:
v = d.get(key)
if not isinstance(v, str):
raise ValueError("Wrong argument to init: expected {key} to be"
" a string, got {v}".format(key=key, v=v))
return v
def verify_bool(d: Dict[str, JSONType], key: str) -> bool:
v = d.get(key)
if not isinstance(v, bool):
raise ValueError("Wrong argument to init: expected {key} to be"
" a bool, got {v}".format(key=key, v=v))
return v
self.rpc_filename = verify_str(configuration, 'rpc-file')
self.lightning_dir = verify_str(configuration, 'lightning-dir')
path = os.path.join(self.lightning_dir, self.rpc_filename)
self.rpc = LightningRpc(path)
self.startup = configuration['startup']
self.startup = verify_bool(configuration, 'startup')
for name, value in options.items():
self.options[name]['value'] = value
@ -647,18 +750,18 @@ class PluginStream(object):
"""Sink that turns everything that is written to it into a notification.
"""
def __init__(self, plugin, level="info"):
def __init__(self, plugin: Plugin, level: str = "info"):
self.plugin = plugin
self.level = level
self.buff = ''
def write(self, payload):
def write(self, payload: str) -> None:
self.buff += payload
if len(payload) > 0 and payload[-1] == '\n':
self.flush()
def flush(self):
def flush(self) -> None:
lines = self.buff.split('\n')
if len(lines) < 2:
return
@ -670,7 +773,8 @@ class PluginStream(object):
self.buff = lines[-1]
def monkey_patch(plugin, stdout=True, stderr=False):
def monkey_patch(plugin: Plugin, stdout: bool = True,
stderr: bool = False) -> None:
"""Monkey patch stderr and stdout so we use notifications instead.
A plugin commonly communicates with lightningd over its stdout and

View File

@ -1,3 +1,4 @@
import coincurve
import struct
@ -66,5 +67,79 @@ class ShortChannelId(object):
def __str__(self):
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self)
def __eq__(self, other):
return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum
def __eq__(self, other: object) -> bool:
if not isinstance(other, ShortChannelId):
return False
return (
self.block == other.block
and self.txnum == other.txnum
and self.outnum == other.outnum
)
class Secret(object):
def __init__(self, data: bytes) -> None:
assert(len(data) == 32)
self.data = data
def to_bytes(self) -> bytes:
return self.data
def __eq__(self, other: object) -> bool:
return isinstance(other, Secret) and self.data == other.data
def __str__(self):
return "Secret[0x{}]".format(self.data.hex())
class PrivateKey(object):
def __init__(self, rawkey) -> None:
if not isinstance(rawkey, bytes):
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
elif len(rawkey) != 32:
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
self.rawkey = rawkey
self.key = coincurve.PrivateKey(rawkey)
def serializeCompressed(self):
return self.key.secret
def public_key(self):
return PublicKey(self.key.public_key)
class PublicKey(object):
def __init__(self, innerkey):
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
# by coincurve
if isinstance(innerkey, bytes):
if innerkey[0] in [2, 3] and len(innerkey) == 33:
innerkey = coincurve.PublicKey(innerkey)
else:
raise ValueError(
"Byte keys must be 33-byte long starting from either 02 or 03"
)
elif not isinstance(innerkey, coincurve.keys.PublicKey):
raise ValueError(
"Key must either be bytes or coincurve.keys.PublicKey"
)
self.key = innerkey
def serializeCompressed(self):
return self.key.format(compressed=True)
def to_bytes(self) -> bytes:
return self.serializeCompressed()
def __str__(self):
return "PublicKey[0x{}]".format(
self.serializeCompressed().hex()
)
def Keypair(object):
def __init__(self, priv, pub):
self.priv, self.pub = priv, pub

View File

@ -4,6 +4,7 @@ from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from .primitives import Secret, PrivateKey, PublicKey
from hashlib import sha256
import coincurve
import os
@ -55,64 +56,6 @@ def decryptWithAD(k, n, ad, ciphertext):
return chacha.decrypt(n, ciphertext, ad)
class PrivateKey(object):
def __init__(self, rawkey):
if not isinstance(rawkey, bytes):
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
elif len(rawkey) != 32:
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
self.rawkey = rawkey
self.key = coincurve.PrivateKey(rawkey)
def serializeCompressed(self):
return self.key.secret
def public_key(self):
return PublicKey(self.key.public_key)
class Secret(object):
def __init__(self, raw):
assert(len(raw) == 32)
self.raw = raw
def __str__(self):
return "Secret[0x{}]".format(self.raw.hex())
class PublicKey(object):
def __init__(self, innerkey):
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
# by coincurve
if isinstance(innerkey, bytes):
if innerkey[0] in [2, 3] and len(innerkey) == 33:
innerkey = coincurve.PublicKey(innerkey)
else:
raise ValueError(
"Byte keys must be 33-byte long starting from either 02 or 03"
)
elif not isinstance(innerkey, coincurve.keys.PublicKey):
raise ValueError(
"Key must either be bytes or coincurve.keys.PublicKey"
)
self.key = innerkey
def serializeCompressed(self):
return self.key.format(compressed=True)
def __str__(self):
return "PublicKey[0x{}]".format(
self.serializeCompressed().hex()
)
def Keypair(object):
def __init__(self, priv, pub):
self.priv, self.pub = priv, pub
class Sha256Mixer(object):
def __init__(self, base):
self.hash = sha256(base).digest()
@ -174,7 +117,7 @@ class LightningConnection(object):
h.hash = self.handshake['h']
h.update(self.handshake['e'].public_key().serializeCompressed())
es = ecdh(self.handshake['e'], self.remote_pubkey)
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
assert(len(t) == 64)
self.chaining_key, temp_k1 = t[:32], t[32:]
c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'')
@ -194,7 +137,7 @@ class LightningConnection(object):
h.update(re.serializeCompressed())
es = ecdh(self.local_privkey, re)
self.handshake['re'] = re
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
self.chaining_key, temp_k1 = t[:32], t[32:]
try:
@ -210,7 +153,7 @@ class LightningConnection(object):
h.hash = self.handshake['h']
h.update(self.handshake['e'].public_key().serializeCompressed())
ee = ecdh(self.handshake['e'], self.handshake['re'])
t = hkdf(salt=self.chaining_key, ikm=ee.raw, info=b'')
t = hkdf(salt=self.chaining_key, ikm=ee.data, info=b'')
assert(len(t) == 64)
self.chaining_key, self.temp_k2 = t[:32], t[32:]
c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'')
@ -231,7 +174,7 @@ class LightningConnection(object):
h.update(re.serializeCompressed())
ee = ecdh(self.handshake['e'], re)
self.chaining_key, self.temp_k2 = hkdf_two_keys(
salt=self.chaining_key, ikm=ee.raw
salt=self.chaining_key, ikm=ee.data
)
try:
decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c)
@ -249,7 +192,7 @@ class LightningConnection(object):
se = ecdh(self.local_privkey, self.re)
self.chaining_key, self.temp_k3 = hkdf_two_keys(
salt=self.chaining_key, ikm=se.raw
salt=self.chaining_key, ikm=se.data
)
t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'')
m = b'\x00' + c + t
@ -272,7 +215,7 @@ class LightningConnection(object):
se = ecdh(self.handshake['e'], self.remote_pubkey)
self.chaining_key, self.temp_k3 = hkdf_two_keys(
se.raw, self.chaining_key
se.data, self.chaining_key
)
decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t)
self.rn, self.sn = 0, 0