tlv: calculate sizeof by measuring message length

much better than statically calculating the sizeof
This commit is contained in:
lisa neigut 2019-03-28 14:25:19 -07:00 committed by Rusty Russell
parent 9a23a354fd
commit aba4e161ce

View File

@ -235,15 +235,12 @@ fromwire_impl_templ = """bool fromwire_{name}({ctx}const void *p{args})
fromwire_tlv_impl_templ = """static bool _fromwire_{tlv_name}_{name}({ctx}{args})
{{
\tsize_t start_len, plen;
\tsize_t start_len = *plen;
{fields}
\tconst u8 *cursor = p;
\tplen = tal_count(p);
\tif (plen < len)
\tif (start_len < len)
\t\treturn false;
\tstart_len = plen;
{subcalls}
\treturn cursor != NULL && (start_len - plen == len);
\treturn cursor != NULL && (start_len - *plen == len);
}}
"""
@ -382,22 +379,23 @@ class Message(object):
self.has_variable_fields = True
self.fields.append(field)
def print_fromwire_array(self, ctx, subcalls, basetype, f, name, num_elems):
def print_fromwire_array(self, ctx, subcalls, basetype, f, name, num_elems, is_tlv=False):
p_ref = '' if is_tlv else '&'
if f.has_array_helper():
subcalls.append('fromwire_{}_array(&cursor, &plen, {}, {});'
.format(basetype, name, num_elems))
subcalls.append('fromwire_{}_array(&cursor, {}plen, {}, {});'
.format(basetype, p_ref, name, num_elems))
else:
subcalls.append('for (size_t i = 0; i < {}; i++)'
.format(num_elems))
if f.fieldtype.is_assignable():
subcalls.append('({})[i] = fromwire_{}(&cursor, &plen);'
.format(name, basetype))
subcalls.append('({})[i] = fromwire_{}(&cursor, {}plen);'
.format(name, basetype, p_ref))
elif basetype in varlen_structs:
subcalls.append('({})[i] = fromwire_{}({}, &cursor, &plen);'
.format(name, basetype, ctx))
subcalls.append('({})[i] = fromwire_{}({}, &cursor, {}plen);'
.format(name, basetype, ctx, p_ref))
else:
subcalls.append('fromwire_{}(&cursor, &plen, {} + i);'
.format(basetype, name))
subcalls.append('fromwire_{}(&cursor, {}plen, {} + i);'
.format(basetype, p_ref, name))
def print_tlv_fromwire(self, tlv_name):
""" prints fromwire function definition for a TLV message.
@ -405,7 +403,7 @@ class Message(object):
to populate, instead of fields, as well as a length to read in
"""
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
args = 'const void *p, const u16 len, struct _tlv_msg_{name} *{name}'.format(name=self.name)
args = 'const u8 *cursor, size_t *plen, const u16 len, struct _tlv_msg_{name} *{name}'.format(name=self.name)
fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var]
subcalls = CCode()
for f in self.fields:
@ -419,12 +417,12 @@ class Message(object):
subcalls.append('/*{} */'.format(c))
if f.is_padding():
subcalls.append('fromwire_pad(&cursor, &plen, {});'
subcalls.append('fromwire_pad(&cursor, plen, {});'
.format(f.num_elems))
elif f.is_array():
name = '*{}->{}'.format(self.name, f.name)
self.print_fromwire_array('ctx', subcalls, basetype, f, name,
f.num_elems)
f.num_elems, is_tlv=True)
elif f.is_variable_size():
subcalls.append("// 2nd case {name}".format(name=f.name))
typename = f.fieldtype.name
@ -437,16 +435,16 @@ class Message(object):
name = '{}->{}'.format(self.name, f.name)
# Allocate these off the array itself, if they need alloc.
self.print_fromwire_array('*' + f.name, subcalls, basetype, f,
name, f.lenvar)
name, f.lenvar, is_tlv=True)
else:
if f.is_assignable():
if f.is_len_var:
s = '{} = fromwire_{}(&cursor, &plen);'.format(f.name, basetype)
s = '{} = fromwire_{}(&cursor, plen);'.format(f.name, basetype)
else:
s = '{}->{} = fromwire_{}(&cursor, &plen);'.format(
s = '{}->{} = fromwire_{}(&cursor, plen);'.format(
self.name, f.name, basetype)
else:
s = 'fromwire_{}(&cursor, &plen, *{}->{});'.format(
s = 'fromwire_{}(&cursor, plen, *{}->{});'.format(
basetype, self.name, f.name)
subcalls.append(s)
@ -506,7 +504,7 @@ class Message(object):
elif f.is_tlv:
if not f.is_variable_size():
raise TypeError('TLV {} not variable size'.format(f.name))
subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &{tlv_len}, {tlv_name}))'
subcalls.append('if (!fromwire__{tlv_name}(ctx, &cursor, &plen, &{tlv_len}, {tlv_name}))'
.format(tlv_name=f.name, tlv_len=f.lenvar))
subcalls.append('return false;')
elif f.is_variable_size():
@ -585,7 +583,7 @@ class Message(object):
elif f.optional:
raise TypeError("Optional fields on TLV messages not currently supported. {}->{}".format(tlv_name, f.name))
if f.is_len_var:
field_decls.append('\t{0} {1} = tal_count(&{2}->{3});'.format(
field_decls.append('\t{0} {1} = tal_count({2}->{3});'.format(
f.fieldtype.name, f.name, self.name, f.lenvar_for.name
))
@ -611,6 +609,9 @@ class Message(object):
field_decls='\n'.join(field_decls),
subcalls=str(subcalls))
def find_tlv_lenvar_field(self, tlv_name):
return [f for f in self.fields if f.is_len_var and f.lenvar_for.is_tlv and f.lenvar_for.name == tlv_name][0]
def print_towire(self, is_header, tlv_name):
if self.is_tlv:
if is_header:
@ -636,9 +637,8 @@ class Message(object):
for f in self.fields:
if f.is_len_var:
if f.lenvar_for.is_tlv:
field_decls.append('\t{0} {1} = sizeof({2});'.format(
f.fieldtype.name, f.name, f.lenvar_for.name
))
# used below...
field_decls.append('\t{0} {1};'.format(f.fieldtype.name, f.name))
else:
field_decls.append('\t{0} {1} = tal_count({2});'.format(
f.fieldtype.name, f.name, f.lenvar_for.name
@ -656,11 +656,21 @@ class Message(object):
.format(f.num_elems))
elif f.is_array():
self.print_towire_array(subcalls, basetype, f, f.num_elems)
elif f.is_len_var and f.lenvar_for.is_tlv:
continue # taken care of below
elif f.is_tlv:
if not f.is_variable_size():
raise TypeError('TLV {} not variable size'.format(f.name))
subcalls.append('towire__{tlv_name}(&p, {tlv_name});'.format(
tlv_name=f.name))
raise ValueError('TLV {} not variable size'.format(f.name))
lenvar_field = self.find_tlv_lenvar_field(f.name)
subcalls.append('/* ~~build TLV for {} ~~*/'.format(f.name))
subcalls.append("u8 *{tlv_name}_buffer = tal_arr(ctx, u8, 0);\n"
"towire__{tlv_name}(ctx, &{tlv_name}_buffer, {tlv_name});\n"
"{lenvar_field} = tal_count({tlv_name}_buffer);\n"
"towire_{lenvar_fieldtype}(&p, {lenvar_field});\n"
"towire_u8_array(&p, {tlv_name}_buffer, {lenvar_field});\n".format(
tlv_name=f.name,
lenvar_field=lenvar_field.name,
lenvar_fieldtype=lenvar_field.fieldtype.name))
elif f.is_variable_size():
self.print_towire_array(subcalls, basetype, f, f.lenvar)
else:
@ -821,43 +831,47 @@ struct _{tlv_name} {{
"""
tlv__type_impl_towire_fields = """\tif ({tlv_name}->{name}) {{
\t\ttlv_msg = tal_arr(ctx, u8, 0);
\t\ttowire_u16(p, {enum});
\t\ttowire_u16(p, sizeof(*{tlv_name}->{name}));
\t\t_towire_{tlv_name}_{name}(p, {tlv_name}->{name});
\t\t_towire_{tlv_name}_{name}(&tlv_msg, {tlv_name}->{name});
\t\tmsg_len = tal_count(tlv_msg);
\t\ttowire_u16(p, msg_len);
\t\ttowire_u8_array(p, tlv_msg, msg_len);
\t}}
"""
tlv__type_impl_towire_template = """static void towire__{tlv_name}(u8 **p, const struct _{tlv_name} *{tlv_name}) {{
tlv__type_impl_towire_template = """static void towire__{tlv_name}(const tal_t *ctx, u8 **p, const struct _{tlv_name} *{tlv_name}) {{
\tu16 msg_len;
\tu8 *tlv_msg;
{fields}}}
"""
tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, const u16 *len, struct _{tlv_name} *{tlv_name}) {{
tlv__type_impl_fromwire_template = """static bool fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u16 *len, struct _{tlv_name} *{tlv_name}) {{
\tu16 msg_type, msg_len;
\tconst u8 *cursor = *p;
\tsize_t plen = tal_count(p);
\tif (plen != *len)
\tif (*plen < *len)
\t\treturn false;
\twhile (cursor && plen) {{
\t\tmsg_type = fromwire_u16(&cursor, &plen);
\t\tmsg_len = fromwire_u16(&cursor, &plen);
\t\tif (plen < msg_len) {{
\t\t\tfromwire_fail(&cursor, &plen);
\twhile (*plen) {{
\t\tmsg_type = fromwire_u16(p, plen);
\t\tmsg_len = fromwire_u16(p, plen);
\t\tif (*plen < msg_len) {{
\t\t\tfromwire_fail(p, plen);
\t\t\tbreak;
\t\t}}
\t\tswitch((enum {tlv_name}_type)msg_type) {{
{cases}\t\tdefault:
\t\t\t// FIXME: print a warning / message?
\t\t\tcursor += msg_len;
\t\t\t*p+= msg_len;
\t\t\tplen -= msg_len;
\t\t}}
\t}}
\treturn cursor != NULL;
\treturn *p != NULL;
}}
"""
case_tmpl = """\t\tcase {tlv_msg_enum}:
\t\t\tif (!_fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}cursor, msg_len, {tlv_name}->{tlv_msg_name}))
\t\t\t{tlv_name}->{tlv_msg_name} = tal(ctx, struct _tlv_msg_{tlv_msg_name});
\t\t\tif (!_fromwire_{tlv_name}_{tlv_msg_name}({ctx_arg}*p, plen, msg_len, {tlv_name}->{tlv_msg_name}))
\t\t\t\treturn false;
\t\t\tbreak;
"""
@ -1227,10 +1241,8 @@ else:
fromwire_decls.append(m.print_fromwire(options.header, tlv_field))
if not options.header:
tlv_towires = build_tlv_towires(tlv_fields)
tlv_fromwires = build_tlv_fromwires(tlv_fields)
towire_decls += tlv_towires
fromwire_decls += tlv_fromwires
towire_decls += build_tlv_towires(tlv_fields)
fromwire_decls += build_tlv_fromwires(tlv_fields)
towire_decls += [m.print_towire(options.header, '') for m in messages + messages_with_option]
fromwire_decls += [m.print_fromwire(options.header, '') for m in messages + messages_with_option]