pytest: Add db_provider and db instances for configurable backends

We will soon have a postgres backend as well, so we need a way to control the
postgres process and to provision DBs to the nodes. The two interfaces are the
dsn that we pass to the node, and the python query interface needed to query
from tests.

Signed-off-by: Christian Decker <decker.christian@gmail.com>
This commit is contained in:
Christian Decker 2019-08-30 18:59:53 +02:00 committed by Rusty Russell
parent 62dc8dc110
commit 96a22b4003
5 changed files with 205 additions and 29 deletions

169
tests/db.py Normal file
View File

@ -0,0 +1,169 @@
from ephemeral_port_reserve import reserve
import logging
import os
import psycopg2
import random
import shutil
import signal
import sqlite3
import string
import subprocess
import time
class Sqlite3Db(object):
def __init__(self, path):
self.path = path
def get_dsn(self):
"""SQLite3 doesn't provide a DSN, resulting in no CLI-option.
"""
return None
def query(self, query):
orig = os.path.join(self.path)
copy = self.path + ".copy"
shutil.copyfile(orig, copy)
db = sqlite3.connect(copy)
db.row_factory = sqlite3.Row
c = db.cursor()
c.execute(query)
rows = c.fetchall()
result = []
for row in rows:
result.append(dict(zip(row.keys(), row)))
db.commit()
c.close()
db.close()
return result
def execute(self, query):
db = sqlite3.connect(self.path)
c = db.cursor()
c.execute(query)
db.commit()
c.close()
db.close()
class PostgresDb(object):
def __init__(self, dbname, port):
self.dbname = dbname
self.port = port
self.conn = psycopg2.connect("dbname={dbname} user=postgres host=localhost port={port}".format(
dbname=dbname, port=port
))
cur = self.conn.cursor()
cur.execute('SELECT 1')
cur.close()
def get_dsn(self):
return "postgres://postgres:password@localhost:{port}/{dbname}".format(
port=self.port, dbname=self.dbname
)
def query(self, query):
cur = self.conn.cursor()
cur.execute(query)
# Collect the results into a list of dicts.
res = []
for r in cur:
t = {}
# Zip the column definition with the value to get its name.
for c, v in zip(cur.description, r):
t[c.name] = v
res.append(t)
cur.close()
return res
def execute(self, query):
with self.conn, self.conn.cursor() as cur:
cur.execute(query)
class SqliteDbProvider(object):
def __init__(self, directory):
self.directory = directory
def start(self):
pass
def get_db(self, node_directory, testname, node_id):
path = os.path.join(
node_directory,
'lightningd.sqlite3'
)
return Sqlite3Db(path)
def stop(self):
pass
class PostgresDbProvider(object):
def __init__(self, directory):
self.directory = directory
self.port = None
self.proc = None
print("Starting PostgresDbProvider")
def start(self):
passfile = os.path.join(self.directory, "pgpass.txt")
self.pgdir = os.path.join(self.directory, 'pgsql')
# Need to write a tiny file containing the password so `initdb` can pick it up
with open(passfile, 'w') as f:
f.write('cltest\n')
subprocess.check_call([
'/usr/lib/postgresql/10/bin/initdb',
'--pwfile={}'.format(passfile),
'--pgdata={}'.format(self.pgdir),
'--auth=trust',
'--username=postgres',
])
self.port = reserve()
self.proc = subprocess.Popen([
'/usr/lib/postgresql/10/bin/postgres',
'-k', '/tmp/', # So we don't use /var/lib/...
'-D', self.pgdir,
'-p', str(self.port),
'-F',
'-i',
])
# Hacky but seems to work ok (might want to make the postgres proc a TailableProc as well if too flaky).
time.sleep(1)
self.conn = psycopg2.connect("dbname=template1 user=postgres host=localhost port={}".format(self.port))
# Required for CREATE DATABASE to work
self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
def get_db(self, node_directory, testname, node_id):
# Random suffix to avoid collisions on repeated tests
nonce = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8))
dbname = "{}_{}_{}".format(testname, node_id, nonce)
cur = self.conn.cursor()
cur.execute("CREATE DATABASE {};".format(dbname))
cur.close()
db = PostgresDb(dbname, self.port)
return db
def stop(self):
# Send fast shutdown signal see [1] for details:
#
# SIGINT
#
# This is the Fast Shutdown mode. The server disallows new connections
# and sends all existing server processes SIGTERM, which will cause
# them to abort their current transactions and exit promptly. It then
# waits for all server processes to exit and finally shuts down. If
# the server is in online backup mode, backup mode will be terminated,
# rendering the backup useless.
#
# [1] https://www.postgresql.org/docs/9.1/server-shutdown.html
self.proc.send_signal(signal.SIGINT)
self.proc.wait()

View File

@ -1,4 +1,5 @@
from concurrent import futures
from db import SqliteDbProvider, PostgresDbProvider
from utils import NodeFactory, BitcoinD
import logging
@ -149,12 +150,13 @@ def teardown_checks(request):
@pytest.fixture
def node_factory(request, directory, test_name, bitcoind, executor, teardown_checks):
def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks):
nf = NodeFactory(
test_name,
bitcoind,
executor,
directory=directory,
db_provider=db_provider,
)
yield nf
@ -275,6 +277,21 @@ def checkMemleak(node):
return 0
# Mapping from TEST_DB_PROVIDER env variable to class to be used
providers = {
'sqlite3': SqliteDbProvider,
'postgres': PostgresDbProvider,
}
@pytest.fixture(scope="session")
def db_provider(test_base_dir):
provider = providers[os.getenv('TEST_DB_PROVIDER', 'sqlite3')](test_base_dir)
provider.start()
yield provider
provider.stop()
@pytest.fixture
def executor(teardown_checks):
ex = futures.ThreadPoolExecutor(max_workers=20)

View File

@ -10,3 +10,4 @@ pytest-xdist==1.29.0
python-bitcoinlib==0.10.1
tqdm==4.32.2
pytest-timeout==1.3.3
psycopg2==2.8.3

View File

@ -969,12 +969,11 @@ def test_reserve_enforcement(node_factory, executor):
l2.stop()
# They should both aim for 1%.
reserves = l2.db_query('SELECT channel_reserve_satoshis FROM channel_configs')
reserves = l2.db.query('SELECT channel_reserve_satoshis FROM channel_configs')
assert reserves == [{'channel_reserve_satoshis': 10**6 // 100}] * 2
# Edit db to reduce reserve to 0 so it will try to violate it.
l2.db_query('UPDATE channel_configs SET channel_reserve_satoshis=0',
use_copy=False)
l2.db.execute('UPDATE channel_configs SET channel_reserve_satoshis=0')
l2.start()
wait_for(lambda: only_one(l2.rpc.listpeers(l1.info['id'])['peers'])['connected'])

View File

@ -462,7 +462,9 @@ class LightningD(TailableProc):
class LightningNode(object):
def __init__(self, daemon, rpc, btc, executor, may_fail=False, may_reconnect=False, allow_broken_log=False, allow_bad_gossip=False):
def __init__(self, daemon, rpc, btc, executor, may_fail=False,
may_reconnect=False, allow_broken_log=False,
allow_bad_gossip=False, db=None):
self.rpc = rpc
self.daemon = daemon
self.bitcoin = btc
@ -471,6 +473,7 @@ class LightningNode(object):
self.may_reconnect = may_reconnect
self.allow_broken_log = allow_broken_log
self.allow_bad_gossip = allow_bad_gossip
self.db = db
def connect(self, remote_node):
self.rpc.connect(remote_node.info['id'], '127.0.0.1', remote_node.daemon.port)
@ -510,28 +513,8 @@ class LightningNode(object):
def getactivechannels(self):
return [c for c in self.rpc.listchannels()['channels'] if c['active']]
def db_query(self, query, use_copy=True):
orig = os.path.join(self.daemon.lightning_dir, "lightningd.sqlite3")
if use_copy:
copy = os.path.join(self.daemon.lightning_dir, "lightningd-copy.sqlite3")
shutil.copyfile(orig, copy)
db = sqlite3.connect(copy)
else:
db = sqlite3.connect(orig)
db.row_factory = sqlite3.Row
c = db.cursor()
c.execute(query)
rows = c.fetchall()
result = []
for row in rows:
result.append(dict(zip(row.keys(), row)))
db.commit()
c.close()
db.close()
return result
def db_query(self, query):
return self.db.query(query)
# Assumes node is stopped!
def db_manip(self, query):
@ -771,7 +754,7 @@ class LightningNode(object):
class NodeFactory(object):
"""A factory to setup and start `lightningd` daemons.
"""
def __init__(self, testname, bitcoind, executor, directory):
def __init__(self, testname, bitcoind, executor, directory, db_provider):
self.testname = testname
self.next_id = 1
self.nodes = []
@ -779,6 +762,7 @@ class NodeFactory(object):
self.bitcoind = bitcoind
self.directory = directory
self.lock = threading.Lock()
self.db_provider = db_provider
def split_options(self, opts):
"""Split node options from cli options
@ -880,11 +864,17 @@ class NodeFactory(object):
if options is not None:
daemon.opts.update(options)
# Get the DB backend DSN we should be using for this test and this node.
db = self.db_provider.get_db(lightning_dir, self.testname, node_id)
dsn = db.get_dsn()
if dsn is not None:
daemon.opts['wallet'] = dsn
rpc = LightningRpc(socket_path, self.executor)
node = LightningNode(daemon, rpc, self.bitcoind, self.executor, may_fail=may_fail,
may_reconnect=may_reconnect, allow_broken_log=allow_broken_log,
allow_bad_gossip=allow_bad_gossip)
allow_bad_gossip=allow_bad_gossip, db=db)
# Regtest estimatefee are unusable, so override.
node.set_feerates(feerates, False)