msggen: Add classes for MethodName and TypeName

This is required for types and methods with names that need
post-processing (`bkpr-listincome`).
This commit is contained in:
Christian Decker 2024-01-18 13:25:18 +01:00
parent 85b79bc2e1
commit 19af808f45
2 changed files with 46 additions and 11 deletions

View file

@ -1,5 +1,5 @@
# A grpc model # A grpc model
from msggen.model import ArrayField, Field, CompositeField, EnumField, PrimitiveField, Service from msggen.model import ArrayField, Field, CompositeField, EnumField, PrimitiveField, Service, MethodName, TypeName
from msggen.gen import IGenerator from msggen.gen import IGenerator
from typing import TextIO, List, Dict, Any from typing import TextIO, List, Dict, Any
from textwrap import indent, dedent from textwrap import indent, dedent
@ -60,9 +60,11 @@ class GrpcGenerator(IGenerator):
else: else:
self.dest.write(text) self.dest.write(text)
def field2number(self, message_name, field): def field2number(self, message_name: TypeName, field):
m = self.meta['grpc-field-map'] m = self.meta['grpc-field-map']
message_name = message_name.name # TypeName is not JSON-serializable, use the unaltered name.
# Wrap each field mapping by the message_name, since otherwise # Wrap each field mapping by the message_name, since otherwise
# requests and responses share the same number space (just # requests and responses share the same number space (just
# cosmetic really, but why not do it?) # cosmetic really, but why not do it?)
@ -94,11 +96,14 @@ class GrpcGenerator(IGenerator):
for f in fields: for f in fields:
yield (self.field2number(message_name, f), f) yield (self.field2number(message_name, f), f)
def enumvar2number(self, typename, variant): def enumvar2number(self, typename: TypeName, variant):
"""Find an existing variant number of generate a new one. """Find an existing variant number of generate a new one.
If we don't have a variant number yet we'll just take the If we don't have a variant number yet we'll just take the
largest one assigned so far and increment it by 1. """ largest one assigned so far and increment it by 1. """
typename = str(typename.name)
m = self.meta['grpc-enum-map'] m = self.meta['grpc-enum-map']
variant = str(variant) variant = str(variant)
if typename not in m: if typename not in m:
@ -149,7 +154,7 @@ class GrpcGenerator(IGenerator):
""") """)
for method in service.methods: for method in service.methods:
mname = method_name_overrides.get(method.name, method.name) mname = MethodName(method_name_overrides.get(method.name, method.name))
self.write( self.write(
f" rpc {mname}({method.request.typename}) returns ({method.response.typename}) {{}}\n", f" rpc {mname}({method.request.typename}) returns ({method.response.typename}) {{}}\n",
cleanup=False, cleanup=False,
@ -202,7 +207,7 @@ class GrpcGenerator(IGenerator):
typename = f.override(f.typename) typename = f.override(f.typename)
self.write(f"\t{opt}{typename} {f.normalized()} = {i};\n", False) self.write(f"\t{opt}{typename} {f.normalized()} = {i};\n", False)
self.write(f"""}} self.write("""}
""") """)
def generate(self, service: Service) -> None: def generate(self, service: Service) -> None:
@ -250,7 +255,7 @@ class GrpcConverterGenerator(IGenerator):
elif isinstance(f, CompositeField): elif isinstance(f, CompositeField):
self.generate_composite(prefix, f) self.generate_composite(prefix, f)
pbname = self.to_camel_case(field.typename) pbname = self.to_camel_case(str(field.typename))
# If any of the field accesses would result in a deprecated # If any of the field accesses would result in a deprecated
# warning we mark the construction here to allow deprecated # warning we mark the construction here to allow deprecated
@ -421,7 +426,7 @@ class GrpcUnconverterGenerator(GrpcConverterGenerator):
has_deprecated = any([f.deprecated for f in field.fields]) has_deprecated = any([f.deprecated for f in field.fields])
deprecated = ",deprecated" if has_deprecated else "" deprecated = ",deprecated" if has_deprecated else ""
pbname = self.to_camel_case(field.typename) pbname = self.to_camel_case(str(field.typename))
# And now we can convert the current field: # And now we can convert the current field:
self.write(f"""\ self.write(f"""\
#[allow(unused_variables{deprecated})] #[allow(unused_variables{deprecated})]

View file

@ -26,6 +26,36 @@ class FieldName:
return self.name return self.name
class TypeName:
def __init__(self, name: Optional[str]):
if name is None:
raise ValueError("empty typename")
self.name = name
def __str__(self) -> str:
"""Return the normalized typename."""
return (
self.name
.replace(' ', '_')
.replace('-', '')
.replace('/', '_')
)
def __repr__(self) -> str:
return f"Typename[raw={self.name}, str={self}"
def __iadd__(self, other):
self.name += str(other)
return self
def __lt__(self, other) -> bool:
return str(self.name) < str(other)
class MethodName(TypeName):
"""A class encapsulating the naming rules for methods. """
class Field: class Field:
def __init__( def __init__(
self, self,
@ -140,7 +170,7 @@ class Method:
class CompositeField(Field): class CompositeField(Field):
def __init__( def __init__(
self, self,
typename, typename: TypeName,
fields, fields,
path, path,
description, description,
@ -159,7 +189,7 @@ class CompositeField(Field):
@classmethod @classmethod
def from_js(cls, js, path): def from_js(cls, js, path):
typename = path2type(path) typename = TypeName(path2type(path))
properties = js.get("properties", {}) properties = js.get("properties", {})
# Ok, let's flatten the conditional properties. We do this by # Ok, let's flatten the conditional properties. We do this by
@ -257,7 +287,7 @@ class EnumVariant(Field):
class EnumField(Field): class EnumField(Field):
def __init__(self, typename, values, path, description, added, deprecated): def __init__(self, typename: TypeName, values, path, description, added, deprecated):
Field.__init__(self, path, description, added=added, deprecated=deprecated) Field.__init__(self, path, description, added=added, deprecated=deprecated)
self.typename = typename self.typename = typename
self.values = values self.values = values
@ -266,7 +296,7 @@ class EnumField(Field):
@classmethod @classmethod
def from_js(cls, js, path): def from_js(cls, js, path):
# Transform the path into something that is a valid TypeName # Transform the path into something that is a valid TypeName
typename = path2type(path) typename = TypeName(path2type(path))
return EnumField( return EnumField(
typename, typename,
values=filter(lambda i: i is not None, js["enum"]), values=filter(lambda i: i is not None, js["enum"]),