pylightning: provide a class for Lightning JSONDecoder.

Some JSON functions want a *class*, not just a hook, so provide one.
To make it clear that we want an encoding *class* and a decoding *object*,
rename the UnixDomainSocketRpc encode parameter to encode_cls.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2019-02-25 14:45:56 +10:30
parent 464858883b
commit 5a7d038e6e

View File

@ -121,9 +121,9 @@ class Millisatoshi:
class UnixDomainSocketRpc(object):
def __init__(self, socket_path, executor=None, logger=logging, encoder=json.JSONEncoder, decoder=json.JSONDecoder):
def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder()):
self.socket_path = socket_path
self.encoder = encoder
self.encoder_cls = encoder_cls
self.decoder = decoder
self.executor = executor
self.logger = logger
@ -133,7 +133,7 @@ class UnixDomainSocketRpc(object):
self.next_id = 0
def _writeobj(self, sock, obj):
s = json.dumps(obj, cls=self.encoder)
s = json.dumps(obj, cls=self.encoder_cls)
sock.sendall(bytearray(s, 'UTF-8'))
def _readobj_compat(self, sock, buff=b''):
@ -245,32 +245,39 @@ class LightningRpc(UnixDomainSocketRpc):
pass
return json.JSONEncoder.default(self, o)
@staticmethod
def lightning_json_hook(json_object):
return json_object
class LightningJSONDecoder(json.JSONDecoder):
def __init__(self, *, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, strict=True, object_pairs_hook=None):
self.object_hook_next = object_hook
super().__init__(object_hook=self.millisatoshi_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, strict=strict, object_pairs_hook=object_pairs_hook)
@staticmethod
def replace_amounts(obj):
"""
Recursively replace _msat fields with appropriate values with Millisatoshi.
"""
if isinstance(obj, dict):
for k, v in obj.items():
if k.endswith('msat'):
if isinstance(v, str) and v.endswith('msat'):
obj[k] = Millisatoshi(v)
# Special case for array of msat values
elif isinstance(v, list) and all(isinstance(e, str) and e.endswith('msat') for e in v):
obj[k] = [Millisatoshi(e) for e in v]
else:
obj[k] = LightningRpc.replace_amounts(v)
elif isinstance(obj, list):
obj = [LightningRpc.replace_amounts(e) for e in obj]
@staticmethod
def replace_amounts(obj):
"""
Recursively replace _msat fields with appropriate values with Millisatoshi.
"""
if isinstance(obj, dict):
for k, v in obj.items():
if k.endswith('msat'):
if isinstance(v, str) and v.endswith('msat'):
obj[k] = Millisatoshi(v)
# Special case for array of msat values
elif isinstance(v, list) and all(isinstance(e, str) and e.endswith('msat') for e in v):
obj[k] = [Millisatoshi(e) for e in v]
else:
obj[k] = LightningRpc.LightningJSONDecoder.replace_amounts(v)
elif isinstance(obj, list):
obj = [LightningRpc.LightningJSONDecoder.replace_amounts(e) for e in obj]
return obj
return obj
def millisatoshi_hook(self, obj):
obj = LightningRpc.LightningJSONDecoder.replace_amounts(obj)
if self.object_hook_next:
obj = self.object_hook_next(obj)
return obj
def __init__(self, socket_path, executor=None, logger=logging):
super().__init__(socket_path, executor, logging, self.LightningJSONEncoder, json.JSONDecoder(object_hook=self.replace_amounts))
super().__init__(socket_path, executor, logging, self.LightningJSONEncoder, self.LightningJSONDecoder())
def getpeer(self, peer_id, level=None):
"""