contrib/pyln-client: construct JSON ID correctly.

They can set their name explicitly, but if they don't we extract it from argv[0].

We also set it around callbacks, so it will be expanded by default.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2022-09-13 06:49:12 +09:30
parent d360075d22
commit f1f2c1322d
5 changed files with 43 additions and 11 deletions

View file

@ -2,6 +2,7 @@ import json
import logging import logging
import os import os
import socket import socket
import sys
from contextlib import contextmanager from contextlib import contextmanager
from decimal import Decimal from decimal import Decimal
from json import JSONEncoder from json import JSONEncoder
@ -277,13 +278,18 @@ class UnixSocket(object):
class UnixDomainSocketRpc(object): class UnixDomainSocketRpc(object):
def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder()): def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder(), caller_name=None):
self.socket_path = socket_path self.socket_path = socket_path
self.encoder_cls = encoder_cls self.encoder_cls = encoder_cls
self.decoder = decoder self.decoder = decoder
self.executor = executor self.executor = executor
self.logger = logger self.logger = logger
self._notify = None self._notify = None
if caller_name is None:
self.caller_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
else:
self.caller_name = caller_name
self.cmdprefix = None
self.next_id = 1 self.next_id = 1
@ -323,7 +329,11 @@ class UnixDomainSocketRpc(object):
return self.call(name, payload=kwargs) return self.call(name, payload=kwargs)
return wrapper return wrapper
def call(self, method, payload=None): def call(self, method, payload=None, cmdprefix=None):
"""Generic call API: you can set cmdprefix here, or set self.cmdprefix
before the call is made.
"""
self.logger.debug("Calling %s with payload %r", method, payload) self.logger.debug("Calling %s with payload %r", method, payload)
if payload is None: if payload is None:
@ -332,10 +342,16 @@ class UnixDomainSocketRpc(object):
if isinstance(payload, dict): if isinstance(payload, dict):
payload = {k: v for k, v in payload.items() if v is not None} payload = {k: v for k, v in payload.items() if v is not None}
this_id = "{}:{}#{}".format(self.caller_name, method, str(self.next_id))
self.next_id += 1
# FIXME: we open a new socket for every readobj call... # FIXME: we open a new socket for every readobj call...
sock = UnixSocket(self.socket_path) sock = UnixSocket(self.socket_path)
this_id = self.next_id if cmdprefix is None:
self.next_id += 0 cmdprefix = self.cmdprefix
if cmdprefix:
this_id = cmdprefix + '/' + this_id
buf = b'' buf = b''
if self._notify is not None: if self._notify is not None:
@ -343,7 +359,7 @@ class UnixDomainSocketRpc(object):
self._writeobj(sock, { self._writeobj(sock, {
"jsonrpc": "2.0", "jsonrpc": "2.0",
"method": "notifications", "method": "notifications",
"id": 0, "id": this_id + "+notify-enable",
"params": { "params": {
"enable": True "enable": True
}, },

View file

@ -607,17 +607,25 @@ class Plugin(object):
def _exec_func(self, func: Callable[..., Any], def _exec_func(self, func: Callable[..., Any],
request: Request) -> JSONType: request: Request) -> JSONType:
# By default, any RPC calls this makes will have JSON id prefixed by incoming id.
if self.rpc:
self.rpc.cmdprefix = request.id
params = request.params params = request.params
if isinstance(params, list): if isinstance(params, list):
ba = self._bind_pos(func, params, request) ba = self._bind_pos(func, params, request)
return func(*ba.args, **ba.kwargs) ret = func(*ba.args, **ba.kwargs)
elif isinstance(params, dict): elif isinstance(params, dict):
ba = self._bind_kwargs(func, params, request) ba = self._bind_kwargs(func, params, request)
return func(*ba.args, **ba.kwargs) ret = func(*ba.args, **ba.kwargs)
else: else:
if self.rpc:
self.rpc.cmdprefix = None
raise TypeError( raise TypeError(
"Parameters to function call must be either a dict or a list." "Parameters to function call must be either a dict or a list."
) )
if self.rpc:
self.rpc.cmdprefix = None
return ret
def _dispatch_request(self, request: Request) -> None: def _dispatch_request(self, request: Request) -> None:
name = request.method name = request.method

View file

@ -4,7 +4,9 @@ from pyln.client import RpcError, Millisatoshi
from utils import only_one, wait_for, wait_channel_quiescent, mine_funding_to_announce from utils import only_one, wait_for, wait_channel_quiescent, mine_funding_to_announce
import os
import pytest import pytest
import sys
import time import time
import unittest import unittest
@ -18,7 +20,8 @@ def test_invoice(node_factory, chainparams):
inv = l1.rpc.invoice(123000, 'label', 'description', 3700, [addr1, addr2]) inv = l1.rpc.invoice(123000, 'label', 'description', 3700, [addr1, addr2])
# Side note: invoice calls out to listincoming, so check JSON id is as expected # Side note: invoice calls out to listincoming, so check JSON id is as expected
l1.daemon.wait_for_log(": OUT:id=1/cln:listincoming#[0-9]*") myname = os.path.splitext(os.path.basename(sys.argv[0]))[0]
l1.daemon.wait_for_log(": OUT:id={}:invoice#[0-9]*/cln:listincoming#[0-9]*".format(myname))
after = int(time.time()) after = int(time.time())
b11 = l1.rpc.decodepay(inv['bolt11']) b11 = l1.rpc.decodepay(inv['bolt11'])

View file

@ -24,6 +24,7 @@ import signal
import sqlite3 import sqlite3
import stat import stat
import subprocess import subprocess
import sys
import time import time
import unittest import unittest
@ -1486,8 +1487,10 @@ def test_libplugin(node_factory):
l1.rpc.plugin_start(plugin) l1.rpc.plugin_start(plugin)
l1.rpc.check("helloworld") l1.rpc.check("helloworld")
myname = os.path.splitext(os.path.basename(sys.argv[0]))[0]
# Side note: getmanifest will trace back to plugin_start # Side note: getmanifest will trace back to plugin_start
l1.daemon.wait_for_log(": OUT:id=[0-9]*/cln:getmanifest#[0-9]*") l1.daemon.wait_for_log(": OUT:id={}:plugin#[0-9]*/cln:getmanifest#[0-9]*".format(myname))
# Test commands # Test commands
assert l1.rpc.call("helloworld") == {"hello": "world"} assert l1.rpc.call("helloworld") == {"hello": "world"}
@ -1503,7 +1506,7 @@ def test_libplugin(node_factory):
# Test hooks and notifications (add plugin, so we can test hook id) # Test hooks and notifications (add plugin, so we can test hook id)
l2 = node_factory.get_node(options={"plugin": plugin}) l2 = node_factory.get_node(options={"plugin": plugin})
l2.connect(l1) l2.connect(l1)
l2.daemon.wait_for_log(": OUT:id=[0-9]*/cln:peer_connected#[0-9]*") l2.daemon.wait_for_log(": OUT:id={}:connect#[0-9]*/cln:peer_connected#[0-9]*".format(myname))
l1.daemon.wait_for_log("{} peer_connected".format(l2.info["id"])) l1.daemon.wait_for_log("{} peer_connected".format(l2.info["id"]))
l1.daemon.wait_for_log("{} connected".format(l2.info["id"])) l1.daemon.wait_for_log("{} connected".format(l2.info["id"]))

View file

@ -12,6 +12,7 @@ from utils import (
import os import os
import pytest import pytest
import subprocess import subprocess
import sys
import time import time
import unittest import unittest
@ -60,7 +61,8 @@ def test_withdraw(node_factory, bitcoind):
out = l1.rpc.withdraw(waddr, 2 * amount) out = l1.rpc.withdraw(waddr, 2 * amount)
# Side note: sendrawtransaction will trace back to withdrawl # Side note: sendrawtransaction will trace back to withdrawl
l1.daemon.wait_for_log(": OUT:id=[0-9]*/cln:withdraw#[0-9]*/txprepare:sendpsbt#[0-9]*/cln:sendrawtransaction#[0-9]*") myname = os.path.splitext(os.path.basename(sys.argv[0]))[0]
l1.daemon.wait_for_log(": OUT:id={}:withdraw#[0-9]*/cln:withdraw#[0-9]*/txprepare:sendpsbt#[0-9]*/cln:sendrawtransaction#[0-9]*".format(myname))
# Make sure bitcoind received the withdrawal # Make sure bitcoind received the withdrawal
unspent = l1.bitcoin.rpc.listunspent(0) unspent = l1.bitcoin.rpc.listunspent(0)