core-lightning/tools/generate-wire.py
niftynei 175fcf381a psbt: have wally_tx serialization methods be legible for gen'd code
our code generators expect the serialization name to match the struct
type
2020-06-23 14:49:32 +02:00

703 lines
24 KiB
Python
Executable File

#! /usr/bin/env python3
# Script to parse spec output CSVs and produce C files.
# Released by lisa neigut under CC0:
# https://creativecommons.org/publicdomain/zero/1.0/
#
# Reads from stdin, outputs C header or body file.
#
# Standard message types:
# msgtype,<msgname>,<value>[,<option>]
# msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>]
#
# TLV types:
# tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>]
# tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>]
#
# Subtypes:
# subtype,<subtypename>
# subtypedata,<subtypename>,<fieldname>,<typename>,[<count>]
from argparse import ArgumentParser, REMAINDER
from collections import OrderedDict
import copy
import fileinput
from mako.template import Template
import os
import re
import sys
# Generator to give us one line at a time.
def next_line(args, lines):
if lines is None:
lines = fileinput.input(args)
for i, line in enumerate(lines):
yield i + 1, line.strip()
# Class definitions, to keep things classy
class Field(object):
def __init__(self, name, type_obj, extensions=[],
field_comments=[], optional=False):
self.name = name
self.type_obj = type_obj
self.count = 1
self.len_field_of = None
self.len_field = None
self.implicit_len = False
self.extension_names = extensions
self.is_optional = optional
self.field_comments = field_comments
def __deepcopy__(self, memo):
deepcopy_method = self.__deepcopy__
self.__deepcopy__ = None
field = copy.deepcopy(self, memo)
self.__deepcopy__ = deepcopy_method
field.type_obj = self.type_obj
return field
def add_count(self, count):
self.count = int(count)
def add_len_field(self, len_field):
self.count = False
# we cache our len-field's name
self.len_field = len_field.name
# the len-field caches our name
len_field.len_field_of = self.name
def add_implicit_len(self):
self.count = False
self.implicit_len = True
def is_array(self):
return self.count > 1
def is_varlen(self):
return not self.count
def is_implicit_len(self):
return self.implicit_len
def is_extension(self):
return bool(self.extension_names)
def size(self, implicit_expression=None):
if self.count:
return self.count
if self.len_field:
return self.len_field
assert self.is_implicit_len()
assert implicit_expression
return implicit_expression
def needs_context(self):
""" A field needs a context if it's varsized """
return self.is_varlen() or self.type_obj.needs_context()
def arg_desc_to(self):
if self.len_field_of:
return ''
type_name = self.type_obj.type_name()
if self.is_array():
return ', const {} {}[{}]'.format(type_name, self.name, self.count)
if self.type_obj.is_assignable() and not self.is_varlen():
name = self.name
if self.is_optional:
name = '*' + name
return ', {} {}'.format(type_name, name)
if self.is_varlen() and self.type_obj.is_varsize():
return ', const {} **{}'.format(type_name, self.name)
return ', const {} *{}'.format(type_name, self.name)
def arg_desc_from(self):
type_name = self.type_obj.type_name()
if self.type_obj.is_const_ptr_ptr_type():
return ', const {} **{}'.format(type_name, self.name)
if self.len_field_of:
return ''
if self.is_array():
return ', {} {}[{}]'.format(type_name, self.name, self.count)
ptrs = '*'
if self.is_varlen() or self.is_optional or self.type_obj.is_varsize():
ptrs += '*'
if self.is_varlen() and self.type_obj.is_varsize():
ptrs += '*'
return ', {} {}{}'.format(type_name, ptrs, self.name)
class FieldSet(object):
def __init__(self):
self.fields = OrderedDict()
self.len_fields = {}
def add_data_field(self, field_name, type_obj, count=1,
extensions=[], comments=[], optional=False,
implicit_len_ok=False):
field = Field(field_name, type_obj, extensions=extensions,
field_comments=comments, optional=optional)
if bool(count):
try:
field.add_count(int(count))
except ValueError:
if count in self.fields:
len_field = self.find_data_field(count)
field.add_len_field(len_field)
self.len_fields[len_field.name] = len_field
else:
# '...' means "rest of TLV"
assert implicit_len_ok
assert count == '...'
field.add_implicit_len()
# You can't have any fields after an implicit-length field.
if len(self.fields) != 0:
assert not self.fields[next(reversed(self.fields))].is_implicit_len()
self.fields[field_name] = field
def find_data_field(self, field_name):
return self.fields[field_name]
def get_len_fields(self):
return list(self.len_fields.values())
def has_len_fields(self):
return bool(self.len_fields)
def needs_context(self):
return any([field.needs_context() or field.is_optional for field in self.fields.values()])
def singleton(self):
"""Return the single message, if there's only one, otherwise None"""
if len(self.fields) == 1:
return next(iter(self.fields.values()))
return None
class Type(FieldSet):
assignables = [
'u8',
'u16',
'u32',
'u64',
'tu16',
'tu32',
'tu64',
'bool',
'amount_sat',
'amount_msat',
'errcode_t',
'bigsize',
'varint'
]
typedefs = [
'u8',
'u16',
'u32',
'u64',
'bool',
'secp256k1_ecdsa_signature',
'secp256k1_ecdsa_recoverable_signature',
'wirestring',
'errcode_t',
'bigsize',
'varint',
]
truncated_typedefs = [
'tu16',
'tu32',
'tu64',
]
# Externally defined variable size types (require a context)
varsize_types = [
'peer_features',
'gossip_getnodes_entry',
'gossip_getchannels_entry',
'failed_htlc',
'existing_htlc',
'utxo',
'bitcoin_tx',
'wirestring',
'per_peer_state',
'bitcoin_tx_output',
'exclude_entry',
'fee_states',
'onionreply',
'feature_set',
'onionmsg_path',
'route_hop',
'tx_parts',
'wally_psbt',
]
# Some BOLT types are re-typed based on their field name
# ('fieldname partial', 'original type', 'outer type'): ('true type', 'collapse array?')
name_field_map = {
('txid', 'sha256'): ('bitcoin_txid', False),
('amt', 'u64'): ('amount_msat', False),
('msat', 'u64'): ('amount_msat', False),
('satoshis', 'u64'): ('amount_sat', False),
('node_id', 'pubkey', 'channel_announcement'): ('node_id', False),
('node_id', 'pubkey', 'node_announcement'): ('node_id', False),
('temporary_channel_id', 'u8'): ('channel_id', True),
('secret', 'u8'): ('secret', True),
('preimage', 'u8'): ('preimage', True),
}
# For BOLT specified types, a few type names need to be simply 'remapped'
# 'original type': 'true type'
name_remap = {
'byte': 'u8',
'signature': 'secp256k1_ecdsa_signature',
'chain_hash': 'bitcoin_blkid',
'point': 'pubkey',
# FIXME: omits 'pad'
}
# Types that are const pointer-to-pointers, such as chainparams, i.e.,
# they set a reference to some const entry.
const_ptr_ptr_types = [
'chainparams'
]
@staticmethod
def true_type(type_name, field_name=None, outer_name=None):
""" Returns 'true' type of a given type and a flag if
we've remapped a variable size/array type to a single struct
(an example of this is 'temporary_channel_id' which is specified
as a 32*byte, but we re-map it to a channel_id
"""
if type_name in Type.name_remap:
type_name = Type.name_remap[type_name]
if field_name:
for t, true_type in Type.name_field_map.items():
if t[0] in field_name and t[1] == type_name:
if len(t) == 2 or outer_name == t[2]:
return true_type
return (type_name, False)
def __init__(self, name):
FieldSet.__init__(self)
self.name, self.is_enum = self.parse_name(name)
self.depends_on = {}
self.type_comments = []
self.tlv = False
def parse_name(self, name):
if name.startswith('enum '):
return name[5:], True
return name, False
def add_data_field(self, field_name, type_obj, count=1,
extensions=[], comments=[], optional=False):
FieldSet.add_data_field(self, field_name, type_obj, count,
extensions=extensions,
comments=comments, optional=optional)
if type_obj.name not in self.depends_on:
self.depends_on[type_obj.name] = type_obj
def type_name(self):
if self.name in self.typedefs:
return self.name
if self.name in self.truncated_typedefs:
return self.name[1:]
if self.is_enum:
prefix = 'enum '
else:
prefix = 'struct '
return prefix + self.struct_name()
# We only accelerate the u8 case: it's common and trivial.
def has_array_helper(self):
return self.name in ['u8']
def struct_name(self):
if self.is_tlv():
return self.tlv.struct_name()
return self.name
def subtype_deps(self):
return [dep for dep in self.depends_on.values() if dep.is_subtype()]
def is_subtype(self):
return bool(self.fields)
def is_const_ptr_ptr_type(self):
return self.name in self.const_ptr_ptr_types
def is_truncated(self):
return self.name in self.truncated_typedefs
def needs_context(self):
return self.is_varsize()
def is_assignable(self):
""" Generally typedef's and enums """
return self.name in self.assignables or self.is_enum
def is_varsize(self):
""" A type is variably sized if it's marked as such (in varsize_types)
or it contains a field of variable length """
return self.name in self.varsize_types or self.has_len_fields()
def add_comments(self, comments):
self.type_comments = comments
def mark_tlv(self, tlv):
self.tlv = tlv
def is_tlv(self):
return bool(self.tlv)
class Message(FieldSet):
def __init__(self, name, number, option=[], enum_prefix='wire',
struct_prefix=None, comments=[]):
FieldSet.__init__(self)
self.name = name
self.number = number
self.enum_prefix = enum_prefix
self.option = option[0] if len(option) else None
self.struct_prefix = struct_prefix
self.enumname = None
self.msg_comments = comments
self.if_token = None
def has_option(self):
return self.option is not None
def enum_name(self):
name = self.enumname if self.enumname else self.name
return "{}_{}".format(self.enum_prefix, name).upper()
def struct_name(self):
if self.struct_prefix:
return self.struct_prefix + "_" + self.name
return self.name
def add_if(self, if_token):
self.if_token = if_token
class Tlv(object):
def __init__(self, name):
self.name = name
self.messages = {}
def add_message(self, tokens, comments=[]):
""" tokens -> (name, value[, option]) """
self.messages[tokens[0]] = Message(tokens[0], tokens[1], option=tokens[2:],
enum_prefix=self.name,
struct_prefix=self.struct_name(),
comments=comments)
def type_name(self):
return 'struct ' + self.struct_name()
def struct_name(self):
return "tlv_{}".format(self.name)
def find_message(self, name):
return self.messages[name]
def ordered_msgs(self):
return sorted(self.messages.values(), key=lambda item: int(item.number))
class Master(object):
types = {}
tlvs = {}
messages = {}
extension_msgs = {}
inclusions = []
top_comments = []
def add_comments(self, comments):
self.top_comments += comments
def add_include(self, inclusion):
self.inclusions.append(inclusion)
def add_tlv(self, tlv_name):
if tlv_name not in self.tlvs:
self.tlvs[tlv_name] = Tlv(tlv_name)
if tlv_name not in self.types:
self.types[tlv_name] = Type(tlv_name)
return self.tlvs[tlv_name]
def add_message(self, tokens, comments=[]):
""" tokens -> (name, value[, option])"""
self.messages[tokens[0]] = Message(tokens[0], tokens[1], option=tokens[2:],
comments=comments)
def add_extension_msg(self, name, msg):
self.extension_msgs[name] = msg
def add_type(self, type_name, field_name=None, outer_name=None):
optional = False
if type_name.startswith('?'):
type_name = type_name[1:]
optional = True
# Check for special type name re-mapping
type_name, collapse_original = Type.true_type(type_name, field_name,
outer_name)
if type_name not in self.types:
self.types[type_name] = Type(type_name)
return self.types[type_name], collapse_original, optional
def find_type(self, type_name):
return self.types[type_name]
def find_message(self, msg_name):
if msg_name in self.messages:
return self.messages[msg_name]
if msg_name in self.extension_msgs:
return self.extension_msgs[msg_name]
return None
def find_tlv(self, tlv_name):
return self.tlvs[tlv_name]
def get_ordered_subtypes(self):
""" We want to order subtypes such that the 'no dependency'
types are printed first """
subtypes = [s for s in self.types.values() if s.is_subtype()]
# Start with subtypes without subtype dependencies
sorted_types = [s for s in subtypes if not len(s.subtype_deps())]
unsorted = [s for s in subtypes if len(s.subtype_deps())]
while len(unsorted):
names = [s.name for s in sorted_types]
for s in list(unsorted):
if all([dependency.name in names for dependency in s.subtype_deps()]):
sorted_types.append(s)
unsorted.remove(s)
return sorted_types
def tlv_structs(self):
ret = []
for tlv in self.tlvs.values():
for v in tlv.messages.values():
if not v.singleton():
ret.append(v)
return ret
def find_template(self, options):
dirpath = os.path.dirname(os.path.abspath(__file__))
filename = dirpath + '/gen/{}{}_template'.format(
'print_' if options.print_wire else '', options.page)
return Template(filename=filename)
def post_process(self):
""" method to handle any 'post processing' that needs to be done.
for now, we just need match up types to TLVs """
for tlv_name, tlv in self.tlvs.items():
if tlv_name in self.types:
self.types[tlv_name].mark_tlv(tlv)
def write(self, options, output):
template = self.find_template(options)
enum_sets = []
if len(self.messages.values()) != 0:
enum_sets.append({
'name': options.enum_name,
'set': self.messages.values(),
})
stuff = {}
stuff['top_comments'] = self.top_comments
stuff['options'] = options
stuff['idem'] = re.sub(r'[^A-Z]+', '_', options.header_filename.upper())
stuff['header_filename'] = options.header_filename
stuff['includes'] = self.inclusions
stuff['enum_sets'] = enum_sets
subtypes = self.get_ordered_subtypes()
stuff['structs'] = subtypes + self.tlv_structs()
stuff['tlvs'] = self.tlvs
# We leave out extension messages in the printing pages. Any extension
# fields will get printed under the 'original' message, if present
if options.print_wire:
stuff['messages'] = list(self.messages.values())
else:
stuff['messages'] = list(self.messages.values()) + list(self.extension_msgs.values())
stuff['subtypes'] = subtypes
print(template.render(**stuff), file=output)
def main(options, args=None, output=sys.stdout, lines=None):
genline = next_line(args, lines)
comment_set = []
token_name = None
# Create a new 'master' that serves as the coordinator for the file generation
master = Master()
for i in options.include:
master.add_include('#include <{}>'.format(i))
try:
while True:
ln, line = next(genline)
tokens = line.split(',')
token_type = tokens[0]
if not bool(line):
master.add_comments(comment_set)
comment_set = []
token_name = None
continue
if len(tokens) > 2:
token_name = tokens[1]
if token_type == 'subtype':
subtype, _, _ = master.add_type(tokens[1])
subtype.add_comments(list(comment_set))
comment_set = []
elif token_type == 'subtypedata':
subtype = master.find_type(tokens[1])
if not subtype:
raise ValueError('Unknown subtype {} for data.\nat {}:{}'
.format(tokens[1], ln, line))
type_obj, collapse, optional = master.add_type(tokens[3], tokens[2], tokens[1])
if optional:
raise ValueError('Subtypes cannot have optional fields {}.{}\n at {}:{}'
.format(subtype.name, tokens[2], ln, line))
if collapse:
count = 1
else:
count = tokens[4]
subtype.add_data_field(tokens[2], type_obj, count, comments=list(comment_set),
optional=optional)
comment_set = []
elif token_type == 'tlvtype':
tlv = master.add_tlv(tokens[1])
tlv.add_message(tokens[2:], comments=list(comment_set))
comment_set = []
elif token_type == 'tlvdata':
type_obj, collapse, optional = master.add_type(tokens[4], tokens[3], tokens[1])
if optional:
raise ValueError('TLV messages cannot have optional fields {}.{}\n at {}:{}'
.format(tokens[2], tokens[3], ln, line))
tlv = master.find_tlv(tokens[1])
if not tlv:
raise ValueError('tlvdata for unknown tlv {}.\nat {}:{}'
.format(tokens[1], ln, line))
msg = tlv.find_message(tokens[2])
if not msg:
raise ValueError('tlvdata for unknown tlv-message {}.\nat {}:{}'
.format(tokens[2], ln, line))
if collapse:
count = 1
else:
count = tokens[5]
msg.add_data_field(tokens[3], type_obj, count, comments=list(comment_set),
optional=optional, implicit_len_ok=True)
comment_set = []
elif token_type == 'msgtype':
master.add_message(tokens[1:], comments=list(comment_set))
comment_set = []
elif token_type == 'msgdata':
msg = master.find_message(tokens[1])
if not msg:
raise ValueError('Unknown message type {}. {}:{}'.format(tokens[1], ln, line))
type_obj, collapse, optional = master.add_type(tokens[3], tokens[2], tokens[1])
if collapse:
count = 1
elif len(tokens) < 5:
raise ValueError('problem with parsing {}:{}'.format(ln, line))
else:
count = tokens[4]
# if this is an 'extension' field*, we want to add a new 'message' type
# in the future, extensions will be handled as TLV's
#
# *(in the spec they're called 'optional', but that term is overloaded
# in that internal wire messages have 'optional' fields that are treated
# differently. for the sake of clarity here, for bolt-wire messages,
# we'll refer to 'optional' message fields as 'extensions')
#
if tokens[5:] == []:
msg.add_data_field(tokens[2], type_obj, count, comments=list(comment_set),
optional=optional)
else: # is one or more extension fields
if optional:
raise ValueError("Extension fields cannot be optional. {}:{}"
.format(ln, line))
orig_msg = msg
for extension in tokens[5:]:
extension_name = "{}_{}".format(tokens[1], extension)
msg = master.find_message(extension_name)
if not msg:
msg = copy.deepcopy(orig_msg)
msg.enumname = msg.name
msg.name = extension_name
master.add_extension_msg(msg.name, msg)
msg.add_data_field(tokens[2], type_obj, count, comments=list(comment_set), optional=optional)
# If this is a print_wire page, add the extension fields to the
# original message, so we can print them if present.
if options.print_wire:
orig_msg.add_data_field(tokens[2], type_obj, count=count,
extensions=tokens[5:],
comments=list(comment_set),
optional=optional)
comment_set = []
elif token_type.startswith('#include'):
master.add_include(token_type)
elif token_type.startswith('#if'):
msg = master.find_message(token_name)
if (msg):
if_token = token_type[token_type.index(' ') + 1:]
msg.add_if(if_token)
elif token_type.startswith('#'):
comment_set.append(token_type[1:])
else:
raise ValueError("Unknown token type {} on line {}:{}".format(token_type, ln, line))
except StopIteration:
pass
master.post_process()
master.write(options, output)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-s", "--expose-subtypes", help="print subtypes in header",
action="store_true", default=False)
parser.add_argument("-P", "--print_wire", help="generate wire printing source files",
action="store_true", default=False)
parser.add_argument("--page", choices=['header', 'impl'], help="page to print")
parser.add_argument('--expose-tlv-type', action='append', default=[])
parser.add_argument('--include', action='append', default=[])
parser.add_argument('header_filename', help='The filename of the header')
parser.add_argument('enum_name', help='The name of the enum to produce')
parser.add_argument("files", help='Files to read in (or stdin)', nargs=REMAINDER)
parsed_args = parser.parse_args()
main(parsed_args, parsed_args.files)