mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-01-18 05:12:45 +01:00
pyln.proto.message: more mypy fixes.
This includes some real bugfixes, since it noticed some places we were being loose with different types! Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
parent
3882e8bdf7
commit
11a0de877e
@ -17,7 +17,7 @@ check-flake8:
|
||||
|
||||
# mypy . does not recurse. I have no idea why...
|
||||
check-mypy:
|
||||
mypy --ignore-missing-imports `find * -name '*.py'`
|
||||
mypy --ignore-missing-imports `find pyln/proto/message/ -name '*.py'`
|
||||
|
||||
$(SDIST_FILE):
|
||||
python3 setup.py sdist
|
||||
|
@ -1,8 +1,8 @@
|
||||
from .fundamental_types import FieldType, IntegerType, split_field
|
||||
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union
|
||||
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union, cast
|
||||
from io import BufferedIOBase
|
||||
if TYPE_CHECKING:
|
||||
from .message import SubtypeType, TlvStreamType
|
||||
from .message import SubtypeType, TlvMessageType, MessageTypeField
|
||||
|
||||
|
||||
class ArrayType(FieldType):
|
||||
@ -98,7 +98,7 @@ class SizedArrayType(ArrayType):
|
||||
class EllipsisArrayType(ArrayType):
|
||||
"""This is used for ... fields at the end of a tlv: the array ends
|
||||
when the tlv ends"""
|
||||
def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType):
|
||||
def __init__(self, tlv: 'TlvMessageType', name: str, elemtype: FieldType):
|
||||
super().__init__(tlv, name, elemtype)
|
||||
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
|
||||
@ -119,13 +119,13 @@ class LengthFieldType(FieldType):
|
||||
super().__init__(inttype.name)
|
||||
self.underlying_type = inttype
|
||||
# You can be length for more than one field!
|
||||
self.len_for: List[DynamicArrayType] = []
|
||||
self.len_for: List['MessageTypeField'] = []
|
||||
|
||||
def is_optional(self) -> bool:
|
||||
"""This field value is always implies, never specified directly"""
|
||||
return True
|
||||
|
||||
def add_length_for(self, field: 'DynamicArrayType') -> None:
|
||||
def add_length_for(self, field: 'MessageTypeField') -> None:
|
||||
assert isinstance(field.fieldtype, DynamicArrayType)
|
||||
self.len_for.append(field)
|
||||
|
||||
@ -160,7 +160,7 @@ class LengthFieldType(FieldType):
|
||||
they're implied by the length of other fields"""
|
||||
return ''
|
||||
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]:
|
||||
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
|
||||
return self.underlying_type.read(io_in, otherfields)
|
||||
|
||||
@ -186,11 +186,11 @@ they're implied by the length of other fields"""
|
||||
|
||||
class DynamicArrayType(ArrayType):
|
||||
"""This is used for arrays where another field controls the size"""
|
||||
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType):
|
||||
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: 'MessageTypeField'):
|
||||
super().__init__(outer, name, elemtype)
|
||||
assert type(lenfield.fieldtype) is LengthFieldType
|
||||
self.lenfield = lenfield
|
||||
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
|
||||
return super().read_arr(io_in, otherfields,
|
||||
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))
|
||||
cast(LengthFieldType, self.lenfield.fieldtype)._maybe_calc_value(self.lenfield.name, otherfields))
|
||||
|
@ -59,6 +59,15 @@ These are further specialized.
|
||||
def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def val_from_str(self, s: str) -> Tuple[Any, str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write(self, io_out: BufferedIOBase, v: Any, otherfields: Dict[str, Any]) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
|
||||
"""Convert to a python object: for simple fields, this means a string"""
|
||||
return self.val_to_str(v, otherfields)
|
||||
@ -83,7 +92,7 @@ class IntegerType(FieldType):
|
||||
a, b = split_field(s)
|
||||
return int(a), b
|
||||
|
||||
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> int:
|
||||
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
|
||||
"""Convert to a python object: for integer fields, this means an int"""
|
||||
return int(v)
|
||||
|
||||
@ -240,7 +249,7 @@ class BigSizeType(FieldType):
|
||||
return int(v)
|
||||
|
||||
|
||||
def fundamental_types():
|
||||
def fundamental_types() -> List[FieldType]:
|
||||
# From 01-messaging.md#fundamental-types:
|
||||
return [IntegerType('byte', 1, 'B'),
|
||||
IntegerType('u16', 2, '>H'),
|
||||
|
@ -1,10 +1,10 @@
|
||||
import struct
|
||||
from io import BufferedIOBase, BytesIO
|
||||
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType
|
||||
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType, IntegerType
|
||||
from .array_types import (
|
||||
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
|
||||
)
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, cast
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, Callable, cast
|
||||
|
||||
|
||||
class MessageNamespace(object):
|
||||
@ -12,7 +12,7 @@ class MessageNamespace(object):
|
||||
domain, such as within a given BOLT"""
|
||||
def __init__(self, csv_lines: List[str] = []):
|
||||
self.subtypes: Dict[str, SubtypeType] = {}
|
||||
self.fundamentaltypes: Dict[str, SubtypeType] = {}
|
||||
self.fundamentaltypes: Dict[str, FieldType] = {}
|
||||
self.tlvtypes: Dict[str, TlvStreamType] = {}
|
||||
self.messagetypes: Dict[str, MessageType] = {}
|
||||
|
||||
@ -28,27 +28,35 @@ domain, such as within a given BOLT"""
|
||||
for v in other.subtypes.values():
|
||||
ret.add_subtype(v)
|
||||
ret.tlvtypes = self.tlvtypes.copy()
|
||||
for v in other.tlvtypes.values():
|
||||
ret.add_tlvtype(v)
|
||||
for tlv in other.tlvtypes.values():
|
||||
ret.add_tlvtype(tlv)
|
||||
ret.messagetypes = self.messagetypes.copy()
|
||||
for v in other.messagetypes.values():
|
||||
ret.add_messagetype(v)
|
||||
return ret
|
||||
|
||||
def _check_unique(self, name: str) -> None:
|
||||
"""Raise an exception if name already used"""
|
||||
funtype = self.get_fundamentaltype(name)
|
||||
if funtype:
|
||||
raise ValueError('Already have {}'.format(funtype))
|
||||
subtype = self.get_subtype(name)
|
||||
if subtype:
|
||||
raise ValueError('Already have {}'.format(subtype))
|
||||
tlvtype = self.get_tlvtype(name)
|
||||
if tlvtype:
|
||||
raise ValueError('Already have {}'.format(tlvtype))
|
||||
|
||||
def add_subtype(self, t: 'SubtypeType') -> None:
|
||||
prev = self.get_type(t.name)
|
||||
if prev:
|
||||
raise ValueError('Already have {}'.format(prev))
|
||||
self._check_unique(t.name)
|
||||
self.subtypes[t.name] = t
|
||||
|
||||
def add_fundamentaltype(self, t: 'SubtypeType') -> None:
|
||||
assert not self.get_type(t.name)
|
||||
def add_fundamentaltype(self, t: FieldType) -> None:
|
||||
self._check_unique(t.name)
|
||||
self.fundamentaltypes[t.name] = t
|
||||
|
||||
def add_tlvtype(self, t: 'TlvStreamType') -> None:
|
||||
prev = self.get_type(t.name)
|
||||
if prev:
|
||||
raise ValueError('Already have {}'.format(prev))
|
||||
self._check_unique(t.name)
|
||||
self.tlvtypes[t.name] = t
|
||||
|
||||
def add_messagetype(self, m: 'MessageType') -> None:
|
||||
@ -70,7 +78,7 @@ domain, such as within a given BOLT"""
|
||||
return m
|
||||
return None
|
||||
|
||||
def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']:
|
||||
def get_fundamentaltype(self, name: str) -> Optional[FieldType]:
|
||||
if name in self.fundamentaltypes:
|
||||
return self.fundamentaltypes[name]
|
||||
return None
|
||||
@ -85,14 +93,6 @@ domain, such as within a given BOLT"""
|
||||
return self.tlvtypes[name]
|
||||
return None
|
||||
|
||||
def get_type(self, name: str) -> Optional['SubtypeType']:
|
||||
t = self.get_fundamentaltype(name)
|
||||
if t is None:
|
||||
t = self.get_subtype(name)
|
||||
if t is None:
|
||||
t = self.get_tlvtype(name)
|
||||
return t
|
||||
|
||||
def load_csv(self, lines: List[str]) -> None:
|
||||
"""Load a series of comma-separate-value lines into the namespace"""
|
||||
vals: Dict[str, List[List[str]]] = {'msgtype': [],
|
||||
@ -152,23 +152,22 @@ class MessageTypeField(object):
|
||||
return self.full_name
|
||||
|
||||
|
||||
class SubtypeType(object):
|
||||
class SubtypeType(FieldType):
|
||||
"""This defines a 'subtype' in BOLT-speak. It consists of fields of
|
||||
other types. Since 'msgtype' and 'tlvtype' are almost identical, they
|
||||
inherit from this too.
|
||||
other types. Since 'msgtype' is almost identical, it inherits from this too.
|
||||
|
||||
"""
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.fields: List[FieldType] = []
|
||||
super().__init__(name)
|
||||
self.fields: List[MessageTypeField] = []
|
||||
|
||||
def find_field(self, fieldname: str):
|
||||
def find_field(self, fieldname: str) -> Optional[MessageTypeField]:
|
||||
for f in self.fields:
|
||||
if f.name == fieldname:
|
||||
return f
|
||||
return None
|
||||
|
||||
def add_field(self, field: FieldType):
|
||||
def add_field(self, field: MessageTypeField) -> None:
|
||||
if self.find_field(field.name):
|
||||
raise ValueError("{}: duplicate field {}".format(self, field))
|
||||
self.fields.append(field)
|
||||
@ -192,12 +191,16 @@ inherit from this too.
|
||||
.format(parts))
|
||||
return SubtypeType(parts[0])
|
||||
|
||||
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField:
|
||||
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], option: str = None) -> MessageTypeField:
|
||||
"""Takes msgdata/subtypedata after first two fields
|
||||
e.g. [...]timestamp_node_id_1,u32,
|
||||
|
||||
"""
|
||||
basetype = namespace.get_type(parts[1])
|
||||
basetype = namespace.get_fundamentaltype(parts[1])
|
||||
if basetype is None:
|
||||
basetype = namespace.get_subtype(parts[1])
|
||||
if basetype is None:
|
||||
basetype = namespace.get_tlvtype(parts[1])
|
||||
if basetype is None:
|
||||
raise ValueError('Unknown type {}'.format(parts[1]))
|
||||
|
||||
@ -206,7 +209,8 @@ inherit from this too.
|
||||
lenfield = self.find_field(parts[2])
|
||||
if lenfield is not None:
|
||||
# If we didn't know that field was a length, we do now!
|
||||
if type(lenfield.fieldtype) is not LengthFieldType:
|
||||
if not isinstance(lenfield.fieldtype, LengthFieldType):
|
||||
assert isinstance(lenfield.fieldtype, IntegerType)
|
||||
lenfield.fieldtype = LengthFieldType(lenfield.fieldtype)
|
||||
field = MessageTypeField(self.name, parts[0],
|
||||
DynamicArrayType(self,
|
||||
@ -215,7 +219,9 @@ inherit from this too.
|
||||
lenfield),
|
||||
option)
|
||||
lenfield.fieldtype.add_length_for(field)
|
||||
elif ellipsisok and parts[2] == '...':
|
||||
elif parts[2] == '...':
|
||||
# ... is only valid for a TLV.
|
||||
assert isinstance(self, TlvMessageType)
|
||||
field = MessageTypeField(self.name, parts[0],
|
||||
EllipsisArrayType(self,
|
||||
parts[0], basetype),
|
||||
@ -264,8 +270,10 @@ inherit from this too.
|
||||
raise ValueError("Unknown fields specified: {}".format(unknown))
|
||||
|
||||
for f in defined.difference(have):
|
||||
if not f.fieldtype.is_optional():
|
||||
raise ValueError("Missing value for {}".format(f))
|
||||
field = self.find_field(f)
|
||||
assert field
|
||||
if not field.fieldtype.is_optional():
|
||||
raise ValueError("Missing value for {}".format(field))
|
||||
|
||||
def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
|
||||
self._raise_if_badvals(v)
|
||||
@ -273,6 +281,7 @@ inherit from this too.
|
||||
sep = ''
|
||||
for fname, val in v.items():
|
||||
field = self.find_field(fname)
|
||||
assert field
|
||||
s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields)
|
||||
sep = ','
|
||||
|
||||
@ -281,16 +290,19 @@ inherit from this too.
|
||||
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ret: Dict[str, Any] = {}
|
||||
for k, v in val.items():
|
||||
ret[k] = self.find_field(k).fieldtype.val_to_py(v, val)
|
||||
field = self.find_field(k)
|
||||
assert field
|
||||
ret[k] = field.fieldtype.val_to_py(v, val)
|
||||
return ret
|
||||
|
||||
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
|
||||
self._raise_if_badvals(v)
|
||||
for fname, val in v.items():
|
||||
field = self.find_field(fname)
|
||||
assert field
|
||||
field.fieldtype.write(io_out, val, otherfields)
|
||||
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
vals = {}
|
||||
for field in self.fields:
|
||||
val = field.fieldtype.read(io_in, otherfields)
|
||||
@ -383,25 +395,46 @@ class MessageType(SubtypeType):
|
||||
messagetype.add_field(field)
|
||||
|
||||
|
||||
class TlvStreamType(SubtypeType):
|
||||
"""A TlvStreamType is just a Subtype, but its fields are
|
||||
TlvMessageTypes. In the CSV format these are created implicitly, when
|
||||
a tlvtype line (which defines a TlvMessageType within the TlvType,
|
||||
confusingly) refers to them.
|
||||
class TlvMessageType(MessageType):
|
||||
"""A 'tlvtype' in BOLT-speak"""
|
||||
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__(name, value)
|
||||
|
||||
def __str__(self):
|
||||
return "tlvmsgtype-{}".format(self.name)
|
||||
|
||||
|
||||
class TlvStreamType(FieldType):
|
||||
"""A TlvStreamType's fields are TlvMessageTypes. In the CSV format
|
||||
these are created implicitly, when a tlvtype line (which defines a
|
||||
TlvMessageType within the TlvType, confusingly) refers to them.
|
||||
|
||||
"""
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
self.fields: List[TlvMessageType] = []
|
||||
|
||||
def __str__(self):
|
||||
return "tlvstreamtype-{}".format(self.name)
|
||||
|
||||
def find_field_by_number(self, num: int) -> Optional['TlvMessageType']:
|
||||
def find_field(self, fieldname: str) -> Optional[TlvMessageType]:
|
||||
for f in self.fields:
|
||||
if f.name == fieldname:
|
||||
return f
|
||||
return None
|
||||
|
||||
def find_field_by_number(self, num: int) -> Optional[TlvMessageType]:
|
||||
for f in self.fields:
|
||||
if f.number == num:
|
||||
return f
|
||||
return None
|
||||
|
||||
def add_field(self, field: TlvMessageType) -> None:
|
||||
if self.find_field(field.name):
|
||||
raise ValueError("{}: duplicate field {}".format(self, field))
|
||||
self.fields.append(field)
|
||||
|
||||
def is_optional(self) -> bool:
|
||||
"""You can omit a tlvstream= altogether"""
|
||||
return True
|
||||
@ -438,7 +471,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
||||
raise ValueError("Unknown tlv field {}.{}"
|
||||
.format(tlvstream, parts[1]))
|
||||
|
||||
subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True)
|
||||
subfield = field._field_from_csv(namespace, parts[2:])
|
||||
field.add_field(subfield)
|
||||
|
||||
def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
|
||||
@ -480,7 +513,9 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
||||
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ret: Dict[str, Any] = {}
|
||||
for k, v in val.items():
|
||||
ret[k] = self.find_field(k).val_to_py(v, val)
|
||||
field = self.find_field(k)
|
||||
assert field
|
||||
ret[k] = field.val_to_py(v, val)
|
||||
return ret
|
||||
|
||||
def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None:
|
||||
@ -490,14 +525,16 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
||||
|
||||
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
|
||||
# ascending order as TLV spec requires.
|
||||
def write_raw_val(iobuf, val, otherfields: Dict[str, Any]):
|
||||
def write_raw_val(iobuf: BufferedIOBase, val: Any, otherfields: Dict[str, Any]) -> None:
|
||||
iobuf.write(val)
|
||||
|
||||
def get_value(tup):
|
||||
"""Get value from num, fun, val tuple"""
|
||||
return tup[0]
|
||||
|
||||
ordered = []
|
||||
ordered: List[Tuple[int,
|
||||
Callable[[BufferedIOBase, Any, Dict[str, Any]], None],
|
||||
Any]] = []
|
||||
for fieldname in v:
|
||||
f = self.find_field(fieldname)
|
||||
if f is None:
|
||||
@ -510,13 +547,13 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
||||
|
||||
for typenum, writefunc, val in ordered:
|
||||
buf = BytesIO()
|
||||
writefunc(buf, val, otherfields)
|
||||
writefunc(cast(BufferedIOBase, buf), val, otherfields)
|
||||
BigSizeType.write(io_out, typenum)
|
||||
BigSizeType.write(io_out, len(buf.getvalue()))
|
||||
io_out.write(buf.getvalue())
|
||||
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
|
||||
vals: Dict[str, Any] = {}
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[Union[str, int], Any]:
|
||||
vals: Dict[Union[str, int], Any] = {}
|
||||
|
||||
while True:
|
||||
tlv_type = BigSizeType.read(io_in)
|
||||
@ -543,16 +580,6 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
|
||||
return " {}={}".format(name, self.val_to_str(v, {}))
|
||||
|
||||
|
||||
class TlvMessageType(MessageType):
|
||||
"""A 'tlvtype' in BOLT-speak"""
|
||||
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__(name, value)
|
||||
|
||||
def __str__(self):
|
||||
return "tlvmsgtype-{}".format(self.name)
|
||||
|
||||
|
||||
class Message(object):
|
||||
"""A particular message instance"""
|
||||
def __init__(self, messagetype: MessageType, **kwargs):
|
||||
@ -679,7 +706,8 @@ Must not have missing fields.
|
||||
"""Convert to a Python native object: dicts, lists, strings, ints"""
|
||||
ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {}
|
||||
for f, v in self.fields.items():
|
||||
fieldtype = self.messagetype.find_field(f).fieldtype
|
||||
ret[f] = fieldtype.val_to_py(v, self.fields)
|
||||
field = self.messagetype.find_field(f)
|
||||
assert field
|
||||
ret[f] = field.fieldtype.val_to_py(v, self.fields)
|
||||
|
||||
return ret
|
||||
|
Loading…
Reference in New Issue
Block a user