diff --git a/contrib/pyln-proto/pyln/proto/message/message.py b/contrib/pyln-proto/pyln/proto/message/message.py index c40464c66..3877a117e 100644 --- a/contrib/pyln-proto/pyln/proto/message/message.py +++ b/contrib/pyln-proto/pyln/proto/message/message.py @@ -11,12 +11,13 @@ class MessageNamespace(object): domain, such as within a given BOLT""" def __init__(self, csv_lines=[]): self.subtypes = {} + self.fundamentaltypes = {} self.tlvtypes = {} self.messagetypes = {} # For convenience, basic types go in every namespace for t in fundamental_types(): - self.add_subtype(t) + self.add_fundamentaltype(t) self.load_csv(csv_lines) @@ -26,6 +27,10 @@ domain, such as within a given BOLT""" return ValueError('Already have {}'.format(prev)) self.subtypes[t.name] = t + def add_fundamentaltype(self, t): + assert not self.get_type(t.name) + self.fundamentaltypes[t.name] = t + def add_tlvtype(self, t): prev = self.get_type(t.name) if prev: @@ -51,6 +56,11 @@ domain, such as within a given BOLT""" return m return None + def get_fundamentaltype(self, name): + if name in self.fundamentaltypes: + return self.fundamentaltypes[name] + return None + def get_subtype(self, name): if name in self.subtypes: return self.subtypes[name] @@ -62,7 +72,9 @@ domain, such as within a given BOLT""" return None def get_type(self, name): - t = self.get_subtype(name) + t = self.get_fundamentaltype(name) + if not t: + t = self.get_subtype(name) if not t: t = self.get_tlvtype(name) return t