bolt-gen: add field optional handling

we'll need this for internal wire message formats. also disambiguates
from 'bolt message optional fields', which we rename to extensions here.

example of an optional field declaration (note the ? prefixing the
type):

    msgdata,msg_name,field_name,?type,count

these are handled with either a boolean if they're not present,
or a true value and then the object if they are.
This commit is contained in:
lisa neigut 2019-07-19 11:43:50 -05:00 committed by Rusty Russell
parent ade594e941
commit 181e1916b2

View File

@ -38,13 +38,16 @@ def next_line(args, lines):
# Class definitions, to keep things classy
class Field(object):
def __init__(self, name, type_obj, optional=False, field_comments=[]):
def __init__(self, name, type_obj, extension=False,
field_comments=[], optional=False):
self.name = name
self.type_obj = type_obj
self.count = 1
self.is_optional = optional
self.len_field_of = None
self.len_field = None
self.is_extension = extension
self.is_optional = optional
self.field_comments = field_comments
def add_count(self, count):
@ -66,6 +69,9 @@ class Field(object):
def is_optional(self):
return self.is_optional
def is_extension(self):
return self.is_extension
def size(self):
if self.count:
return self.count
@ -101,15 +107,16 @@ class Field(object):
class FieldSet(object):
def __init__(self):
self.fields = OrderedDict()
self.optional_fields = False
self.extension_fields = False
self.len_fields = {}
def add_data_field(self, field_name, type_obj, count=1, is_optional=[], comments=[]):
# FIXME: use this somewhere?
if is_optional:
self.optional_fields = True
def add_data_field(self, field_name, type_obj, count=1,
is_extension=[], comments=[], optional=False):
if is_extension:
self.extension_fields = True
field = Field(field_name, type_obj, bool(is_optional), comments)
field = Field(field_name, type_obj, extension=bool(is_extension),
field_comments=comments, optional=optional)
if bool(count):
try:
field.add_count(int(count))
@ -201,8 +208,10 @@ class Type(FieldSet):
self.is_enum = False
self.type_comments = []
def add_data_field(self, field_name, type_obj, count=1, is_optional=[], comments=[]):
FieldSet.add_data_field(self, field_name, type_obj, count, is_optional, comments)
def add_data_field(self, field_name, type_obj, count=1,
is_extension=[], comments=[], optional=False):
FieldSet.add_data_field(self, field_name, type_obj, count,
is_extension, comments=comments, optional=optional)
if type_obj.name not in self.depends_on:
self.depends_on[type_obj.name] = type_obj
@ -305,12 +314,16 @@ class Master(object):
self.extension_msgs[name] = msg
def add_type(self, type_name, field_name=None):
optional = False
if type_name.startswith('?'):
type_name = type_name[1:]
optional = True
# Check for special type name re-mapping
type_name, collapse_original = Type.true_type(type_name, field_name)
if type_name not in self.types:
self.types[type_name] = Type(type_name)
return self.types[type_name], collapse_original
return self.types[type_name], collapse_original, optional
def find_type(self, type_name):
return self.types[type_name]
@ -398,7 +411,7 @@ def main(options, args=None, output=sys.stdout, lines=None):
continue
if token_type == 'subtype':
subtype, _ = master.add_type(tokens[1])
subtype, _, _ = master.add_type(tokens[1])
subtype.add_comments(list(comment_set))
comment_set = []
@ -407,13 +420,17 @@ def main(options, args=None, output=sys.stdout, lines=None):
if not subtype:
raise ValueError('Unknown subtype {} for data.\nat {}:{}'
.format(tokens[1], ln, line))
type_obj, collapse = master.add_type(tokens[3], tokens[2])
type_obj, collapse, optional = master.add_type(tokens[3], tokens[2])
if optional:
raise ValueError('Subtypes cannot have optional fields {}.{}\n at {}:{}'
.format(subtype.name, tokens[2], ln, line))
if collapse:
count = 1
else:
count = tokens[4]
subtype.add_data_field(tokens[2], type_obj, count, list(comment_set))
subtype.add_data_field(tokens[2], type_obj, count, comments=list(comment_set),
optional=optional)
comment_set = []
elif token_type == 'tlvtype':
tlv = master.add_tlv(tokens[1])
@ -421,7 +438,11 @@ def main(options, args=None, output=sys.stdout, lines=None):
comment_set = []
elif token_type == 'tlvdata':
type_obj, collapse = master.add_type(tokens[4], tokens[3])
type_obj, collapse, optional = master.add_type(tokens[4], tokens[3])
if optional:
raise ValueError('TLV messages cannot have optional fields {}.{}\n at {}:{}'
.format(tokens[2], tokens[3], ln, line))
tlv = master.find_tlv(tokens[1])
if not tlv:
raise ValueError('tlvdata for unknown tlv {}.\nat {}:{}'
@ -435,7 +456,8 @@ def main(options, args=None, output=sys.stdout, lines=None):
else:
count = tokens[5]
msg.add_data_field(tokens[3], type_obj, count, list(comment_set))
msg.add_data_field(tokens[3], type_obj, count, comments=list(comment_set),
optional=optional)
comment_set = []
elif token_type == 'msgtype':
master.add_message(tokens[1:], comments=list(comment_set))
@ -444,7 +466,7 @@ def main(options, args=None, output=sys.stdout, lines=None):
msg = master.find_message(tokens[1])
if not msg:
raise ValueError('Unknown message type {}. {}:{}'.format(tokens[1], ln, line))
type_obj, collapse = master.add_type(tokens[3], tokens[2])
type_obj, collapse, optional = master.add_type(tokens[3], tokens[2])
# if this is an 'extension' field*, we want to add a new 'message' type
# in the future, extensions will be handled as TLV's
@ -469,7 +491,8 @@ def main(options, args=None, output=sys.stdout, lines=None):
else:
count = tokens[4]
msg.add_data_field(tokens[2], type_obj, count, list(comment_set))
msg.add_data_field(tokens[2], type_obj, count, comments=list(comment_set),
optional=optional)
comment_set = []
elif token_type.startswith('#include'):
master.add_include(token_type)