pyln-proto: write out length of arrays of subtypes to wire

We weren't writing out the length of a nested subtype's
dynamicarraylenght, now we do. The trick is to iterate through the
fields on a subtype (since the length field is added separately)
and to also iterate down through the otherfield values as we 'descend'
This commit is contained in:
niftynei 2021-03-24 16:36:32 -05:00 committed by Rusty Russell
parent 6db6ba6c03
commit 5142dc81f6
3 changed files with 45 additions and 8 deletions

View File

@ -48,9 +48,16 @@ wants an array of some type.
return [self.elemtype.val_to_py(i, otherfields) for i in v]
def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
for i in v:
self.elemtype.write(io_out, i, otherfields)
def write(self, io_out: BufferedIOBase, vals: List[Any], otherfields: Dict[str, Any]) -> None:
name = self.name.split('.')[1]
if otherfields and name in otherfields:
otherfields = otherfields[name]
for i, val in enumerate(vals):
if isinstance(otherfields, list) and len(otherfields) > i:
fields = otherfields[i]
else:
fields = otherfields
self.elemtype.write(io_out, val, fields)
def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]:
"""arraysize None means take rest of io entirely and exactly"""
@ -179,7 +186,7 @@ they're implied by the length of other fields"""
if mylen != len(otherfields[lens.name]):
return [fieldname]
# Field might be missing!
if lens.name in otherfields:
if otherfields and lens.name in otherfields:
mylen = len(otherfields[lens.name])
return []

View File

@ -297,10 +297,17 @@ other types. Since 'msgtype' is almost identical, it inherits from this too.
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
self._raise_if_badvals(v)
for fname, val in v.items():
field = self.find_field(fname)
assert field
field.fieldtype.write(io_out, val, otherfields)
for f in self.fields:
if f.name in v:
val = v[f.name]
else:
if f.option is not None:
raise ValueError("Missing field {} {}".format(f.name, otherfields))
val = None
if self.name in otherfields:
otherfields = otherfields[self.name]
f.fieldtype.write(io_out, val, otherfields)
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
vals = {}

View File

@ -90,6 +90,29 @@ def test_subtype():
assert m.missing_fields()
def test_subtype_array():
ns = MessageNamespace()
ns.load_csv(['msgtype,tx_signatures,1',
'msgdata,tx_signatures,num_witnesses,u16,',
'msgdata,tx_signatures,witness_stack,witness_stack,num_witnesses',
'subtype,witness_stack',
'subtypedata,witness_stack,num_input_witness,u16,',
'subtypedata,witness_stack,witness_element,witness_element,num_input_witness',
'subtype,witness_element',
'subtypedata,witness_element,len,u16,',
'subtypedata,witness_element,witness,byte,len'])
for test in [["tx_signatures witness_stack="
"[{witness_element=[{witness=3045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01},{witness=02d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b}]}]",
bytes.fromhex('00010001000200483045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01002102d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b')]]:
m = Message.from_str(ns, test[0])
assert m.to_str() == test[0]
buf = io.BytesIO()
m.write(buf)
assert buf.getvalue().hex() == test[1].hex()
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
def test_tlv():
ns = MessageNamespace()
ns.load_csv(['msgtype,test1,1',