diff --git a/tools/generate-wire.py b/tools/generate-wire.py index f3765f0a0..131de5c46 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -169,7 +169,9 @@ class Field(object): # Bolts use just a number: Guess type based on size. if options.bolt: - if size == 'var_int': + if size == '$': # this is a subtype + self.fieldtype = FieldType('struct {}'.format(name)) + elif size == 'var_int': base_size = 8 self.fieldtype = FieldType(size) else: @@ -853,6 +855,14 @@ struct {struct_name} {{ fields=str(fmt_fields)) +class Subtype(Message): + def __init__(self, name, comments): + super().__init__(name, None, comments, False) + + def print_struct(self): + return TlvMessage._inner_print_struct(self.name, self.fields) + + tlv_message_towire_stub = """static void towire_{tlv_name}_{name}(u8 **p, struct tlv_msg_{name} *{name}) {{ {field_decls} {subcalls} @@ -1071,6 +1081,7 @@ options = parser.parse_args() # Maps message names to messages messages = [] messages_with_option = [] +subtypes = [] comments = [] includes = [] tlv_fields = {} @@ -1096,18 +1107,21 @@ for line in fileinput.input(options.files): is_tlv_msg = len(parts) == 3 if len(parts) == 2 or is_tlv_msg: # eg: commit_sig,132,(_tlv) - if is_tlv_msg: - message = TlvMessage(parts[0], - Enumtype("WIRE_" + parts[0].upper(), parts[1]), - comments) + if parts[1] == '$': # this is a subtype + subtypes.append(Subtype(parts[0], comments)) else: - message = Message(parts[0], - Enumtype("WIRE_" + parts[0].upper(), parts[1]), - comments) + if is_tlv_msg: + message = TlvMessage(parts[0], + Enumtype("WIRE_" + parts[0].upper(), parts[1]), + comments) + else: + message = Message(parts[0], + Enumtype("WIRE_" + parts[0].upper(), parts[1]), + comments) - messages.append(message) - if is_tlv_msg: - tlv_fields[parts[2]].append(message) + messages.append(message) + if is_tlv_msg: + tlv_fields[parts[2]].append(message) comments = [] prevfield = None @@ -1115,9 +1129,9 @@ for line in fileinput.input(options.files): if len(parts) == 4: # eg commit_sig,0,channel-id,8 OR # commit_sig,0,channel-id,u64 - m = find_message(messages, parts[0]) + m = find_message(messages + subtypes, parts[0]) if m is None: - raise ValueError('Unknown message {}'.format(parts[0])) + raise ValueError('Unknown message or subtype {}'.format(parts[0])) elif len(parts) == 5: # eg. # channel_reestablish,48,your_last_per_commitment_secret,32,option209 @@ -1187,6 +1201,13 @@ def build_tlv_structs(tlv_fields): return structs +def build_subtype_structs(subtypes): + structs = "" + for subtype in subtypes: + structs += subtype.print_struct() + return structs + + enum_header_template = """enum {enumname} {{ {enums} }}; @@ -1214,7 +1235,7 @@ header_template = """/* This file was generated by generate-wire.py */ #include #include {includes} -{formatted_hdr_enums}{tlv_structs} +{formatted_hdr_enums}{gen_structs} {func_decls} #endif /* LIGHTNING_{idem} */ """ @@ -1304,6 +1325,7 @@ built_hdr_enums = build_hdr_enums(options.enumname, toplevel_messages, tlv_field built_impl_enums = build_impl_enums(options.enumname, toplevel_messages, tlv_fields) tlv_structs = build_tlv_structs(tlv_fields) tlv_structs += build_tlv_type_structs(tlv_fields) +subtype_structs = build_subtype_structs(subtypes) includes = '\n'.join(includes) printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}("{name}", msg); return;'.format(enum=m.enum, name=m.name) for m in toplevel_messages] @@ -1339,5 +1361,5 @@ print(template.format( enumname=options.enumname, formatted_hdr_enums=built_hdr_enums, formatted_impl_enums=built_impl_enums, - tlv_structs=tlv_structs, + gen_structs=tlv_structs + subtype_structs, func_decls='\n'.join(decls)))