contrib/pyln-client: construct JSON ID correctly.

They can set their name explicitly, but if they don't we extract it from argv[0].

We also set it around callbacks, so it will be expanded by default.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2022-09-13 06:49:12 +09:30
parent d360075d22
commit f1f2c1322d
5 changed files with 43 additions and 11 deletions

View file

@ -2,6 +2,7 @@ import json
import logging
import os
import socket
import sys
from contextlib import contextmanager
from decimal import Decimal
from json import JSONEncoder
@ -277,13 +278,18 @@ class UnixSocket(object):
class UnixDomainSocketRpc(object):
def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder()):
def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder(), caller_name=None):
self.socket_path = socket_path
self.encoder_cls = encoder_cls
self.decoder = decoder
self.executor = executor
self.logger = logger
self._notify = None
if caller_name is None:
self.caller_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
else:
self.caller_name = caller_name
self.cmdprefix = None
self.next_id = 1
@ -323,7 +329,11 @@ class UnixDomainSocketRpc(object):
return self.call(name, payload=kwargs)
return wrapper
def call(self, method, payload=None):
def call(self, method, payload=None, cmdprefix=None):
"""Generic call API: you can set cmdprefix here, or set self.cmdprefix
before the call is made.
"""
self.logger.debug("Calling %s with payload %r", method, payload)
if payload is None:
@ -332,10 +342,16 @@ class UnixDomainSocketRpc(object):
if isinstance(payload, dict):
payload = {k: v for k, v in payload.items() if v is not None}
this_id = "{}:{}#{}".format(self.caller_name, method, str(self.next_id))
self.next_id += 1
# FIXME: we open a new socket for every readobj call...
sock = UnixSocket(self.socket_path)
this_id = self.next_id
self.next_id += 0
if cmdprefix is None:
cmdprefix = self.cmdprefix
if cmdprefix:
this_id = cmdprefix + '/' + this_id
buf = b''
if self._notify is not None:
@ -343,7 +359,7 @@ class UnixDomainSocketRpc(object):
self._writeobj(sock, {
"jsonrpc": "2.0",
"method": "notifications",
"id": 0,
"id": this_id + "+notify-enable",
"params": {
"enable": True
},

View file

@ -607,17 +607,25 @@ class Plugin(object):
def _exec_func(self, func: Callable[..., Any],
request: Request) -> JSONType:
# By default, any RPC calls this makes will have JSON id prefixed by incoming id.
if self.rpc:
self.rpc.cmdprefix = request.id
params = request.params
if isinstance(params, list):
ba = self._bind_pos(func, params, request)
return func(*ba.args, **ba.kwargs)
ret = func(*ba.args, **ba.kwargs)
elif isinstance(params, dict):
ba = self._bind_kwargs(func, params, request)
return func(*ba.args, **ba.kwargs)
ret = func(*ba.args, **ba.kwargs)
else:
if self.rpc:
self.rpc.cmdprefix = None
raise TypeError(
"Parameters to function call must be either a dict or a list."
)
if self.rpc:
self.rpc.cmdprefix = None
return ret
def _dispatch_request(self, request: Request) -> None:
name = request.method

View file

@ -4,7 +4,9 @@ from pyln.client import RpcError, Millisatoshi
from utils import only_one, wait_for, wait_channel_quiescent, mine_funding_to_announce
import os
import pytest
import sys
import time
import unittest
@ -18,7 +20,8 @@ def test_invoice(node_factory, chainparams):
inv = l1.rpc.invoice(123000, 'label', 'description', 3700, [addr1, addr2])
# Side note: invoice calls out to listincoming, so check JSON id is as expected
l1.daemon.wait_for_log(": OUT:id=1/cln:listincoming#[0-9]*")
myname = os.path.splitext(os.path.basename(sys.argv[0]))[0]
l1.daemon.wait_for_log(": OUT:id={}:invoice#[0-9]*/cln:listincoming#[0-9]*".format(myname))
after = int(time.time())
b11 = l1.rpc.decodepay(inv['bolt11'])

View file

@ -24,6 +24,7 @@ import signal
import sqlite3
import stat
import subprocess
import sys
import time
import unittest
@ -1486,8 +1487,10 @@ def test_libplugin(node_factory):
l1.rpc.plugin_start(plugin)
l1.rpc.check("helloworld")
myname = os.path.splitext(os.path.basename(sys.argv[0]))[0]
# Side note: getmanifest will trace back to plugin_start
l1.daemon.wait_for_log(": OUT:id=[0-9]*/cln:getmanifest#[0-9]*")
l1.daemon.wait_for_log(": OUT:id={}:plugin#[0-9]*/cln:getmanifest#[0-9]*".format(myname))
# Test commands
assert l1.rpc.call("helloworld") == {"hello": "world"}
@ -1503,7 +1506,7 @@ def test_libplugin(node_factory):
# Test hooks and notifications (add plugin, so we can test hook id)
l2 = node_factory.get_node(options={"plugin": plugin})
l2.connect(l1)
l2.daemon.wait_for_log(": OUT:id=[0-9]*/cln:peer_connected#[0-9]*")
l2.daemon.wait_for_log(": OUT:id={}:connect#[0-9]*/cln:peer_connected#[0-9]*".format(myname))
l1.daemon.wait_for_log("{} peer_connected".format(l2.info["id"]))
l1.daemon.wait_for_log("{} connected".format(l2.info["id"]))

View file

@ -12,6 +12,7 @@ from utils import (
import os
import pytest
import subprocess
import sys
import time
import unittest
@ -60,7 +61,8 @@ def test_withdraw(node_factory, bitcoind):
out = l1.rpc.withdraw(waddr, 2 * amount)
# Side note: sendrawtransaction will trace back to withdrawl
l1.daemon.wait_for_log(": OUT:id=[0-9]*/cln:withdraw#[0-9]*/txprepare:sendpsbt#[0-9]*/cln:sendrawtransaction#[0-9]*")
myname = os.path.splitext(os.path.basename(sys.argv[0]))[0]
l1.daemon.wait_for_log(": OUT:id={}:withdraw#[0-9]*/cln:withdraw#[0-9]*/txprepare:sendpsbt#[0-9]*/cln:sendrawtransaction#[0-9]*".format(myname))
# Make sure bitcoind received the withdrawal
unspent = l1.bitcoin.rpc.listunspent(0)