mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-03-03 10:46:58 +01:00
generate-wire.py: allow optional typename in csv file.
For our internal CSV files, we can specify the type explicitly rather than trying to guess (eg. bool). Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
parent
8ad1298f88
commit
937a62100f
1 changed files with 107 additions and 64 deletions
|
@ -8,31 +8,70 @@ import re
|
||||||
|
|
||||||
Enumtype = namedtuple('Enumtype', ['name', 'value'])
|
Enumtype = namedtuple('Enumtype', ['name', 'value'])
|
||||||
|
|
||||||
|
class FieldType(object):
|
||||||
|
def __init__(self,name):
|
||||||
|
self.name = name
|
||||||
|
self.tsize = FieldType._typesize(name)
|
||||||
|
|
||||||
|
def is_assignable(self):
|
||||||
|
return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64'
|
||||||
|
|
||||||
|
# Returns typename and base size
|
||||||
|
@staticmethod
|
||||||
|
def _typesize(typename):
|
||||||
|
if typename == 'pad':
|
||||||
|
return 1
|
||||||
|
elif typename == 'struct channel_id':
|
||||||
|
return 8
|
||||||
|
elif typename == 'struct ipv6':
|
||||||
|
return 16
|
||||||
|
elif typename == 'struct signature':
|
||||||
|
return 64
|
||||||
|
elif typename == 'struct pubkey':
|
||||||
|
return 33
|
||||||
|
elif typename == 'struct sha256':
|
||||||
|
return 32
|
||||||
|
elif typename == 'u64':
|
||||||
|
return 8
|
||||||
|
elif typename == 'u32':
|
||||||
|
return 4
|
||||||
|
elif typename == 'u16':
|
||||||
|
return 2
|
||||||
|
elif typename == 'u8':
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown typename {}'.format(typename))
|
||||||
|
|
||||||
class Field(object):
|
class Field(object):
|
||||||
def __init__(self,message,name,size,comments):
|
def __init__(self,message,name,size,comments,typename=None):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.comments = comments
|
self.comments = comments
|
||||||
self.name = name.replace('-', '_')
|
self.name = name.replace('-', '_')
|
||||||
self.is_len_var = False
|
self.is_len_var = False
|
||||||
(self.typename, self.basesize) = Field._guess_type(message,self.name,size)
|
self.lenvar = None
|
||||||
|
|
||||||
|
# Size could be a literal number (eg. 33), or a field (eg 'len'), or
|
||||||
|
# a multiplier of a field (eg. num-htlc-timeouts*64).
|
||||||
try:
|
try:
|
||||||
if int(size) % self.basesize != 0:
|
base_size = int(size)
|
||||||
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize))
|
|
||||||
self.num_elems = int(int(size) / self.basesize)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self.num_elems = 0
|
|
||||||
# If it's a multiplicitive expression, must end in basesize.
|
# If it's a multiplicitive expression, must end in basesize.
|
||||||
if '*' in size:
|
if '*' in size:
|
||||||
tail='*' + str(self.basesize)
|
base_size = int(size.split('*')[1])
|
||||||
if not size.endswith(tail):
|
self.lenvar = size.split('*')[0]
|
||||||
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize))
|
|
||||||
size = size[:-len(tail)]
|
|
||||||
else:
|
else:
|
||||||
if self.basesize != 1:
|
base_size = 0
|
||||||
raise ValueError('Invalid size {} for {}.{} not expressed as a multiple of {}'.format(size,self.message,self.name,self.basesize))
|
self.lenvar = size
|
||||||
|
self.lenvar = self.lenvar.replace('-','_')
|
||||||
|
|
||||||
self.lenvar = size.replace('-','_')
|
if typename is None:
|
||||||
|
self.fieldtype = Field._guess_type(message,self.name,base_size)
|
||||||
|
else:
|
||||||
|
self.fieldtype = FieldType(typename)
|
||||||
|
|
||||||
|
if base_size % self.fieldtype.tsize != 0:
|
||||||
|
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(base_size,self.message,self.name,self.fieldtype.tsize))
|
||||||
|
self.num_elems = int(base_size / self.fieldtype.tsize)
|
||||||
|
|
||||||
def is_padding(self):
|
def is_padding(self):
|
||||||
return self.name.startswith('pad')
|
return self.name.startswith('pad')
|
||||||
|
@ -42,68 +81,67 @@ class Field(object):
|
||||||
return self.num_elems > 1 or self.is_padding()
|
return self.num_elems > 1 or self.is_padding()
|
||||||
|
|
||||||
def is_variable_size(self):
|
def is_variable_size(self):
|
||||||
return self.num_elems == 0
|
return self.lenvar is not None
|
||||||
|
|
||||||
def is_assignable(self):
|
def is_assignable(self):
|
||||||
if self.is_array() or self.is_variable_size():
|
if self.is_array() or self.is_variable_size():
|
||||||
return False
|
return False
|
||||||
return self.typename == 'u8' or self.typename == 'u16' or self.typename == 'u32' or self.typename == 'u64'
|
return self.fieldtype.is_assignable()
|
||||||
|
|
||||||
# Returns typename and base size
|
# Returns FieldType
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _guess_type(message, fieldname, sizestr):
|
def _guess_type(message, fieldname, base_size):
|
||||||
if fieldname.startswith('pad'):
|
if fieldname.startswith('pad'):
|
||||||
return ('pad',1)
|
return FieldType('pad')
|
||||||
|
|
||||||
if fieldname.endswith('channel_id'):
|
if fieldname.endswith('channel_id'):
|
||||||
return ('struct channel_id',8)
|
return FieldType('struct channel_id')
|
||||||
|
|
||||||
if message == 'node_announcement' and fieldname == 'ipv6':
|
if message == 'node_announcement' and fieldname == 'ipv6':
|
||||||
return ('struct ipv6',16)
|
return FieldType('struct ipv6')
|
||||||
|
|
||||||
if message == 'node_announcement' and fieldname == 'alias':
|
if message == 'node_announcement' and fieldname == 'alias':
|
||||||
return ('u8',1)
|
return FieldType('u8')
|
||||||
|
|
||||||
if fieldname.endswith('features'):
|
if fieldname.endswith('features'):
|
||||||
return ('u8',1)
|
return FieldType('u8')
|
||||||
|
|
||||||
if fieldname == 'addresses':
|
|
||||||
return ('u8', 1)
|
|
||||||
|
|
||||||
# We translate signatures and pubkeys.
|
# We translate signatures and pubkeys.
|
||||||
if 'signature' in fieldname:
|
if 'signature' in fieldname:
|
||||||
return ('struct signature',64)
|
return FieldType('struct signature')
|
||||||
|
|
||||||
# The remainder should be fixed sizes.
|
|
||||||
if sizestr == '33':
|
|
||||||
return ('struct pubkey',33)
|
|
||||||
if sizestr == '32':
|
|
||||||
return ('struct sha256',32)
|
|
||||||
if sizestr == '8':
|
|
||||||
return ('u64',8)
|
|
||||||
if sizestr == '4':
|
|
||||||
return ('u32',4)
|
|
||||||
if sizestr == '2':
|
|
||||||
return ('u16',2)
|
|
||||||
if sizestr == '1':
|
|
||||||
return ('u8',1)
|
|
||||||
|
|
||||||
# We whitelist specific things here, otherwise we'd treat everything
|
# We whitelist specific things here, otherwise we'd treat everything
|
||||||
# as a u8 array.
|
# as a u8 array.
|
||||||
if message == 'update_fail_htlc' and fieldname == 'reason':
|
if message == 'update_fail_htlc' and fieldname == 'reason':
|
||||||
return ('u8', 1)
|
return FieldType('u8')
|
||||||
if message == 'update_add_htlc' and fieldname == 'onion_routing_packet':
|
if message == 'update_add_htlc' and fieldname == 'onion_routing_packet':
|
||||||
return ('u8', 1)
|
return FieldType('u8')
|
||||||
if message == 'node_announcement' and fieldname == 'alias':
|
if message == 'node_announcement' and fieldname == 'alias':
|
||||||
return ('u8',1)
|
return FieldType('u8')
|
||||||
if message == 'error' and fieldname == 'data':
|
if message == 'error' and fieldname == 'data':
|
||||||
return ('u8',1)
|
return FieldType('u8')
|
||||||
if message == 'shutdown' and fieldname == 'scriptpubkey':
|
if message == 'shutdown' and fieldname == 'scriptpubkey':
|
||||||
return ('u8',1)
|
return FieldType('u8')
|
||||||
if message == 'node_announcement' and fieldname == 'rgb_color':
|
if message == 'node_announcement' and fieldname == 'rgb_color':
|
||||||
return ('u8',1)
|
return FieldType('u8')
|
||||||
|
if message == 'node_announcement' and fieldname == 'addresses':
|
||||||
|
return FieldType('u8')
|
||||||
|
|
||||||
raise ValueError('Unknown size {} for {}'.format(sizestr,fieldname))
|
# The remainder should be fixed sizes.
|
||||||
|
if base_size == 33:
|
||||||
|
return FieldType('struct pubkey')
|
||||||
|
if base_size == 32:
|
||||||
|
return FieldType('struct sha256')
|
||||||
|
if base_size == 8:
|
||||||
|
return FieldType('u64')
|
||||||
|
if base_size == 4:
|
||||||
|
return FieldType('u32')
|
||||||
|
if base_size == 2:
|
||||||
|
return FieldType('u16')
|
||||||
|
if base_size == 1:
|
||||||
|
return FieldType('u8')
|
||||||
|
|
||||||
|
raise ValueError('Unknown size {} for {}'.format(base_size,fieldname))
|
||||||
|
|
||||||
class Message(object):
|
class Message(object):
|
||||||
def __init__(self,name,enum,comments):
|
def __init__(self,name,enum,comments):
|
||||||
|
@ -116,7 +154,7 @@ class Message(object):
|
||||||
def checkLenField(self,field):
|
def checkLenField(self,field):
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
if f.name == field.lenvar:
|
if f.name == field.lenvar:
|
||||||
if f.typename != 'u16':
|
if f.fieldtype.name != 'u16':
|
||||||
raise ValueError('Field {} has non-u16 length variable {}'
|
raise ValueError('Field {} has non-u16 length variable {}'
|
||||||
.format(field.name, field.lenvar))
|
.format(field.name, field.lenvar))
|
||||||
|
|
||||||
|
@ -151,11 +189,11 @@ class Message(object):
|
||||||
if f.is_padding():
|
if f.is_padding():
|
||||||
continue
|
continue
|
||||||
if f.is_array():
|
if f.is_array():
|
||||||
print(', {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='')
|
print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
|
||||||
elif f.is_variable_size():
|
elif f.is_variable_size():
|
||||||
print(', {} **{}'.format(f.typename, f.name), end='')
|
print(', {} **{}'.format(f.fieldtype.name, f.name), end='')
|
||||||
else:
|
else:
|
||||||
print(', {} *{}'.format(f.typename, f.name), end='')
|
print(', {} *{}'.format(f.fieldtype.name, f.name), end='')
|
||||||
|
|
||||||
if is_header:
|
if is_header:
|
||||||
print(');')
|
print(');')
|
||||||
|
@ -166,7 +204,7 @@ class Message(object):
|
||||||
|
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
if f.is_len_var:
|
if f.is_len_var:
|
||||||
print('\t{} {};'.format(f.typename, f.name));
|
print('\t{} {};'.format(f.fieldtype.name, f.name));
|
||||||
|
|
||||||
print('\tconst u8 *cursor = p;\n'
|
print('\tconst u8 *cursor = p;\n'
|
||||||
'\tsize_t tmp_len;\n'
|
'\tsize_t tmp_len;\n'
|
||||||
|
@ -180,9 +218,9 @@ class Message(object):
|
||||||
.format(self.enum.name))
|
.format(self.enum.name))
|
||||||
|
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
basetype=f.typename
|
basetype=f.fieldtype.name
|
||||||
if f.typename.startswith('struct '):
|
if f.fieldtype.name.startswith('struct '):
|
||||||
basetype=f.typename[7:]
|
basetype=f.fieldtype.name[7:]
|
||||||
|
|
||||||
for c in f.comments:
|
for c in f.comments:
|
||||||
print('\t/*{} */'.format(c))
|
print('\t/*{} */'.format(c))
|
||||||
|
@ -197,7 +235,7 @@ class Message(object):
|
||||||
elif f.is_variable_size():
|
elif f.is_variable_size():
|
||||||
print("\t//2th case", f.name)
|
print("\t//2th case", f.name)
|
||||||
print('\t*{} = tal_arr(ctx, {}, {});'
|
print('\t*{} = tal_arr(ctx, {}, {});'
|
||||||
.format(f.name, f.typename, f.lenvar))
|
.format(f.name, f.fieldtype.name, f.lenvar))
|
||||||
print('\tfromwire_{}_array(&cursor, plen, *{}, {});'
|
print('\tfromwire_{}_array(&cursor, plen, *{}, {});'
|
||||||
.format(basetype, f.name, f.lenvar))
|
.format(basetype, f.name, f.lenvar))
|
||||||
elif f.is_assignable():
|
elif f.is_assignable():
|
||||||
|
@ -225,11 +263,11 @@ class Message(object):
|
||||||
if f.is_padding():
|
if f.is_padding():
|
||||||
continue
|
continue
|
||||||
if f.is_array():
|
if f.is_array():
|
||||||
print(', const {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='')
|
print(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
|
||||||
elif f.is_assignable():
|
elif f.is_assignable():
|
||||||
print(', {} {}'.format(f.typename, f.name), end='')
|
print(', {} {}'.format(f.fieldtype.name, f.name), end='')
|
||||||
else:
|
else:
|
||||||
print(', const {} *{}'.format(f.typename, f.name), end='')
|
print(', const {} *{}'.format(f.fieldtype.name, f.name), end='')
|
||||||
|
|
||||||
if is_header:
|
if is_header:
|
||||||
print(');')
|
print(');')
|
||||||
|
@ -242,9 +280,9 @@ class Message(object):
|
||||||
'\ttowire_u16(&p, {});'.format(self.enum.name))
|
'\ttowire_u16(&p, {});'.format(self.enum.name))
|
||||||
|
|
||||||
for f in self.fields:
|
for f in self.fields:
|
||||||
basetype=f.typename
|
basetype=f.fieldtype.name
|
||||||
if f.typename.startswith('struct '):
|
if f.fieldtype.name.startswith('struct '):
|
||||||
basetype=f.typename[7:]
|
basetype=f.fieldtype.name[7:]
|
||||||
|
|
||||||
for c in f.comments:
|
for c in f.comments:
|
||||||
print('\t/*{} */'.format(c))
|
print('\t/*{} */'.format(c))
|
||||||
|
@ -311,10 +349,15 @@ for line in fileinput.input(args[2:]):
|
||||||
messages.append(Message(parts[0],Enumtype("WIRE_" + parts[0].upper(), int(parts[1],0)),comments))
|
messages.append(Message(parts[0],Enumtype("WIRE_" + parts[0].upper(), int(parts[1],0)),comments))
|
||||||
comments=[]
|
comments=[]
|
||||||
else:
|
else:
|
||||||
# eg commit_sig,0,channel-id,8
|
# eg commit_sig,0,channel-id,8 OR
|
||||||
|
# commit_sig,0,channel-id,8,u64
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if m.name == parts[0]:
|
if m.name == parts[0]:
|
||||||
m.addField(Field(parts[0], parts[2], parts[3], comments))
|
if len(parts) == 4:
|
||||||
|
m.addField(Field(parts[0], parts[2], parts[3], comments))
|
||||||
|
else:
|
||||||
|
m.addField(Field(parts[0], parts[2], parts[3], comments,
|
||||||
|
parts[4]))
|
||||||
break
|
break
|
||||||
comments=[]
|
comments=[]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue