pyln-client: Add support to monkey patch the JSONEncoder

Several times we had issues with plugins not being able to re-encode an RPC
result because they forgot to use the custom encoder class. This allows us to
patch the JSONEncoder when we start the RPC or the plugin and automagically
support classes that provide a `to_json` method.
This commit is contained in:
Christian Decker 2020-06-15 12:52:42 +02:00 committed by Rusty Russell
parent 382230509b
commit 748caf91d3

View File

@ -5,6 +5,21 @@ import logging
import os
import socket
import warnings
from json import JSONEncoder
def _patched_default(self, obj):
return getattr(obj.__class__, "to_json", _patched_default.default)(obj)
def monkey_patch_json(patch=True):
is_patched = JSONEncoder.default == _patched_default
if patch and not is_patched:
_patched_default.default = JSONEncoder.default # Save unmodified
JSONEncoder.default = _patched_default # Replace it.
elif not patch and is_patched:
JSONEncoder.default = _patched_default.default
class RpcError(ValueError):
@ -327,7 +342,10 @@ class LightningRpc(UnixDomainSocketRpc):
return json.JSONEncoder.default(self, o)
class LightningJSONDecoder(json.JSONDecoder):
def __init__(self, *, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, strict=True, object_pairs_hook=None):
def __init__(self, *, object_hook=None, parse_float=None,
parse_int=None, parse_constant=None,
strict=True, object_pairs_hook=None,
patch_json=True):
self.object_hook_next = object_hook
super().__init__(object_hook=self.millisatoshi_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, strict=strict, object_pairs_hook=object_pairs_hook)
@ -357,8 +375,18 @@ class LightningRpc(UnixDomainSocketRpc):
obj = self.object_hook_next(obj)
return obj
def __init__(self, socket_path, executor=None, logger=logging):
super().__init__(socket_path, executor, logger, self.LightningJSONEncoder, self.LightningJSONDecoder())
def __init__(self, socket_path, executor=None, logger=logging,
patch_json=True):
super().__init__(
socket_path,
executor,
logger,
self.LightningJSONEncoder,
self.LightningJSONDecoder()
)
if patch_json:
monkey_patch_json(patch=True)
def autocleaninvoice(self, cycle_seconds=None, expired_by=None):
"""