diff --git a/contrib/pyln-testing/pyln/testing/utils.py b/contrib/pyln-testing/pyln/testing/utils.py index e5d95404a..beea7a966 100644 --- a/contrib/pyln-testing/pyln/testing/utils.py +++ b/contrib/pyln-testing/pyln/testing/utils.py @@ -7,10 +7,10 @@ from pyln.client import RpcError from pyln.testing.btcproxy import BitcoinRpcProxy from collections import OrderedDict from decimal import Decimal -from ephemeral_port_reserve import reserve # type: ignore from pyln.client import LightningRpc from pyln.client import Millisatoshi +import ephemeral_port_reserve # type: ignore import json import logging import lzma @@ -146,6 +146,27 @@ def get_tx_p2wsh_outnum(bitcoind, tx, amount): return None +unused_port_lock = threading.Lock() +unused_port_set = set() + + +def reserve_unused_port(): + """Get an unused port: avoids handing out the same port unless it's been + returned""" + with unused_port_lock: + while True: + port = ephemeral_port_reserve.reserve() + if port not in unused_port_set: + break + unused_port_set.add(port) + + return port + + +def drop_unused_port(port): + unused_port_set.remove(port) + + class TailableProc(object): """A monitorable process that we can start, stop and tail. @@ -367,7 +388,10 @@ class BitcoinD(TailableProc): TailableProc.__init__(self, bitcoin_dir, verbose=False) if rpcport is None: - rpcport = reserve() + self.reserved_rpcport = reserve_unused_port() + rpcport = self.reserved_rpcport + else: + self.reserved_rpcport = None self.bitcoin_dir = bitcoin_dir self.rpcport = rpcport @@ -398,6 +422,10 @@ class BitcoinD(TailableProc): self.rpc = SimpleBitcoinProxy(btc_conf_file=self.conf_file) self.proxies = [] + def __del__(self): + if self.reserved_rpcport is not None: + drop_unused_port(self.reserved_rpcport) + def start(self): TailableProc.start(self) self.wait_for_log("Done loading", timeout=TIMEOUT) @@ -1281,6 +1309,7 @@ class NodeFactory(object): self.testname = testname self.next_id = 1 self.nodes = [] + self.reserved_ports = [] self.executor = executor self.bitcoind = bitcoind self.directory = directory @@ -1313,10 +1342,6 @@ class NodeFactory(object): cli_opts = {k: v for k, v in opts.items() if k not in node_opt_keys} return node_opts, cli_opts - def get_next_port(self): - with self.lock: - return reserve() - def get_node_id(self): """Generate a unique numeric ID for a lightning node """ @@ -1361,7 +1386,7 @@ class NodeFactory(object): expect_fail=False, cleandir=True, **kwargs): self.throttler.wait() node_id = self.get_node_id() if not node_id else node_id - port = self.get_next_port() + port = reserve_unused_port() lightning_dir = os.path.join( self.directory, "lightning-{}/".format(node_id)) @@ -1383,6 +1408,7 @@ class NodeFactory(object): node.set_feerates(feerates, False) self.nodes.append(node) + self.reserved_ports.append(port) if dbfile: out = open(os.path.join(node.daemon.lightning_dir, TEST_NETWORK, 'lightningd.sqlite3'), 'xb') @@ -1495,4 +1521,7 @@ class NodeFactory(object): json.dumps(leaks, sort_keys=True, indent=4) )) + for p in self.reserved_ports: + drop_unused_port(p) + return not unexpected_fail, err_msgs diff --git a/tests/test_misc.py b/tests/test_misc.py index 1995a0bd1..6fb65f324 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -2104,7 +2104,7 @@ def test_unix_socket_path_length(node_factory, bitcoind, directory, executor, db os.makedirs(lightning_dir) db = db_provider.get_db(lightning_dir, "test_unix_socket_path_length", 1) - l1 = LightningNode(1, lightning_dir, bitcoind, executor, VALGRIND, db=db, port=node_factory.get_next_port()) + l1 = LightningNode(1, lightning_dir, bitcoind, executor, VALGRIND, db=db, port=reserve()) # `LightningNode.start()` internally calls `LightningRpc.getinfo()` which # exercises the socket logic, and raises an issue if it fails.