mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-01-18 05:12:45 +01:00
pyln: Add type-annotations to plugin.py
This should help users that have type-checking enabled.
This commit is contained in:
parent
d27da4d152
commit
49ec800a07
@ -1,10 +1,12 @@
|
||||
from .lightning import LightningRpc, Millisatoshi
|
||||
from binascii import hexlify
|
||||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from .lightning import LightningRpc, Millisatoshi
|
||||
from threading import RLock
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@ -12,6 +14,16 @@ import re
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Notice that this definition is incomplete as it only checks the
|
||||
# top-level. Arrays and Dicts could contain types that aren't encodeable. This
|
||||
# limitation stems from the fact that recursive types are not really supported
|
||||
# yet.
|
||||
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
|
||||
|
||||
# Yes, decorators are weird...
|
||||
NoneDecoratorType = Callable[..., Callable[..., None]]
|
||||
JsonDecoratorType = Callable[..., Callable[..., JSONType]]
|
||||
|
||||
|
||||
class MethodType(Enum):
|
||||
RPCMETHOD = 0
|
||||
@ -32,8 +44,10 @@ class Method(object):
|
||||
- RPC exposed by RPC passthrough
|
||||
- HOOK registered to be called synchronously by lightningd
|
||||
"""
|
||||
def __init__(self, name, func, mtype=MethodType.RPCMETHOD, category=None,
|
||||
desc=None, long_desc=None, deprecated=False):
|
||||
def __init__(self, name: str, func: Callable[..., JSONType],
|
||||
mtype: MethodType = MethodType.RPCMETHOD,
|
||||
category: str = None, desc: str = None,
|
||||
long_desc: str = None, deprecated: bool = False):
|
||||
self.name = name
|
||||
self.func = func
|
||||
self.mtype = mtype
|
||||
@ -47,7 +61,8 @@ class Method(object):
|
||||
class Request(dict):
|
||||
"""A request object that wraps params and allows async return
|
||||
"""
|
||||
def __init__(self, plugin, req_id, method, params, background=False):
|
||||
def __init__(self, plugin: 'Plugin', req_id: Optional[int], method: str,
|
||||
params: Any, background: bool = False):
|
||||
self.method = method
|
||||
self.params = params
|
||||
self.background = background
|
||||
@ -55,15 +70,19 @@ class Request(dict):
|
||||
self.state = RequestState.PENDING
|
||||
self.id = req_id
|
||||
|
||||
def getattr(self, key):
|
||||
def getattr(self, key: str) -> Union[Method, Any, int]:
|
||||
if key == "params":
|
||||
return self.params
|
||||
elif key == "id":
|
||||
return self.id
|
||||
elif key == "method":
|
||||
return self.method
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot get attribute "{key}" on Request'.format(key=key)
|
||||
)
|
||||
|
||||
def set_result(self, result):
|
||||
def set_result(self, result: Any) -> None:
|
||||
if self.state != RequestState.PENDING:
|
||||
raise ValueError(
|
||||
"Cannot set the result of a request that is not pending, "
|
||||
@ -75,7 +94,7 @@ class Request(dict):
|
||||
'result': self.result
|
||||
})
|
||||
|
||||
def set_exception(self, exc):
|
||||
def set_exception(self, exc: Exception) -> None:
|
||||
if self.state != RequestState.PENDING:
|
||||
raise ValueError(
|
||||
"Cannot set the exception of a request that is not pending, "
|
||||
@ -93,7 +112,7 @@ class Request(dict):
|
||||
},
|
||||
})
|
||||
|
||||
def _write_result(self, result):
|
||||
def _write_result(self, result: dict) -> None:
|
||||
self.plugin._write_locked(result)
|
||||
|
||||
|
||||
@ -126,12 +145,20 @@ class Plugin(object):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, stdout=None, stdin=None, autopatch=True, dynamic=True,
|
||||
init_features=None, node_features=None, invoice_features=None):
|
||||
self.methods = {'init': Method('init', self._init, MethodType.RPCMETHOD)}
|
||||
self.options = {}
|
||||
def __init__(self, stdout: Optional[io.TextIOBase] = None,
|
||||
stdin: Optional[io.TextIOBase] = None, autopatch: bool = True,
|
||||
dynamic: bool = True,
|
||||
init_features: Optional[Union[int, str, bytes]] = None,
|
||||
node_features: Optional[Union[int, str, bytes]] = None,
|
||||
invoice_features: Optional[Union[int, str, bytes]] = None):
|
||||
self.methods = {
|
||||
'init': Method('init', self._init, MethodType.RPCMETHOD)
|
||||
}
|
||||
|
||||
def convert_featurebits(bits):
|
||||
self.options: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def convert_featurebits(
|
||||
bits: Optional[Union[int, str, bytes]]) -> Optional[str]:
|
||||
"""Convert the featurebits into the bytes required to hexencode.
|
||||
"""
|
||||
if bits is None:
|
||||
@ -149,7 +176,9 @@ class Plugin(object):
|
||||
return hexlify(bits).decode('ASCII')
|
||||
|
||||
else:
|
||||
raise ValueError("Could not convert featurebits to hex-encoded string")
|
||||
raise ValueError(
|
||||
"Could not convert featurebits to hex-encoded string"
|
||||
)
|
||||
|
||||
self.featurebits = {
|
||||
'init': convert_featurebits(init_features),
|
||||
@ -158,7 +187,7 @@ class Plugin(object):
|
||||
}
|
||||
|
||||
# A dict from topics to handler functions
|
||||
self.subscriptions = {}
|
||||
self.subscriptions: Dict[str, Callable[..., None]] = {}
|
||||
|
||||
if not stdout:
|
||||
self.stdout = sys.stdout
|
||||
@ -172,17 +201,21 @@ class Plugin(object):
|
||||
monkey_patch(self, stdout=True, stderr=True)
|
||||
|
||||
self.add_method("getmanifest", self._getmanifest, background=False)
|
||||
self.rpc_filename = None
|
||||
self.lightning_dir = None
|
||||
self.rpc = None
|
||||
self.rpc_filename: Optional[str] = None
|
||||
self.lightning_dir: Optional[str] = None
|
||||
self.rpc: Optional[LightningRpc] = None
|
||||
self.startup = True
|
||||
self.dynamic = dynamic
|
||||
self.child_init = None
|
||||
self.child_init: Optional[Callable[..., None]] = None
|
||||
|
||||
self.write_lock = RLock()
|
||||
|
||||
def add_method(self, name, func, background=False, category=None, desc=None,
|
||||
long_desc=None, deprecated=False):
|
||||
def add_method(self, name: str, func: Callable[..., Any],
|
||||
background: bool = False,
|
||||
category: Optional[str] = None,
|
||||
desc: Optional[str] = None,
|
||||
long_desc: Optional[str] = None,
|
||||
deprecated: bool = False) -> None:
|
||||
"""Add a plugin method to the dispatch table.
|
||||
|
||||
The function will be expected at call time (see `_dispatch`)
|
||||
@ -221,11 +254,15 @@ class Plugin(object):
|
||||
)
|
||||
|
||||
# Register the function with the name
|
||||
method = Method(name, func, MethodType.RPCMETHOD, category, desc, long_desc, deprecated)
|
||||
method = Method(
|
||||
name, func, MethodType.RPCMETHOD, category, desc, long_desc,
|
||||
deprecated
|
||||
)
|
||||
|
||||
method.background = background
|
||||
self.methods[name] = method
|
||||
|
||||
def add_subscription(self, topic, func):
|
||||
def add_subscription(self, topic: str, func: Callable[..., None]) -> None:
|
||||
"""Add a subscription to our list of subscriptions.
|
||||
|
||||
A subscription is an association between a topic and a handler
|
||||
@ -243,9 +280,9 @@ class Plugin(object):
|
||||
"Topic {} already has a handler".format(topic)
|
||||
)
|
||||
|
||||
# Make sure the notification callback has a **kwargs argument so that it
|
||||
# doesn't break if we add more arguments to the call later on. Issue a
|
||||
# warning if it does not.
|
||||
# Make sure the notification callback has a **kwargs argument so that
|
||||
# it doesn't break if we add more arguments to the call later
|
||||
# on. Issue a warning if it does not.
|
||||
s = inspect.signature(func)
|
||||
kinds = [p.kind for p in s.parameters.values()]
|
||||
if inspect.Parameter.VAR_KEYWORD not in kinds:
|
||||
@ -257,16 +294,20 @@ class Plugin(object):
|
||||
|
||||
self.subscriptions[topic] = func
|
||||
|
||||
def subscribe(self, topic):
|
||||
def subscribe(self, topic: str) -> NoneDecoratorType:
|
||||
"""Function decorator to register a notification handler.
|
||||
|
||||
"""
|
||||
def decorator(f):
|
||||
# Yes, decorator type annotations are just weird, don't think too much
|
||||
# about it...
|
||||
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||
self.add_subscription(topic, f)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def add_option(self, name, default, description, opt_type="string",
|
||||
deprecated=False):
|
||||
def add_option(self, name: str, default: Optional[str],
|
||||
description: Optional[str],
|
||||
opt_type: str = "string", deprecated: bool = False) -> None:
|
||||
"""Add an option that we'd like to register with lightningd.
|
||||
|
||||
Needs to be called before `Plugin.run`, otherwise we might not
|
||||
@ -279,7 +320,9 @@ class Plugin(object):
|
||||
)
|
||||
|
||||
if opt_type not in ["string", "int", "bool", "flag"]:
|
||||
raise ValueError('{} not in supported type set (string, int, bool, flag)')
|
||||
raise ValueError(
|
||||
'{} not in supported type set (string, int, bool, flag)'
|
||||
)
|
||||
|
||||
self.options[name] = {
|
||||
'name': name,
|
||||
@ -290,7 +333,8 @@ class Plugin(object):
|
||||
'deprecated': deprecated,
|
||||
}
|
||||
|
||||
def add_flag_option(self, name, description, deprecated=False):
|
||||
def add_flag_option(self, name: str, description: str,
|
||||
deprecated: bool = False) -> None:
|
||||
"""Add a flag option that we'd like to register with lightningd.
|
||||
|
||||
Needs to be called before `Plugin.run`, otherwise we might not
|
||||
@ -300,7 +344,7 @@ class Plugin(object):
|
||||
self.add_option(name, None, description, opt_type="flag",
|
||||
deprecated=deprecated)
|
||||
|
||||
def get_option(self, name):
|
||||
def get_option(self, name: str) -> str:
|
||||
if name not in self.options:
|
||||
raise ValueError("No option with name {} registered".format(name))
|
||||
|
||||
@ -309,31 +353,42 @@ class Plugin(object):
|
||||
else:
|
||||
return self.options[name]['default']
|
||||
|
||||
def async_method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False):
|
||||
def async_method(self, method_name: str, category: Optional[str] = None,
|
||||
desc: Optional[str] = None,
|
||||
long_desc: Optional[str] = None,
|
||||
deprecated: bool = False) -> NoneDecoratorType:
|
||||
"""Decorator to add an async plugin method to the dispatch table.
|
||||
|
||||
Internally uses add_method.
|
||||
"""
|
||||
def decorator(f):
|
||||
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||
self.add_method(method_name, f, background=True, category=category,
|
||||
desc=desc, long_desc=long_desc,
|
||||
deprecated=deprecated)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def method(self, method_name, category=None, desc=None, long_desc=None, deprecated=False):
|
||||
def method(self, method_name: str, category: Optional[str] = None,
|
||||
desc: Optional[str] = None,
|
||||
long_desc: Optional[str] = None,
|
||||
deprecated: bool = False) -> JsonDecoratorType:
|
||||
"""Decorator to add a plugin method to the dispatch table.
|
||||
|
||||
Internally uses add_method.
|
||||
"""
|
||||
def decorator(f):
|
||||
self.add_method(method_name, f, background=False, category=category,
|
||||
desc=desc, long_desc=long_desc,
|
||||
def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
|
||||
self.add_method(method_name,
|
||||
f,
|
||||
background=False,
|
||||
category=category,
|
||||
desc=desc,
|
||||
long_desc=long_desc,
|
||||
deprecated=deprecated)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def add_hook(self, name, func, background=False):
|
||||
def add_hook(self, name: str, func: Callable[..., JSONType],
|
||||
background: bool = False) -> None:
|
||||
"""Register a hook that is called synchronously by lightningd on events
|
||||
"""
|
||||
if name in self.methods:
|
||||
@ -357,40 +412,47 @@ class Plugin(object):
|
||||
method.background = background
|
||||
self.methods[name] = method
|
||||
|
||||
def hook(self, method_name):
|
||||
def hook(self, method_name: str) -> JsonDecoratorType:
|
||||
"""Decorator to add a plugin hook to the dispatch table.
|
||||
|
||||
Internally uses add_hook.
|
||||
"""
|
||||
def decorator(f):
|
||||
def decorator(f: Callable[..., JSONType]) -> Callable[..., JSONType]:
|
||||
self.add_hook(method_name, f, background=False)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def async_hook(self, method_name):
|
||||
def async_hook(self, method_name: str) -> NoneDecoratorType:
|
||||
"""Decorator to add an async plugin hook to the dispatch table.
|
||||
|
||||
Internally uses add_hook.
|
||||
"""
|
||||
def decorator(f):
|
||||
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||
self.add_hook(method_name, f, background=True)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def init(self, *args, **kwargs):
|
||||
def init(self) -> NoneDecoratorType:
|
||||
"""Decorator to add a function called after plugin initialization
|
||||
"""
|
||||
def decorator(f):
|
||||
def decorator(f: Callable[..., None]) -> Callable[..., None]:
|
||||
if self.child_init is not None:
|
||||
raise ValueError('The @plugin.init decorator should only be used once')
|
||||
raise ValueError(
|
||||
'The @plugin.init decorator should only be used once'
|
||||
)
|
||||
self.child_init = f
|
||||
return f
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _coerce_arguments(func, ba):
|
||||
def _coerce_arguments(
|
||||
func: Callable[..., Any],
|
||||
ba: inspect.BoundArguments) -> inspect.BoundArguments:
|
||||
args = OrderedDict()
|
||||
annotations = func.__annotations__ if hasattr(func, "__annotations__") else {}
|
||||
annotations = {}
|
||||
if hasattr(func, "__annotations__"):
|
||||
annotations = func.__annotations__
|
||||
|
||||
for key, val in ba.arguments.items():
|
||||
annotation = annotations.get(key, None)
|
||||
if annotation is not None and annotation == Millisatoshi:
|
||||
@ -400,7 +462,8 @@ class Plugin(object):
|
||||
ba.arguments = args
|
||||
return ba
|
||||
|
||||
def _bind_pos(self, func, params, request):
|
||||
def _bind_pos(self, func: Callable[..., Any], params: List[str],
|
||||
request: Request) -> inspect.BoundArguments:
|
||||
"""Positional binding of parameters
|
||||
"""
|
||||
assert(isinstance(params, list))
|
||||
@ -409,7 +472,7 @@ class Plugin(object):
|
||||
# Collect injections so we can sort them and insert them in the right
|
||||
# order later. If we don't apply inject them in increasing order we
|
||||
# might shift away an earlier injection.
|
||||
injections = []
|
||||
injections: List[Tuple[int, Any]] = []
|
||||
if 'plugin' in sig.parameters:
|
||||
pos = list(sig.parameters.keys()).index('plugin')
|
||||
injections.append((pos, self))
|
||||
@ -425,7 +488,8 @@ class Plugin(object):
|
||||
ba.apply_defaults()
|
||||
return ba
|
||||
|
||||
def _bind_kwargs(self, func, params, request):
|
||||
def _bind_kwargs(self, func: Callable[..., Any], params: Dict[str, Any],
|
||||
request: Request) -> inspect.BoundArguments:
|
||||
"""Keyword based binding of parameters
|
||||
"""
|
||||
assert(isinstance(params, dict))
|
||||
@ -445,7 +509,8 @@ class Plugin(object):
|
||||
self._coerce_arguments(func, ba)
|
||||
return ba
|
||||
|
||||
def _exec_func(self, func, request):
|
||||
def _exec_func(self, func: Callable[..., Any],
|
||||
request: Request) -> JSONType:
|
||||
params = request.params
|
||||
if isinstance(params, list):
|
||||
ba = self._bind_pos(func, params, request)
|
||||
@ -454,9 +519,11 @@ class Plugin(object):
|
||||
ba = self._bind_kwargs(func, params, request)
|
||||
return func(*ba.args, **ba.kwargs)
|
||||
else:
|
||||
raise TypeError("Parameters to function call must be either a dict or a list.")
|
||||
raise TypeError(
|
||||
"Parameters to function call must be either a dict or a list."
|
||||
)
|
||||
|
||||
def _dispatch_request(self, request):
|
||||
def _dispatch_request(self, request: Request) -> None:
|
||||
name = request.method
|
||||
|
||||
if name not in self.methods:
|
||||
@ -487,7 +554,7 @@ class Plugin(object):
|
||||
request.set_exception(e)
|
||||
self.log(traceback.format_exc())
|
||||
|
||||
def _dispatch_notification(self, request):
|
||||
def _dispatch_notification(self, request: Request) -> None:
|
||||
if request.method not in self.subscriptions:
|
||||
raise ValueError("No subscription for {name} found.".format(
|
||||
name=request.method))
|
||||
@ -498,15 +565,19 @@ class Plugin(object):
|
||||
except Exception:
|
||||
self.log(traceback.format_exc())
|
||||
|
||||
def _write_locked(self, obj):
|
||||
def _write_locked(self, obj: JSONType) -> None:
|
||||
# ensure_ascii turns UTF-8 into \uXXXX so we need to suppress that,
|
||||
# then utf8 ourselves.
|
||||
s = bytes(json.dumps(obj, cls=LightningRpc.LightningJSONEncoder, ensure_ascii=False) + "\n\n", encoding='utf-8')
|
||||
s = bytes(json.dumps(
|
||||
obj,
|
||||
cls=LightningRpc.LightningJSONEncoder,
|
||||
ensure_ascii=False
|
||||
) + "\n\n", encoding='utf-8')
|
||||
with self.write_lock:
|
||||
self.stdout.buffer.write(s)
|
||||
self.stdout.flush()
|
||||
|
||||
def notify(self, method, params):
|
||||
def notify(self, method: str, params: JSONType) -> None:
|
||||
payload = {
|
||||
'jsonrpc': '2.0',
|
||||
'method': method,
|
||||
@ -514,30 +585,35 @@ class Plugin(object):
|
||||
}
|
||||
self._write_locked(payload)
|
||||
|
||||
def log(self, message, level='info'):
|
||||
def log(self, message: str, level: str = 'info') -> None:
|
||||
# Split the log into multiple lines and print them
|
||||
# individually. Makes tracebacks much easier to read.
|
||||
for line in message.split('\n'):
|
||||
self.notify('log', {'level': level, 'message': line})
|
||||
|
||||
def _parse_request(self, jsrequest):
|
||||
def _parse_request(self, jsrequest: Dict[str, JSONType]) -> Request:
|
||||
i = jsrequest.get('id', None)
|
||||
if not isinstance(i, int) and i is not None:
|
||||
raise ValueError('Non-integer request id "{i}"'.format(i=i))
|
||||
|
||||
request = Request(
|
||||
plugin=self,
|
||||
req_id=jsrequest.get('id', None),
|
||||
method=jsrequest['method'],
|
||||
req_id=i,
|
||||
method=str(jsrequest['method']),
|
||||
params=jsrequest['params'],
|
||||
background=False,
|
||||
)
|
||||
return request
|
||||
|
||||
def _multi_dispatch(self, msgs):
|
||||
def _multi_dispatch(self, msgs: List[bytes]) -> bytes:
|
||||
"""We received a couple of messages, now try to dispatch them all.
|
||||
|
||||
Returns the last partial message that was not complete yet.
|
||||
"""
|
||||
for payload in msgs[:-1]:
|
||||
# Note that we use function annotations to do Millisatoshi conversions
|
||||
# in _exec_func, so we don't use LightningJSONDecoder here.
|
||||
# Note that we use function annotations to do Millisatoshi
|
||||
# conversions in _exec_func, so we don't use LightningJSONDecoder
|
||||
# here.
|
||||
request = self._parse_request(json.loads(payload.decode('utf8')))
|
||||
|
||||
# If this has an 'id'-field, it's a request and returns a
|
||||
@ -550,7 +626,7 @@ class Plugin(object):
|
||||
|
||||
return msgs[-1]
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
partial = b""
|
||||
for l in self.stdin.buffer:
|
||||
partial += l
|
||||
@ -561,7 +637,7 @@ class Plugin(object):
|
||||
|
||||
partial = self._multi_dispatch(msgs)
|
||||
|
||||
def _getmanifest(self, **kwargs):
|
||||
def _getmanifest(self, **kwargs) -> JSONType:
|
||||
if 'allow-deprecated-apis' in kwargs:
|
||||
self.deprecated_apis = kwargs['allow-deprecated-apis']
|
||||
else:
|
||||
@ -582,13 +658,21 @@ class Plugin(object):
|
||||
doc = inspect.getdoc(method.func)
|
||||
if not doc:
|
||||
self.log(
|
||||
'RPC method \'{}\' does not have a docstring.'.format(method.name)
|
||||
'RPC method \'{}\' does not have a docstring.'.format(
|
||||
method.name
|
||||
)
|
||||
)
|
||||
doc = "Undocumented RPC method from a plugin."
|
||||
doc = re.sub('\n+', ' ', doc)
|
||||
|
||||
# Handles out-of-order use of parameters like:
|
||||
# def hello_obfus(arg1, arg2, plugin, thing3, request=None, thing5='at', thing6=21)
|
||||
#
|
||||
# ```python3
|
||||
#
|
||||
# def hello_obfus(arg1, arg2, plugin, thing3, request=None,
|
||||
# thing5='at', thing6=21)
|
||||
#
|
||||
# ```
|
||||
argspec = inspect.getfullargspec(method.func)
|
||||
defaults = argspec.defaults
|
||||
num_defaults = len(defaults) if defaults else 0
|
||||
@ -611,7 +695,8 @@ class Plugin(object):
|
||||
'description': doc if not method.desc else method.desc
|
||||
})
|
||||
if method.long_desc:
|
||||
methods[len(methods) - 1]["long_description"] = method.long_desc
|
||||
m = methods[len(methods) - 1]
|
||||
m["long_description"] = method.long_desc
|
||||
|
||||
manifest = {
|
||||
'options': list(self.options.values()),
|
||||
@ -628,12 +713,30 @@ class Plugin(object):
|
||||
|
||||
return manifest
|
||||
|
||||
def _init(self, options, configuration, request):
|
||||
self.rpc_filename = configuration['rpc-file']
|
||||
self.lightning_dir = configuration['lightning-dir']
|
||||
def _init(self, options: Dict[str, JSONType],
|
||||
configuration: Dict[str, JSONType],
|
||||
request: Request) -> JSONType:
|
||||
|
||||
def verify_str(d: Dict[str, JSONType], key: str) -> str:
|
||||
v = d.get(key)
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("Wrong argument to init: expected {key} to be"
|
||||
" a string, got {v}".format(key=key, v=v))
|
||||
return v
|
||||
|
||||
def verify_bool(d: Dict[str, JSONType], key: str) -> bool:
|
||||
v = d.get(key)
|
||||
if not isinstance(v, bool):
|
||||
raise ValueError("Wrong argument to init: expected {key} to be"
|
||||
" a bool, got {v}".format(key=key, v=v))
|
||||
return v
|
||||
|
||||
self.rpc_filename = verify_str(configuration, 'rpc-file')
|
||||
self.lightning_dir = verify_str(configuration, 'lightning-dir')
|
||||
|
||||
path = os.path.join(self.lightning_dir, self.rpc_filename)
|
||||
self.rpc = LightningRpc(path)
|
||||
self.startup = configuration['startup']
|
||||
self.startup = verify_bool(configuration, 'startup')
|
||||
for name, value in options.items():
|
||||
self.options[name]['value'] = value
|
||||
|
||||
@ -647,18 +750,18 @@ class PluginStream(object):
|
||||
"""Sink that turns everything that is written to it into a notification.
|
||||
"""
|
||||
|
||||
def __init__(self, plugin, level="info"):
|
||||
def __init__(self, plugin: Plugin, level: str = "info"):
|
||||
self.plugin = plugin
|
||||
self.level = level
|
||||
self.buff = ''
|
||||
|
||||
def write(self, payload):
|
||||
def write(self, payload: str) -> None:
|
||||
self.buff += payload
|
||||
|
||||
if len(payload) > 0 and payload[-1] == '\n':
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
def flush(self) -> None:
|
||||
lines = self.buff.split('\n')
|
||||
if len(lines) < 2:
|
||||
return
|
||||
@ -670,7 +773,8 @@ class PluginStream(object):
|
||||
self.buff = lines[-1]
|
||||
|
||||
|
||||
def monkey_patch(plugin, stdout=True, stderr=False):
|
||||
def monkey_patch(plugin: Plugin, stdout: bool = True,
|
||||
stderr: bool = False) -> None:
|
||||
"""Monkey patch stderr and stdout so we use notifications instead.
|
||||
|
||||
A plugin commonly communicates with lightningd over its stdout and
|
||||
|
@ -1,3 +1,4 @@
|
||||
import coincurve
|
||||
import struct
|
||||
|
||||
|
||||
@ -66,5 +67,79 @@ class ShortChannelId(object):
|
||||
def __str__(self):
|
||||
return "{self.block}x{self.txnum}x{self.outnum}".format(self=self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.block == other.block and self.txnum == other.txnum and self.outnum == other.outnum
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, ShortChannelId):
|
||||
return False
|
||||
|
||||
return (
|
||||
self.block == other.block
|
||||
and self.txnum == other.txnum
|
||||
and self.outnum == other.outnum
|
||||
)
|
||||
|
||||
|
||||
class Secret(object):
|
||||
def __init__(self, data: bytes) -> None:
|
||||
assert(len(data) == 32)
|
||||
self.data = data
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return self.data
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Secret) and self.data == other.data
|
||||
|
||||
def __str__(self):
|
||||
return "Secret[0x{}]".format(self.data.hex())
|
||||
|
||||
|
||||
class PrivateKey(object):
|
||||
def __init__(self, rawkey) -> None:
|
||||
if not isinstance(rawkey, bytes):
|
||||
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
|
||||
elif len(rawkey) != 32:
|
||||
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
|
||||
|
||||
self.rawkey = rawkey
|
||||
self.key = coincurve.PrivateKey(rawkey)
|
||||
|
||||
def serializeCompressed(self):
|
||||
return self.key.secret
|
||||
|
||||
def public_key(self):
|
||||
return PublicKey(self.key.public_key)
|
||||
|
||||
|
||||
class PublicKey(object):
|
||||
def __init__(self, innerkey):
|
||||
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
|
||||
# by coincurve
|
||||
if isinstance(innerkey, bytes):
|
||||
if innerkey[0] in [2, 3] and len(innerkey) == 33:
|
||||
innerkey = coincurve.PublicKey(innerkey)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Byte keys must be 33-byte long starting from either 02 or 03"
|
||||
)
|
||||
|
||||
elif not isinstance(innerkey, coincurve.keys.PublicKey):
|
||||
raise ValueError(
|
||||
"Key must either be bytes or coincurve.keys.PublicKey"
|
||||
)
|
||||
self.key = innerkey
|
||||
|
||||
def serializeCompressed(self):
|
||||
return self.key.format(compressed=True)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return self.serializeCompressed()
|
||||
|
||||
def __str__(self):
|
||||
return "PublicKey[0x{}]".format(
|
||||
self.serializeCompressed().hex()
|
||||
)
|
||||
|
||||
|
||||
def Keypair(object):
|
||||
def __init__(self, priv, pub):
|
||||
self.priv, self.pub = priv, pub
|
||||
|
@ -4,6 +4,7 @@ from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
from .primitives import Secret, PrivateKey, PublicKey
|
||||
from hashlib import sha256
|
||||
import coincurve
|
||||
import os
|
||||
@ -55,64 +56,6 @@ def decryptWithAD(k, n, ad, ciphertext):
|
||||
return chacha.decrypt(n, ciphertext, ad)
|
||||
|
||||
|
||||
class PrivateKey(object):
|
||||
def __init__(self, rawkey):
|
||||
if not isinstance(rawkey, bytes):
|
||||
raise TypeError(f"rawkey must be bytes, {type(rawkey)} received")
|
||||
elif len(rawkey) != 32:
|
||||
raise ValueError(f"rawkey must be 32-byte long. {len(rawkey)} received")
|
||||
|
||||
self.rawkey = rawkey
|
||||
self.key = coincurve.PrivateKey(rawkey)
|
||||
|
||||
def serializeCompressed(self):
|
||||
return self.key.secret
|
||||
|
||||
def public_key(self):
|
||||
return PublicKey(self.key.public_key)
|
||||
|
||||
|
||||
class Secret(object):
|
||||
def __init__(self, raw):
|
||||
assert(len(raw) == 32)
|
||||
self.raw = raw
|
||||
|
||||
def __str__(self):
|
||||
return "Secret[0x{}]".format(self.raw.hex())
|
||||
|
||||
|
||||
class PublicKey(object):
|
||||
def __init__(self, innerkey):
|
||||
# We accept either 33-bytes raw keys, or an EC PublicKey as returned
|
||||
# by coincurve
|
||||
if isinstance(innerkey, bytes):
|
||||
if innerkey[0] in [2, 3] and len(innerkey) == 33:
|
||||
innerkey = coincurve.PublicKey(innerkey)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Byte keys must be 33-byte long starting from either 02 or 03"
|
||||
)
|
||||
|
||||
elif not isinstance(innerkey, coincurve.keys.PublicKey):
|
||||
raise ValueError(
|
||||
"Key must either be bytes or coincurve.keys.PublicKey"
|
||||
)
|
||||
self.key = innerkey
|
||||
|
||||
def serializeCompressed(self):
|
||||
return self.key.format(compressed=True)
|
||||
|
||||
def __str__(self):
|
||||
return "PublicKey[0x{}]".format(
|
||||
self.serializeCompressed().hex()
|
||||
)
|
||||
|
||||
|
||||
def Keypair(object):
|
||||
def __init__(self, priv, pub):
|
||||
self.priv, self.pub = priv, pub
|
||||
|
||||
|
||||
class Sha256Mixer(object):
|
||||
def __init__(self, base):
|
||||
self.hash = sha256(base).digest()
|
||||
@ -174,7 +117,7 @@ class LightningConnection(object):
|
||||
h.hash = self.handshake['h']
|
||||
h.update(self.handshake['e'].public_key().serializeCompressed())
|
||||
es = ecdh(self.handshake['e'], self.remote_pubkey)
|
||||
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
|
||||
t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
|
||||
assert(len(t) == 64)
|
||||
self.chaining_key, temp_k1 = t[:32], t[32:]
|
||||
c = encryptWithAD(temp_k1, self.nonce(0), h.digest(), b'')
|
||||
@ -194,7 +137,7 @@ class LightningConnection(object):
|
||||
h.update(re.serializeCompressed())
|
||||
es = ecdh(self.local_privkey, re)
|
||||
self.handshake['re'] = re
|
||||
t = hkdf(salt=self.chaining_key, ikm=es.raw, info=b'')
|
||||
t = hkdf(salt=self.chaining_key, ikm=es.data, info=b'')
|
||||
self.chaining_key, temp_k1 = t[:32], t[32:]
|
||||
|
||||
try:
|
||||
@ -210,7 +153,7 @@ class LightningConnection(object):
|
||||
h.hash = self.handshake['h']
|
||||
h.update(self.handshake['e'].public_key().serializeCompressed())
|
||||
ee = ecdh(self.handshake['e'], self.handshake['re'])
|
||||
t = hkdf(salt=self.chaining_key, ikm=ee.raw, info=b'')
|
||||
t = hkdf(salt=self.chaining_key, ikm=ee.data, info=b'')
|
||||
assert(len(t) == 64)
|
||||
self.chaining_key, self.temp_k2 = t[:32], t[32:]
|
||||
c = encryptWithAD(self.temp_k2, self.nonce(0), h.digest(), b'')
|
||||
@ -231,7 +174,7 @@ class LightningConnection(object):
|
||||
h.update(re.serializeCompressed())
|
||||
ee = ecdh(self.handshake['e'], re)
|
||||
self.chaining_key, self.temp_k2 = hkdf_two_keys(
|
||||
salt=self.chaining_key, ikm=ee.raw
|
||||
salt=self.chaining_key, ikm=ee.data
|
||||
)
|
||||
try:
|
||||
decryptWithAD(self.temp_k2, self.nonce(0), h.digest(), c)
|
||||
@ -249,7 +192,7 @@ class LightningConnection(object):
|
||||
se = ecdh(self.local_privkey, self.re)
|
||||
|
||||
self.chaining_key, self.temp_k3 = hkdf_two_keys(
|
||||
salt=self.chaining_key, ikm=se.raw
|
||||
salt=self.chaining_key, ikm=se.data
|
||||
)
|
||||
t = encryptWithAD(self.temp_k3, self.nonce(0), h.digest(), b'')
|
||||
m = b'\x00' + c + t
|
||||
@ -272,7 +215,7 @@ class LightningConnection(object):
|
||||
se = ecdh(self.handshake['e'], self.remote_pubkey)
|
||||
|
||||
self.chaining_key, self.temp_k3 = hkdf_two_keys(
|
||||
se.raw, self.chaining_key
|
||||
se.data, self.chaining_key
|
||||
)
|
||||
decryptWithAD(self.temp_k3, self.nonce(0), h.digest(), t)
|
||||
self.rn, self.sn = 0, 0
|
||||
|
Loading…
Reference in New Issue
Block a user