msggen: Add model-side overrides

Sometimes we just want to paper over the schema directly. Mostly
useful to sidestep the `oneof` things that are required for
expressiveness.
This commit is contained in:
Christian Decker 2022-04-01 14:43:34 +10:30 committed by Rusty Russell
parent 1f40db3594
commit ec5cd92580
6 changed files with 54 additions and 28 deletions

View File

@ -40,7 +40,7 @@ enum ChannelState {
message ChannelStateChangeCause {}
message Utxo {
message Outpoint {
bytes txid = 1;
uint32 outnum = 2;
}

View File

@ -2,7 +2,7 @@ tonic::include_proto!("cln");
use cln_rpc::primitives::{
Amount as JAmount, AmountOrAll as JAmountOrAll, AmountOrAny as JAmountOrAny,
Feerate as JFeerate, OutputDesc as JOutputDesc, Utxo as JUtxo,
Feerate as JFeerate, OutputDesc as JOutputDesc, Outpoint as JOutpoint,
};
impl From<JAmount> for Amount {
@ -17,18 +17,18 @@ impl From<&Amount> for JAmount {
}
}
impl From<JUtxo> for Utxo {
fn from(a: JUtxo) -> Self {
Utxo {
impl From<JOutpoint> for Outpoint {
fn from(a: JOutpoint) -> Self {
Outpoint {
txid: a.txid,
outnum: a.outnum,
}
}
}
impl From<&Utxo> for JUtxo {
fn from(a: &Utxo) -> Self {
JUtxo {
impl From<&Outpoint> for JOutpoint {
fn from(a: &Outpoint) -> Self {
JOutpoint {
txid: a.txid.clone(),
outnum: a.outnum,
}

View File

@ -69,12 +69,12 @@ impl Amount {
}
#[derive(Clone, Debug, PartialEq)]
pub struct Utxo {
pub struct Outpoint {
pub txid: Vec<u8>,
pub outnum: u32,
}
impl Serialize for Utxo {
impl Serialize for Outpoint {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
@ -83,7 +83,7 @@ impl Serialize for Utxo {
}
}
impl<'de> Deserialize<'de> for Utxo {
impl<'de> Deserialize<'de> for Outpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
@ -102,7 +102,7 @@ impl<'de> Deserialize<'de> for Utxo {
.parse()
.map_err(|e| Error::custom(format!("{} is not a valid number: {}", s, e)))?;
Ok(Utxo { txid, outnum })
Ok(Outpoint { txid, outnum })
}
}

View File

@ -24,8 +24,9 @@ typemap = {
'u16': 'uint32', # Yeah, I know...
'f32': 'float',
'integer': 'sint64',
"utxo": "Utxo",
"outpoint": "Outpoint",
"feerate": "Feerate",
"outputdesc": "OutputDesc",
}
@ -41,6 +42,7 @@ overrides = {
'ListTransactions.transactions[].type[]': None,
}
method_name_overrides = {
"Connect": "ConnectPeer",
}
@ -373,7 +375,12 @@ class GrpcUnconverterGenerator(GrpcConverterGenerator):
for f in field.fields:
name = f.normalized()
if isinstance(f, ArrayField):
self.write(f"{name}: c.{name}.iter().map(|s| s.into()).collect(),\n", numindent=3)
typ = f.itemtype.typename
mapping = {
'hex': f'hex::decode(s).unwrap()',
'u32': f's.clone()',
}.get(typ, f's.into()')
self.write(f"{name}: c.{name}.iter().map(|s| {mapping}).collect(),\n", numindent=3)
elif isinstance(f, EnumField):
if f.required:

View File

@ -1,5 +1,6 @@
from typing import List, Union, Optional
import logging
from copy import copy
logger = logging.getLogger(__name__)
@ -18,7 +19,7 @@ class FieldName:
"type": "item_type"
}.get(self.name, self.name)
name = name.replace(' ', '_').replace('-', '_')
name = name.replace(' ', '_').replace('-', '_').replace('[]', '')
return name
def __str__(self):
@ -133,8 +134,12 @@ class CompositeField(Field):
logger.warning(f"Unmanaged {fpath}, it is deprecated")
continue
if 'oneOf' in ftype:
field = UnionField.from_js(ftype, fpath)
if fpath in overrides:
field = copy(overrides[fpath])
field.path = fpath
field.description = desc
if isinstance(field, ArrayField):
field.itemtype.path = fpath
elif "type" not in ftype:
logger.warning(f"Unmanaged {fpath}, it doesn't have a type")
@ -320,11 +325,6 @@ class ArrayField(Field):
itemtype, dims=dims, path=path, description=js.get("description", "")
)
def normalized(self):
# Strip the '[]' that we use to signal an array. The name
# itself doesn't need this.
return Field.normalized(self)[:-2]
class Command:
def __init__(self, name, fields):
@ -336,6 +336,23 @@ class Command:
return f"Command[name={self.name}, fields=[{fieldnames}]]"
InvoiceLabelField = PrimitiveField("string", None, None)
DatastoreKeyField = ArrayField(itemtype=PrimitiveField("string", None, None), dims=1, path=None, description=None)
InvoiceExposeprivatechannelsField = PrimitiveField("boolean", None, None)
PayExclude = ArrayField(itemtype=PrimitiveField("string", None, None), dims=1, path=None, description=None)
# Override fields with manually managed types, fieldpath -> field mapping
overrides = {
'Invoice.label': InvoiceLabelField,
'DelInvoice.label': InvoiceLabelField,
'ListInvoices.label': InvoiceLabelField,
'Datastore.key': DatastoreKeyField,
'DelDatastore.key': DatastoreKeyField,
'ListDatastore.key': DatastoreKeyField,
'Invoice.exposeprivatechannels': InvoiceExposeprivatechannelsField,
'Pay.exclude': PayExclude,
}
def parse_doc(command, js) -> Union[CompositeField, Command]:
"""Given a command name and its schema, generate the IR model"""
path = command

View File

@ -25,6 +25,7 @@ overrides = {
'ListPeers.peers[].channels[].features[]': "string",
'ListFunds.channels[].state': 'ChannelState',
'ListTransactions.transactions[].type[]': None,
'Invoice.exposeprivatechannels': None,
}
# A map of schema type to rust primitive types.
@ -43,6 +44,8 @@ typemap = {
'float': 'f32',
'utxo': 'Utxo',
'feerate': 'Feerate',
'outpoint': 'Outpoint',
'outputdesc': 'OutputDesc',
}
header = f"""#![allow(non_camel_case_types)]
@ -123,7 +126,7 @@ def gen_enum(e):
if e.required:
defi = f" // Path `{e.path}`\n #[serde(rename = \"{e.name}\")]\n pub {e.name.normalized()}: {typename},\n"
else:
defi = f' #[serde(skip_serializing_if = "Option::is_none")]'
defi = f' #[serde(skip_serializing_if = "Option::is_none")]\n'
defi = f" pub {e.name.normalized()}: Option<{typename}>,\n"
return defi, decl
@ -148,17 +151,16 @@ def gen_array(a):
logger.debug(f"Generating array field {a.name} -> {name} ({a.path})")
_, decl = gen_field(a.itemtype)
if isinstance(a.itemtype, PrimitiveField):
if a.path in overrides:
decl = "" # No declaration if we have an override
itemtype = overrides[a.path]
elif isinstance(a.itemtype, PrimitiveField):
itemtype = a.itemtype.typename
elif isinstance(a.itemtype, CompositeField):
itemtype = a.itemtype.typename
elif isinstance(a.itemtype, EnumField):
itemtype = a.itemtype.typename
if a.path in overrides:
decl = "" # No declaration if we have an override
itemtype = overrides[a.path]
if itemtype is None:
return ("", "") # Override said not to include