diff --git a/tools/generate-wire.py b/tools/generate-wire.py index b4d0a0fed..02e5bc410 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -641,6 +641,35 @@ class Message(object): subcalls=str(subcalls) ) + def print_struct(self): + """ returns a string representation of this message as + a struct""" + if not self.is_tlv: + raise TypeError('{} is not a TLV-message').format(self.name) + + fmt_fields = CCode() + for f in self.fields: + if f.is_len_var or f.is_padding(): + # there is no ethical padding under TLVs + continue + elif f.is_variable_size(): + fmt_fields.append('{} *{};'.format(f.fieldtype.name, f.name)) + elif f.is_array(): + fmt_fields.append('{} {}[{}];'.format(f.fieldtype.name, f.name, f.num_elems)) + else: + fmt_fields.append('{} {};'.format(f.fieldtype.name, f.name)) + + return tlv_struct_template.format( + tlv_name=self.name, + fields=str(fmt_fields)) + + +tlv_struct_template = """ +struct tlv_{tlv_name} {{ +{fields} +}}; +""" + def find_message(messages, name): for m in messages: @@ -836,6 +865,14 @@ def build_impl_enums(toplevel_enumname, toplevel_messages, tlv_fields): return enum_set +def build_tlv_structs(tlv_fields): + structs = "" + for field_name, tlv_messages in tlv_fields.items(): + for m in tlv_messages: + structs += m.print_struct() + return structs + + enum_header_template = """enum {enumname} {{ {enums} }}; @@ -863,7 +900,7 @@ header_template = """/* This file was generated by generate-wire.py */ #include #include {includes} -{formatted_hdr_enums} +{formatted_hdr_enums}{tlv_structs} {func_decls} #endif /* LIGHTNING_{idem} */ """ @@ -927,6 +964,7 @@ else: toplevel_messages = [m for m in messages if not m.is_tlv] built_hdr_enums = build_hdr_enums(options.enumname, toplevel_messages, tlv_fields) built_impl_enums = build_impl_enums(options.enumname, toplevel_messages, tlv_fields) +tlv_structs = build_tlv_structs(tlv_fields) includes = '\n'.join(includes) printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}("{name}", msg); return;'.format(enum=m.enum, name=m.name) for m in toplevel_messages] @@ -945,4 +983,5 @@ print(template.format( enumname=options.enumname, formatted_hdr_enums=built_hdr_enums, formatted_impl_enums=built_impl_enums, + tlv_structs=tlv_structs, func_decls='\n'.join(decls)))