pyln.proto.message: more mypy fixes.

This includes some real bugfixes, since it noticed some places we were
being loose with different types!

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2020-06-18 14:24:18 +09:30 committed by Christian Decker
parent 3882e8bdf7
commit 11a0de877e
4 changed files with 110 additions and 73 deletions

View File

@ -17,7 +17,7 @@ check-flake8:
# mypy . does not recurse. I have no idea why...
check-mypy:
mypy --ignore-missing-imports `find * -name '*.py'`
mypy --ignore-missing-imports `find pyln/proto/message/ -name '*.py'`
$(SDIST_FILE):
python3 setup.py sdist

View File

@ -1,8 +1,8 @@
from .fundamental_types import FieldType, IntegerType, split_field
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union
from typing import List, Optional, Dict, Tuple, TYPE_CHECKING, Any, Union, cast
from io import BufferedIOBase
if TYPE_CHECKING:
from .message import SubtypeType, TlvStreamType
from .message import SubtypeType, TlvMessageType, MessageTypeField
class ArrayType(FieldType):
@ -98,7 +98,7 @@ class SizedArrayType(ArrayType):
class EllipsisArrayType(ArrayType):
"""This is used for ... fields at the end of a tlv: the array ends
when the tlv ends"""
def __init__(self, tlv: 'TlvStreamType', name: str, elemtype: FieldType):
def __init__(self, tlv: 'TlvMessageType', name: str, elemtype: FieldType):
super().__init__(tlv, name, elemtype)
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
@ -119,13 +119,13 @@ class LengthFieldType(FieldType):
super().__init__(inttype.name)
self.underlying_type = inttype
# You can be length for more than one field!
self.len_for: List[DynamicArrayType] = []
self.len_for: List['MessageTypeField'] = []
def is_optional(self) -> bool:
"""This field value is always implies, never specified directly"""
return True
def add_length_for(self, field: 'DynamicArrayType') -> None:
def add_length_for(self, field: 'MessageTypeField') -> None:
assert isinstance(field.fieldtype, DynamicArrayType)
self.len_for.append(field)
@ -160,7 +160,7 @@ class LengthFieldType(FieldType):
they're implied by the length of other fields"""
return ''
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> None:
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[int]:
"""We store this, but it'll be removed from the fields as soon as it's used (i.e. by DynamicArrayType's val_from_bin)"""
return self.underlying_type.read(io_in, otherfields)
@ -186,11 +186,11 @@ they're implied by the length of other fields"""
class DynamicArrayType(ArrayType):
"""This is used for arrays where another field controls the size"""
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: LengthFieldType):
def __init__(self, outer: 'SubtypeType', name: str, elemtype: FieldType, lenfield: 'MessageTypeField'):
super().__init__(outer, name, elemtype)
assert type(lenfield.fieldtype) is LengthFieldType
self.lenfield = lenfield
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> List[Any]:
return super().read_arr(io_in, otherfields,
self.lenfield.fieldtype._maybe_calc_value(self.lenfield.name, otherfields))
cast(LengthFieldType, self.lenfield.fieldtype)._maybe_calc_value(self.lenfield.name, otherfields))

View File

@ -59,6 +59,15 @@ These are further specialized.
def val_to_str(self, v: Any, otherfields: Dict[str, Any]) -> str:
raise NotImplementedError()
def val_from_str(self, s: str) -> Tuple[Any, str]:
raise NotImplementedError()
def write(self, io_out: BufferedIOBase, v: Any, otherfields: Dict[str, Any]) -> None:
raise NotImplementedError()
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Any:
raise NotImplementedError()
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
"""Convert to a python object: for simple fields, this means a string"""
return self.val_to_str(v, otherfields)
@ -83,7 +92,7 @@ class IntegerType(FieldType):
a, b = split_field(s)
return int(a), b
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> int:
def val_to_py(self, v: Any, otherfields: Dict[str, Any]) -> Any:
"""Convert to a python object: for integer fields, this means an int"""
return int(v)
@ -240,7 +249,7 @@ class BigSizeType(FieldType):
return int(v)
def fundamental_types():
def fundamental_types() -> List[FieldType]:
# From 01-messaging.md#fundamental-types:
return [IntegerType('byte', 1, 'B'),
IntegerType('u16', 2, '>H'),

View File

@ -1,10 +1,10 @@
import struct
from io import BufferedIOBase, BytesIO
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType
from .fundamental_types import fundamental_types, BigSizeType, split_field, try_unpack, FieldType, IntegerType
from .array_types import (
SizedArrayType, DynamicArrayType, LengthFieldType, EllipsisArrayType
)
from typing import Dict, List, Optional, Tuple, Any, Union, cast
from typing import Dict, List, Optional, Tuple, Any, Union, Callable, cast
class MessageNamespace(object):
@ -12,7 +12,7 @@ class MessageNamespace(object):
domain, such as within a given BOLT"""
def __init__(self, csv_lines: List[str] = []):
self.subtypes: Dict[str, SubtypeType] = {}
self.fundamentaltypes: Dict[str, SubtypeType] = {}
self.fundamentaltypes: Dict[str, FieldType] = {}
self.tlvtypes: Dict[str, TlvStreamType] = {}
self.messagetypes: Dict[str, MessageType] = {}
@ -28,27 +28,35 @@ domain, such as within a given BOLT"""
for v in other.subtypes.values():
ret.add_subtype(v)
ret.tlvtypes = self.tlvtypes.copy()
for v in other.tlvtypes.values():
ret.add_tlvtype(v)
for tlv in other.tlvtypes.values():
ret.add_tlvtype(tlv)
ret.messagetypes = self.messagetypes.copy()
for v in other.messagetypes.values():
ret.add_messagetype(v)
return ret
def _check_unique(self, name: str) -> None:
"""Raise an exception if name already used"""
funtype = self.get_fundamentaltype(name)
if funtype:
raise ValueError('Already have {}'.format(funtype))
subtype = self.get_subtype(name)
if subtype:
raise ValueError('Already have {}'.format(subtype))
tlvtype = self.get_tlvtype(name)
if tlvtype:
raise ValueError('Already have {}'.format(tlvtype))
def add_subtype(self, t: 'SubtypeType') -> None:
prev = self.get_type(t.name)
if prev:
raise ValueError('Already have {}'.format(prev))
self._check_unique(t.name)
self.subtypes[t.name] = t
def add_fundamentaltype(self, t: 'SubtypeType') -> None:
assert not self.get_type(t.name)
def add_fundamentaltype(self, t: FieldType) -> None:
self._check_unique(t.name)
self.fundamentaltypes[t.name] = t
def add_tlvtype(self, t: 'TlvStreamType') -> None:
prev = self.get_type(t.name)
if prev:
raise ValueError('Already have {}'.format(prev))
self._check_unique(t.name)
self.tlvtypes[t.name] = t
def add_messagetype(self, m: 'MessageType') -> None:
@ -70,7 +78,7 @@ domain, such as within a given BOLT"""
return m
return None
def get_fundamentaltype(self, name: str) -> Optional['SubtypeType']:
def get_fundamentaltype(self, name: str) -> Optional[FieldType]:
if name in self.fundamentaltypes:
return self.fundamentaltypes[name]
return None
@ -85,14 +93,6 @@ domain, such as within a given BOLT"""
return self.tlvtypes[name]
return None
def get_type(self, name: str) -> Optional['SubtypeType']:
t = self.get_fundamentaltype(name)
if t is None:
t = self.get_subtype(name)
if t is None:
t = self.get_tlvtype(name)
return t
def load_csv(self, lines: List[str]) -> None:
"""Load a series of comma-separate-value lines into the namespace"""
vals: Dict[str, List[List[str]]] = {'msgtype': [],
@ -152,23 +152,22 @@ class MessageTypeField(object):
return self.full_name
class SubtypeType(object):
class SubtypeType(FieldType):
"""This defines a 'subtype' in BOLT-speak. It consists of fields of
other types. Since 'msgtype' and 'tlvtype' are almost identical, they
inherit from this too.
other types. Since 'msgtype' is almost identical, it inherits from this too.
"""
def __init__(self, name: str):
self.name = name
self.fields: List[FieldType] = []
super().__init__(name)
self.fields: List[MessageTypeField] = []
def find_field(self, fieldname: str):
def find_field(self, fieldname: str) -> Optional[MessageTypeField]:
for f in self.fields:
if f.name == fieldname:
return f
return None
def add_field(self, field: FieldType):
def add_field(self, field: MessageTypeField) -> None:
if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field)
@ -192,12 +191,16 @@ inherit from this too.
.format(parts))
return SubtypeType(parts[0])
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], ellipsisok=False, option: str = None) -> MessageTypeField:
def _field_from_csv(self, namespace: MessageNamespace, parts: List[str], option: str = None) -> MessageTypeField:
"""Takes msgdata/subtypedata after first two fields
e.g. [...]timestamp_node_id_1,u32,
"""
basetype = namespace.get_type(parts[1])
basetype = namespace.get_fundamentaltype(parts[1])
if basetype is None:
basetype = namespace.get_subtype(parts[1])
if basetype is None:
basetype = namespace.get_tlvtype(parts[1])
if basetype is None:
raise ValueError('Unknown type {}'.format(parts[1]))
@ -206,7 +209,8 @@ inherit from this too.
lenfield = self.find_field(parts[2])
if lenfield is not None:
# If we didn't know that field was a length, we do now!
if type(lenfield.fieldtype) is not LengthFieldType:
if not isinstance(lenfield.fieldtype, LengthFieldType):
assert isinstance(lenfield.fieldtype, IntegerType)
lenfield.fieldtype = LengthFieldType(lenfield.fieldtype)
field = MessageTypeField(self.name, parts[0],
DynamicArrayType(self,
@ -215,7 +219,9 @@ inherit from this too.
lenfield),
option)
lenfield.fieldtype.add_length_for(field)
elif ellipsisok and parts[2] == '...':
elif parts[2] == '...':
# ... is only valid for a TLV.
assert isinstance(self, TlvMessageType)
field = MessageTypeField(self.name, parts[0],
EllipsisArrayType(self,
parts[0], basetype),
@ -264,8 +270,10 @@ inherit from this too.
raise ValueError("Unknown fields specified: {}".format(unknown))
for f in defined.difference(have):
if not f.fieldtype.is_optional():
raise ValueError("Missing value for {}".format(f))
field = self.find_field(f)
assert field
if not field.fieldtype.is_optional():
raise ValueError("Missing value for {}".format(field))
def val_to_str(self, v: Dict[str, Any], otherfields: Dict[str, Any]) -> str:
self._raise_if_badvals(v)
@ -273,6 +281,7 @@ inherit from this too.
sep = ''
for fname, val in v.items():
field = self.find_field(fname)
assert field
s += sep + fname + '=' + field.fieldtype.val_to_str(val, otherfields)
sep = ','
@ -281,16 +290,19 @@ inherit from this too.
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
for k, v in val.items():
ret[k] = self.find_field(k).fieldtype.val_to_py(v, val)
field = self.find_field(k)
assert field
ret[k] = field.fieldtype.val_to_py(v, val)
return ret
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
self._raise_if_badvals(v)
for fname, val in v.items():
field = self.find_field(fname)
assert field
field.fieldtype.write(io_out, val, otherfields)
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
vals = {}
for field in self.fields:
val = field.fieldtype.read(io_in, otherfields)
@ -383,25 +395,46 @@ class MessageType(SubtypeType):
messagetype.add_field(field)
class TlvStreamType(SubtypeType):
"""A TlvStreamType is just a Subtype, but its fields are
TlvMessageTypes. In the CSV format these are created implicitly, when
a tlvtype line (which defines a TlvMessageType within the TlvType,
confusingly) refers to them.
class TlvMessageType(MessageType):
"""A 'tlvtype' in BOLT-speak"""
def __init__(self, name: str, value: str):
super().__init__(name, value)
def __str__(self):
return "tlvmsgtype-{}".format(self.name)
class TlvStreamType(FieldType):
"""A TlvStreamType's fields are TlvMessageTypes. In the CSV format
these are created implicitly, when a tlvtype line (which defines a
TlvMessageType within the TlvType, confusingly) refers to them.
"""
def __init__(self, name):
super().__init__(name)
self.fields: List[TlvMessageType] = []
def __str__(self):
return "tlvstreamtype-{}".format(self.name)
def find_field_by_number(self, num: int) -> Optional['TlvMessageType']:
def find_field(self, fieldname: str) -> Optional[TlvMessageType]:
for f in self.fields:
if f.name == fieldname:
return f
return None
def find_field_by_number(self, num: int) -> Optional[TlvMessageType]:
for f in self.fields:
if f.number == num:
return f
return None
def add_field(self, field: TlvMessageType) -> None:
if self.find_field(field.name):
raise ValueError("{}: duplicate field {}".format(self, field))
self.fields.append(field)
def is_optional(self) -> bool:
"""You can omit a tlvstream= altogether"""
return True
@ -438,7 +471,7 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
raise ValueError("Unknown tlv field {}.{}"
.format(tlvstream, parts[1]))
subfield = field._field_from_csv(namespace, parts[2:], ellipsisok=True)
subfield = field._field_from_csv(namespace, parts[2:])
field.add_field(subfield)
def val_from_str(self, s: str) -> Tuple[Dict[str, Any], str]:
@ -480,7 +513,9 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
def val_to_py(self, val: Dict[str, Any], otherfields: Dict[str, Any]) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
for k, v in val.items():
ret[k] = self.find_field(k).val_to_py(v, val)
field = self.find_field(k)
assert field
ret[k] = field.val_to_py(v, val)
return ret
def write(self, io_out: BufferedIOBase, v: Optional[Dict[str, Any]], otherfields: Dict[str, Any]) -> None:
@ -490,14 +525,16 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
# Make a tuple of (fieldnum, val_to_bin, val) so we can sort into
# ascending order as TLV spec requires.
def write_raw_val(iobuf, val, otherfields: Dict[str, Any]):
def write_raw_val(iobuf: BufferedIOBase, val: Any, otherfields: Dict[str, Any]) -> None:
iobuf.write(val)
def get_value(tup):
"""Get value from num, fun, val tuple"""
return tup[0]
ordered = []
ordered: List[Tuple[int,
Callable[[BufferedIOBase, Any, Dict[str, Any]], None],
Any]] = []
for fieldname in v:
f = self.find_field(fieldname)
if f is None:
@ -510,13 +547,13 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
for typenum, writefunc, val in ordered:
buf = BytesIO()
writefunc(buf, val, otherfields)
writefunc(cast(BufferedIOBase, buf), val, otherfields)
BigSizeType.write(io_out, typenum)
BigSizeType.write(io_out, len(buf.getvalue()))
io_out.write(buf.getvalue())
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[str, Any]:
vals: Dict[str, Any] = {}
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Dict[Union[str, int], Any]:
vals: Dict[Union[str, int], Any] = {}
while True:
tlv_type = BigSizeType.read(io_in)
@ -543,16 +580,6 @@ tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,u8,
return " {}={}".format(name, self.val_to_str(v, {}))
class TlvMessageType(MessageType):
"""A 'tlvtype' in BOLT-speak"""
def __init__(self, name: str, value: str):
super().__init__(name, value)
def __str__(self):
return "tlvmsgtype-{}".format(self.name)
class Message(object):
"""A particular message instance"""
def __init__(self, messagetype: MessageType, **kwargs):
@ -679,7 +706,8 @@ Must not have missing fields.
"""Convert to a Python native object: dicts, lists, strings, ints"""
ret: Dict[str, Union[Dict[str, Any], List[Any], str, int]] = {}
for f, v in self.fields.items():
fieldtype = self.messagetype.find_field(f).fieldtype
ret[f] = fieldtype.val_to_py(v, self.fields)
field = self.messagetype.find_field(f)
assert field
ret[f] = field.fieldtype.val_to_py(v, self.fields)
return ret