mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-01-18 05:12:45 +01:00
pyln-proto: write out length of arrays of subtypes to wire
We weren't writing out the length of a nested subtype's dynamicarraylenght, now we do. The trick is to iterate through the fields on a subtype (since the length field is added separately) and to also iterate down through the otherfield values as we 'descend'
This commit is contained in:
parent
6db6ba6c03
commit
5142dc81f6
@ -48,9 +48,16 @@ wants an array of some type.
|
||||
|
||||
return [self.elemtype.val_to_py(i, otherfields) for i in v]
|
||||
|
||||
def write(self, io_out: BufferedIOBase, v: List[Any], otherfields: Dict[str, Any]) -> None:
|
||||
for i in v:
|
||||
self.elemtype.write(io_out, i, otherfields)
|
||||
def write(self, io_out: BufferedIOBase, vals: List[Any], otherfields: Dict[str, Any]) -> None:
|
||||
name = self.name.split('.')[1]
|
||||
if otherfields and name in otherfields:
|
||||
otherfields = otherfields[name]
|
||||
for i, val in enumerate(vals):
|
||||
if isinstance(otherfields, list) and len(otherfields) > i:
|
||||
fields = otherfields[i]
|
||||
else:
|
||||
fields = otherfields
|
||||
self.elemtype.write(io_out, val, fields)
|
||||
|
||||
def read_arr(self, io_in: BufferedIOBase, otherfields: Dict[str, Any], arraysize: Optional[int]) -> List[Any]:
|
||||
"""arraysize None means take rest of io entirely and exactly"""
|
||||
@ -179,7 +186,7 @@ they're implied by the length of other fields"""
|
||||
if mylen != len(otherfields[lens.name]):
|
||||
return [fieldname]
|
||||
# Field might be missing!
|
||||
if lens.name in otherfields:
|
||||
if otherfields and lens.name in otherfields:
|
||||
mylen = len(otherfields[lens.name])
|
||||
return []
|
||||
|
||||
|
@ -297,10 +297,17 @@ other types. Since 'msgtype' is almost identical, it inherits from this too.
|
||||
|
||||
def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str, Any]) -> None:
|
||||
self._raise_if_badvals(v)
|
||||
for fname, val in v.items():
|
||||
field = self.find_field(fname)
|
||||
assert field
|
||||
field.fieldtype.write(io_out, val, otherfields)
|
||||
for f in self.fields:
|
||||
if f.name in v:
|
||||
val = v[f.name]
|
||||
else:
|
||||
if f.option is not None:
|
||||
raise ValueError("Missing field {} {}".format(f.name, otherfields))
|
||||
val = None
|
||||
|
||||
if self.name in otherfields:
|
||||
otherfields = otherfields[self.name]
|
||||
f.fieldtype.write(io_out, val, otherfields)
|
||||
|
||||
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
vals = {}
|
||||
|
@ -90,6 +90,29 @@ def test_subtype():
|
||||
assert m.missing_fields()
|
||||
|
||||
|
||||
def test_subtype_array():
|
||||
ns = MessageNamespace()
|
||||
ns.load_csv(['msgtype,tx_signatures,1',
|
||||
'msgdata,tx_signatures,num_witnesses,u16,',
|
||||
'msgdata,tx_signatures,witness_stack,witness_stack,num_witnesses',
|
||||
'subtype,witness_stack',
|
||||
'subtypedata,witness_stack,num_input_witness,u16,',
|
||||
'subtypedata,witness_stack,witness_element,witness_element,num_input_witness',
|
||||
'subtype,witness_element',
|
||||
'subtypedata,witness_element,len,u16,',
|
||||
'subtypedata,witness_element,witness,byte,len'])
|
||||
|
||||
for test in [["tx_signatures witness_stack="
|
||||
"[{witness_element=[{witness=3045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01},{witness=02d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b}]}]",
|
||||
bytes.fromhex('00010001000200483045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01002102d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b')]]:
|
||||
m = Message.from_str(ns, test[0])
|
||||
assert m.to_str() == test[0]
|
||||
buf = io.BytesIO()
|
||||
m.write(buf)
|
||||
assert buf.getvalue().hex() == test[1].hex()
|
||||
assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
|
||||
|
||||
|
||||
def test_tlv():
|
||||
ns = MessageNamespace()
|
||||
ns.load_csv(['msgtype,test1,1',
|
||||
|
Loading…
Reference in New Issue
Block a user