msggen: Provide parent in visit-method of patch

In the next commit I'll change the behavior of `OptionalPatch`.
The changes require me to have access to the `parent` of a field.

Splitting it up in a separate commit makes it easier to review.
You can run `msggen` against this version and the previous `version`.

I've tested it. It returns exactly the same output.
This commit is contained in:
Erik De Smedt 2024-02-08 10:35:51 +01:00 committed by Christian Decker
parent 15276b7ff5
commit 6e61dde41e

View file

@ -1,4 +1,5 @@
from abc import ABC
from typing import Optional
from msggen import model
@ -10,9 +11,8 @@ class Patch(ABC):
"""
def visit(self, field: model.Field) -> None:
"""Gets called for each node in the model.
"""
def visit(self, field: model.Field, parent: Optional[model.Field] = None) -> None:
"""Gets called for each node in the model."""
pass
def apply(self, service: model.Service) -> None:
@ -20,17 +20,20 @@ class Patch(ABC):
pre-order on each node in the schema tree.
"""
def recurse(f: model.Field):
# First recurse if we have further type definitions
self.visit(f)
if isinstance(f, model.ArrayField):
self.visit(f.itemtype)
self.visit(f.itemtype, f)
recurse(f.itemtype)
elif isinstance(f, model.CompositeField):
for c in f.fields:
self.visit(c)
self.visit(c, f)
recurse(c)
# Now visit ourselves
self.visit(f)
for m in service.methods:
recurse(m.request)
recurse(m.response)
@ -52,12 +55,11 @@ class VersionAnnotationPatch(Patch):
"""
def __init__(self, meta) -> None:
"""Create a patch that can annotate `added` and `deprecated`
"""
"""Create a patch that can annotate `added` and `deprecated`"""
self.meta = meta
def visit(self, f: model.Field) -> None:
m = self.meta['model-field-versions'].get(f.path, {})
def visit(self, f: model.Field, parent: Optional[model.Field] = None) -> None:
m = self.meta["model-field-versions"].get(f.path, {})
# The following lines are used to backfill fields that predate
# the introduction, so they need to use a default version to
@ -67,15 +69,17 @@ class VersionAnnotationPatch(Patch):
# if f.added is None and 'added' not in m:
# m['added'] = 'pre-v0.10.1'
added = m.get('added', None)
deprecated = m.get('deprecated', None)
added = m.get("added", None)
deprecated = m.get("deprecated", None)
assert added or f.added, f"Field {f.path} does not have an `added` annotation"
# We do not allow the added and deprecated flags to be
# modified after the fact.
if f.added and added and f.added != m['added']:
raise ValueError(f"Field {f.path} changed `added` annotation: {f.added} != {m['added']}")
if f.added and added and f.added != m["added"]:
raise ValueError(
f"Field {f.path} changed `added` annotation: {f.added} != {m['added']}"
)
if f.deprecated:
# We don't care about finishing value.
@ -83,7 +87,9 @@ class VersionAnnotationPatch(Patch):
assert len(f.deprecated) == 2
f.deprecated = f.deprecated[0]
if f.deprecated != deprecated:
raise ValueError(f"Field {f.path} changed `deprecated` annotation: {f.deprecated} vs {deprecated}")
raise ValueError(
f"Field {f.path} changed `deprecated` annotation: {f.deprecated} vs {deprecated}"
)
if f.added is None:
f.added = added
@ -91,9 +97,9 @@ class VersionAnnotationPatch(Patch):
f.deprecated = deprecated
# Backfill the metadata using the annotation
self.meta['model-field-versions'][f.path] = {
'added': f.added,
'deprecated': f.deprecated,
self.meta["model-field-versions"][f.path] = {
"added": f.added,
"deprecated": f.deprecated,
}
@ -108,10 +114,10 @@ class OptionalPatch(Patch):
@staticmethod
def version_to_number(version):
# Dummy versions collecting all fields that predate the versioning.
if version == 'pre-v0.10.1':
if version == "pre-v0.10.1":
return 0
assert version[0] == 'v'
parts = version[1:].split('.')
assert version[0] == "v"
parts = version[1:].split(".")
# Months, plus 10 for minor versions.
num = (int(parts[0]) * 12 + int(parts[1])) * 10
@ -126,9 +132,14 @@ class OptionalPatch(Patch):
fields more stringent.
"""
return OptionalPatch.version_to_number('v0.10.1')
return OptionalPatch.version_to_number("v0.10.1")
def visit(self, f: model.Field, parent: Optional[model.Field] = None) -> None:
# Return if the optional field has been set already
if "optional" in dir(f):
if f.optional is not None:
return
def visit(self, f: model.Field) -> None:
# Default to false, and then overwrite it if required.
f.optional = False
if not f.required:
@ -144,60 +155,58 @@ class OptionalPatch(Patch):
class OverridePatch(Patch):
"""Allows omitting some fields and overriding the type of fields based on configuration.
"""Allows omitting some fields and overriding the type of fields based on configuration."""
"""
omit = [
'Decode.invoice_paths[]',
'Decode.invoice_paths[].payinfo',
'Decode.offer_paths[].path[]',
'Decode.offer_recurrence',
'Decode.routes[][]',
'Decode.unknown_invoice_request_tlvs[]',
'Decode.unknown_invoice_tlvs[]',
'Decode.unknown_offer_tlvs[]',
'DecodePay.routes[][]',
'DecodeRoutes.routes',
'Invoice.exposeprivatechannels',
'ListClosedChannels.closedchannels[].channel_type',
'ListPeerChannels.channels[].channel_type',
'ListPeerChannels.channels[].features[]',
'ListPeerChannels.channels[].state_changes[]',
'ListPeers.peers[].channels[].state_changes[]',
'ListTransactions.transactions[].type[]',
"Decode.invoice_paths[]",
"Decode.invoice_paths[].payinfo",
"Decode.offer_paths[].path[]",
"Decode.offer_recurrence",
"Decode.routes[][]",
"Decode.unknown_invoice_request_tlvs[]",
"Decode.unknown_invoice_tlvs[]",
"Decode.unknown_offer_tlvs[]",
"DecodePay.routes[][]",
"DecodeRoutes.routes",
"Invoice.exposeprivatechannels",
"ListClosedChannels.closedchannels[].channel_type",
"ListPeerChannels.channels[].channel_type",
"ListPeerChannels.channels[].features[]",
"ListPeerChannels.channels[].state_changes[]",
"ListPeers.peers[].channels[].state_changes[]",
"ListTransactions.transactions[].type[]",
]
# Handcoded types to use instead of generating the types from the
# schema. Useful for repeated types, and types that have
# redundancies.
overrides = {
'ListClosedChannels.closedchannels[].closer': "ChannelSide",
'ListClosedChannels.closedchannels[].opener': "ChannelSide",
'ListFunds.channels[].state': 'ChannelState',
'ListPeerChannels.channels[].closer': "ChannelSide",
'ListPeerChannels.channels[].opener': "ChannelSide",
'ListPeers.peers[].channels[].closer': "ChannelSide",
'ListPeers.peers[].channels[].features[]': "string",
'ListPeers.peers[].channels[].opener': "ChannelSide",
'ListPeers.peers[].channels[].state_changes[].cause': "ChannelStateChangeCause",
'ListPeers.peers[].channels[].state_changes[].old_state': "ChannelState",
'ListPeers.peers[].channels[].htlcs[].state': "HtlcState",
'ListPeerChannels.channels[].htlcs[].state': "HtlcState",
'ListHtlcs.htlcs[].state': "HtlcState",
'FundChannel.channel_type.names[]': 'ChannelTypeName',
'FundChannel_Start.channel_type.names[]': 'ChannelTypeName',
'MultiFundChannel.channel_ids[].channel_type.names[]': 'ChannelTypeName',
'OpenChannel_Init.channel_type.names[]': 'ChannelTypeName',
'OpenChannel_Bump.channel_type.names[]': 'ChannelTypeName',
'OpenChannel_Update.channel_type.names[]': 'ChannelTypeName',
'AutoClean-Once.subsystem': "AutocleanSubsystem",
'AutoClean-Status.subsystem': "AutocleanSubsystem",
'Plugin.subcommand': 'PluginSubcommand',
'Plugin.command': 'PluginSubcommand',
"ListClosedChannels.closedchannels[].closer": "ChannelSide",
"ListClosedChannels.closedchannels[].opener": "ChannelSide",
"ListFunds.channels[].state": "ChannelState",
"ListPeerChannels.channels[].closer": "ChannelSide",
"ListPeerChannels.channels[].opener": "ChannelSide",
"ListPeers.peers[].channels[].closer": "ChannelSide",
"ListPeers.peers[].channels[].features[]": "string",
"ListPeers.peers[].channels[].opener": "ChannelSide",
"ListPeers.peers[].channels[].state_changes[].cause": "ChannelStateChangeCause",
"ListPeers.peers[].channels[].state_changes[].old_state": "ChannelState",
"ListPeers.peers[].channels[].htlcs[].state": "HtlcState",
"ListPeerChannels.channels[].htlcs[].state": "HtlcState",
"ListHtlcs.htlcs[].state": "HtlcState",
"FundChannel.channel_type.names[]": "ChannelTypeName",
"FundChannel_Start.channel_type.names[]": "ChannelTypeName",
"MultiFundChannel.channel_ids[].channel_type.names[]": "ChannelTypeName",
"OpenChannel_Init.channel_type.names[]": "ChannelTypeName",
"OpenChannel_Bump.channel_type.names[]": "ChannelTypeName",
"OpenChannel_Update.channel_type.names[]": "ChannelTypeName",
"AutoClean-Once.subsystem": "AutocleanSubsystem",
"AutoClean-Status.subsystem": "AutocleanSubsystem",
"Plugin.subcommand": "PluginSubcommand",
"Plugin.command": "PluginSubcommand",
}
def visit(self, f: model.Field) -> None:
"""For now just skips the fields we can't convert.
"""
def visit(self, f: model.Field, parent: Optional[model.Field] = None) -> None:
"""For now just skips the fields we can't convert."""
f.omitted = f.path in self.omit
f.type_override = self.overrides.get(f.path, None)