mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-03-03 10:46:58 +01:00
generate-wire.py: allow optional typename in csv file.
For our internal CSV files, we can specify the type explicitly rather than trying to guess (eg. bool). Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
parent
8ad1298f88
commit
937a62100f
1 changed files with 107 additions and 64 deletions
|
@ -8,31 +8,70 @@ import re
|
|||
|
||||
Enumtype = namedtuple('Enumtype', ['name', 'value'])
|
||||
|
||||
class FieldType(object):
|
||||
def __init__(self,name):
|
||||
self.name = name
|
||||
self.tsize = FieldType._typesize(name)
|
||||
|
||||
def is_assignable(self):
|
||||
return self.name == 'u8' or self.name == 'u16' or self.name == 'u32' or self.name == 'u64'
|
||||
|
||||
# Returns typename and base size
|
||||
@staticmethod
|
||||
def _typesize(typename):
|
||||
if typename == 'pad':
|
||||
return 1
|
||||
elif typename == 'struct channel_id':
|
||||
return 8
|
||||
elif typename == 'struct ipv6':
|
||||
return 16
|
||||
elif typename == 'struct signature':
|
||||
return 64
|
||||
elif typename == 'struct pubkey':
|
||||
return 33
|
||||
elif typename == 'struct sha256':
|
||||
return 32
|
||||
elif typename == 'u64':
|
||||
return 8
|
||||
elif typename == 'u32':
|
||||
return 4
|
||||
elif typename == 'u16':
|
||||
return 2
|
||||
elif typename == 'u8':
|
||||
return 1
|
||||
else:
|
||||
raise ValueError('Unknown typename {}'.format(typename))
|
||||
|
||||
class Field(object):
|
||||
def __init__(self,message,name,size,comments):
|
||||
def __init__(self,message,name,size,comments,typename=None):
|
||||
self.message = message
|
||||
self.comments = comments
|
||||
self.name = name.replace('-', '_')
|
||||
self.is_len_var = False
|
||||
(self.typename, self.basesize) = Field._guess_type(message,self.name,size)
|
||||
self.lenvar = None
|
||||
|
||||
# Size could be a literal number (eg. 33), or a field (eg 'len'), or
|
||||
# a multiplier of a field (eg. num-htlc-timeouts*64).
|
||||
try:
|
||||
if int(size) % self.basesize != 0:
|
||||
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize))
|
||||
self.num_elems = int(int(size) / self.basesize)
|
||||
base_size = int(size)
|
||||
except ValueError:
|
||||
self.num_elems = 0
|
||||
# If it's a multiplicitive expression, must end in basesize.
|
||||
if '*' in size:
|
||||
tail='*' + str(self.basesize)
|
||||
if not size.endswith(tail):
|
||||
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(size,self.message,self.name,self.basesize))
|
||||
size = size[:-len(tail)]
|
||||
base_size = int(size.split('*')[1])
|
||||
self.lenvar = size.split('*')[0]
|
||||
else:
|
||||
if self.basesize != 1:
|
||||
raise ValueError('Invalid size {} for {}.{} not expressed as a multiple of {}'.format(size,self.message,self.name,self.basesize))
|
||||
base_size = 0
|
||||
self.lenvar = size
|
||||
self.lenvar = self.lenvar.replace('-','_')
|
||||
|
||||
self.lenvar = size.replace('-','_')
|
||||
if typename is None:
|
||||
self.fieldtype = Field._guess_type(message,self.name,base_size)
|
||||
else:
|
||||
self.fieldtype = FieldType(typename)
|
||||
|
||||
if base_size % self.fieldtype.tsize != 0:
|
||||
raise ValueError('Invalid size {} for {}.{} not a multiple of {}'.format(base_size,self.message,self.name,self.fieldtype.tsize))
|
||||
self.num_elems = int(base_size / self.fieldtype.tsize)
|
||||
|
||||
def is_padding(self):
|
||||
return self.name.startswith('pad')
|
||||
|
@ -42,68 +81,67 @@ class Field(object):
|
|||
return self.num_elems > 1 or self.is_padding()
|
||||
|
||||
def is_variable_size(self):
|
||||
return self.num_elems == 0
|
||||
return self.lenvar is not None
|
||||
|
||||
def is_assignable(self):
|
||||
if self.is_array() or self.is_variable_size():
|
||||
return False
|
||||
return self.typename == 'u8' or self.typename == 'u16' or self.typename == 'u32' or self.typename == 'u64'
|
||||
return self.fieldtype.is_assignable()
|
||||
|
||||
# Returns typename and base size
|
||||
# Returns FieldType
|
||||
@staticmethod
|
||||
def _guess_type(message, fieldname, sizestr):
|
||||
def _guess_type(message, fieldname, base_size):
|
||||
if fieldname.startswith('pad'):
|
||||
return ('pad',1)
|
||||
return FieldType('pad')
|
||||
|
||||
if fieldname.endswith('channel_id'):
|
||||
return ('struct channel_id',8)
|
||||
return FieldType('struct channel_id')
|
||||
|
||||
if message == 'node_announcement' and fieldname == 'ipv6':
|
||||
return ('struct ipv6',16)
|
||||
return FieldType('struct ipv6')
|
||||
|
||||
if message == 'node_announcement' and fieldname == 'alias':
|
||||
return ('u8',1)
|
||||
return FieldType('u8')
|
||||
|
||||
if fieldname.endswith('features'):
|
||||
return ('u8',1)
|
||||
|
||||
if fieldname == 'addresses':
|
||||
return ('u8', 1)
|
||||
return FieldType('u8')
|
||||
|
||||
# We translate signatures and pubkeys.
|
||||
if 'signature' in fieldname:
|
||||
return ('struct signature',64)
|
||||
|
||||
# The remainder should be fixed sizes.
|
||||
if sizestr == '33':
|
||||
return ('struct pubkey',33)
|
||||
if sizestr == '32':
|
||||
return ('struct sha256',32)
|
||||
if sizestr == '8':
|
||||
return ('u64',8)
|
||||
if sizestr == '4':
|
||||
return ('u32',4)
|
||||
if sizestr == '2':
|
||||
return ('u16',2)
|
||||
if sizestr == '1':
|
||||
return ('u8',1)
|
||||
return FieldType('struct signature')
|
||||
|
||||
# We whitelist specific things here, otherwise we'd treat everything
|
||||
# as a u8 array.
|
||||
if message == 'update_fail_htlc' and fieldname == 'reason':
|
||||
return ('u8', 1)
|
||||
return FieldType('u8')
|
||||
if message == 'update_add_htlc' and fieldname == 'onion_routing_packet':
|
||||
return ('u8', 1)
|
||||
return FieldType('u8')
|
||||
if message == 'node_announcement' and fieldname == 'alias':
|
||||
return ('u8',1)
|
||||
return FieldType('u8')
|
||||
if message == 'error' and fieldname == 'data':
|
||||
return ('u8',1)
|
||||
return FieldType('u8')
|
||||
if message == 'shutdown' and fieldname == 'scriptpubkey':
|
||||
return ('u8',1)
|
||||
return FieldType('u8')
|
||||
if message == 'node_announcement' and fieldname == 'rgb_color':
|
||||
return ('u8',1)
|
||||
return FieldType('u8')
|
||||
if message == 'node_announcement' and fieldname == 'addresses':
|
||||
return FieldType('u8')
|
||||
|
||||
raise ValueError('Unknown size {} for {}'.format(sizestr,fieldname))
|
||||
# The remainder should be fixed sizes.
|
||||
if base_size == 33:
|
||||
return FieldType('struct pubkey')
|
||||
if base_size == 32:
|
||||
return FieldType('struct sha256')
|
||||
if base_size == 8:
|
||||
return FieldType('u64')
|
||||
if base_size == 4:
|
||||
return FieldType('u32')
|
||||
if base_size == 2:
|
||||
return FieldType('u16')
|
||||
if base_size == 1:
|
||||
return FieldType('u8')
|
||||
|
||||
raise ValueError('Unknown size {} for {}'.format(base_size,fieldname))
|
||||
|
||||
class Message(object):
|
||||
def __init__(self,name,enum,comments):
|
||||
|
@ -116,7 +154,7 @@ class Message(object):
|
|||
def checkLenField(self,field):
|
||||
for f in self.fields:
|
||||
if f.name == field.lenvar:
|
||||
if f.typename != 'u16':
|
||||
if f.fieldtype.name != 'u16':
|
||||
raise ValueError('Field {} has non-u16 length variable {}'
|
||||
.format(field.name, field.lenvar))
|
||||
|
||||
|
@ -151,11 +189,11 @@ class Message(object):
|
|||
if f.is_padding():
|
||||
continue
|
||||
if f.is_array():
|
||||
print(', {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='')
|
||||
print(', {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
|
||||
elif f.is_variable_size():
|
||||
print(', {} **{}'.format(f.typename, f.name), end='')
|
||||
print(', {} **{}'.format(f.fieldtype.name, f.name), end='')
|
||||
else:
|
||||
print(', {} *{}'.format(f.typename, f.name), end='')
|
||||
print(', {} *{}'.format(f.fieldtype.name, f.name), end='')
|
||||
|
||||
if is_header:
|
||||
print(');')
|
||||
|
@ -166,7 +204,7 @@ class Message(object):
|
|||
|
||||
for f in self.fields:
|
||||
if f.is_len_var:
|
||||
print('\t{} {};'.format(f.typename, f.name));
|
||||
print('\t{} {};'.format(f.fieldtype.name, f.name));
|
||||
|
||||
print('\tconst u8 *cursor = p;\n'
|
||||
'\tsize_t tmp_len;\n'
|
||||
|
@ -180,9 +218,9 @@ class Message(object):
|
|||
.format(self.enum.name))
|
||||
|
||||
for f in self.fields:
|
||||
basetype=f.typename
|
||||
if f.typename.startswith('struct '):
|
||||
basetype=f.typename[7:]
|
||||
basetype=f.fieldtype.name
|
||||
if f.fieldtype.name.startswith('struct '):
|
||||
basetype=f.fieldtype.name[7:]
|
||||
|
||||
for c in f.comments:
|
||||
print('\t/*{} */'.format(c))
|
||||
|
@ -197,7 +235,7 @@ class Message(object):
|
|||
elif f.is_variable_size():
|
||||
print("\t//2th case", f.name)
|
||||
print('\t*{} = tal_arr(ctx, {}, {});'
|
||||
.format(f.name, f.typename, f.lenvar))
|
||||
.format(f.name, f.fieldtype.name, f.lenvar))
|
||||
print('\tfromwire_{}_array(&cursor, plen, *{}, {});'
|
||||
.format(basetype, f.name, f.lenvar))
|
||||
elif f.is_assignable():
|
||||
|
@ -225,11 +263,11 @@ class Message(object):
|
|||
if f.is_padding():
|
||||
continue
|
||||
if f.is_array():
|
||||
print(', const {} {}[{}]'.format(f.typename, f.name, f.num_elems), end='')
|
||||
print(', const {} {}[{}]'.format(f.fieldtype.name, f.name, f.num_elems), end='')
|
||||
elif f.is_assignable():
|
||||
print(', {} {}'.format(f.typename, f.name), end='')
|
||||
print(', {} {}'.format(f.fieldtype.name, f.name), end='')
|
||||
else:
|
||||
print(', const {} *{}'.format(f.typename, f.name), end='')
|
||||
print(', const {} *{}'.format(f.fieldtype.name, f.name), end='')
|
||||
|
||||
if is_header:
|
||||
print(');')
|
||||
|
@ -242,9 +280,9 @@ class Message(object):
|
|||
'\ttowire_u16(&p, {});'.format(self.enum.name))
|
||||
|
||||
for f in self.fields:
|
||||
basetype=f.typename
|
||||
if f.typename.startswith('struct '):
|
||||
basetype=f.typename[7:]
|
||||
basetype=f.fieldtype.name
|
||||
if f.fieldtype.name.startswith('struct '):
|
||||
basetype=f.fieldtype.name[7:]
|
||||
|
||||
for c in f.comments:
|
||||
print('\t/*{} */'.format(c))
|
||||
|
@ -311,10 +349,15 @@ for line in fileinput.input(args[2:]):
|
|||
messages.append(Message(parts[0],Enumtype("WIRE_" + parts[0].upper(), int(parts[1],0)),comments))
|
||||
comments=[]
|
||||
else:
|
||||
# eg commit_sig,0,channel-id,8
|
||||
# eg commit_sig,0,channel-id,8 OR
|
||||
# commit_sig,0,channel-id,8,u64
|
||||
for m in messages:
|
||||
if m.name == parts[0]:
|
||||
m.addField(Field(parts[0], parts[2], parts[3], comments))
|
||||
if len(parts) == 4:
|
||||
m.addField(Field(parts[0], parts[2], parts[3], comments))
|
||||
else:
|
||||
m.addField(Field(parts[0], parts[2], parts[3], comments,
|
||||
parts[4]))
|
||||
break
|
||||
comments=[]
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue