diff --git a/contrib/pyln-proto/pyln/proto/message/message.py b/contrib/pyln-proto/pyln/proto/message/message.py index e43d497a7..c40464c66 100644 --- a/contrib/pyln-proto/pyln/proto/message/message.py +++ b/contrib/pyln-proto/pyln/proto/message/message.py @@ -109,10 +109,11 @@ domain, such as within a given BOLT""" class MessageTypeField(object): """A field within a particular message type or subtype""" - def __init__(self, ownername, name, fieldtype): + def __init__(self, ownername, name, fieldtype, option=None): self.full_name = "{}.{}".format(ownername, name) self.name = name self.fieldtype = fieldtype + self.option = option def missing_fields(self, fields): """Return this field if it's not in fields""" @@ -171,7 +172,7 @@ inherit from this too. .format(parts)) return SubtypeType(parts[0]) - def _field_from_csv(self, namespace, parts, ellipsisok=False): + def _field_from_csv(self, namespace, parts, ellipsisok=False, option=None): """Takes msgdata/subtypedata after first two fields e.g. [...]timestamp_node_id_1,u32, @@ -191,19 +192,22 @@ inherit from this too. DynamicArrayType(self, parts[0], basetype, - lenfield)) + lenfield), + option) lenfield.fieldtype.add_length_for(field) elif ellipsisok and parts[2] == '...': field = MessageTypeField(self.name, parts[0], EllipsisArrayType(self, - parts[0], basetype)) + parts[0], basetype), + option) else: field = MessageTypeField(self.name, parts[0], SizedArrayType(self, parts[0], basetype, - int(parts[2]))) + int(parts[2])), + option) else: - field = MessageTypeField(self.name, parts[0], basetype) + field = MessageTypeField(self.name, parts[0], basetype, option) return field @@ -299,9 +303,10 @@ class MessageType(SubtypeType): 'NODE': 0x2000, 'UPDATE': 0x1000} - def __init__(self, name, value): + def __init__(self, name, value, option=None): super().__init__(name) self.number = self.parse_value(value) + self.option = option def parse_value(self, value): result = 0 @@ -318,23 +323,30 @@ class MessageType(SubtypeType): @staticmethod def type_from_csv(parts): - """e.g msgtype,open_channel,32""" - if len(parts) != 2: + """e.g msgtype,open_channel,32,option_foo""" + option = None + if len(parts) == 3: + option = parts[2] + elif len(parts) < 2 or len(parts) > 3: raise ValueError("msgtype expected 3 CSV parts, not {}" .format(parts)) - return MessageType(parts[0], parts[1]) + return MessageType(parts[0], parts[1], option) @staticmethod def field_from_csv(namespace, parts): - """e.g msgdata,open_channel,temporary_channel_id,byte,32""" - if len(parts) != 4: + """e.g msgdata,open_channel,temporary_channel_id,byte,32[,opt]""" + option = None + if len(parts) == 5: + option = parts[4] + elif len(parts) != 4: raise ValueError("msgdata expected 4 CSV parts, not {}" .format(parts)) messagetype = namespace.get_msgtype(parts[0]) if not messagetype: raise ValueError("unknown subtype {}".format(parts[0])) - field = messagetype._field_from_csv(namespace, parts[1:]) + field = messagetype._field_from_csv(namespace, parts[1:4], + option=option) messagetype.add_field(field)