mirror of
https://github.com/ElementsProject/lightning.git
synced 2024-11-20 10:39:49 +01:00
1575181ef5
In particular, my bitcoind is in /usr/local/bin Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
199 lines
6.3 KiB
Python
199 lines
6.3 KiB
Python
from binascii import hexlify, unhexlify
|
|
from concurrent import futures
|
|
from hashlib import sha256
|
|
from utils import BitcoinD, LightningD, LightningRpc, LightningNode
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
|
|
bitcoind = None
|
|
TEST_DIR = tempfile.mkdtemp(prefix='lightning-')
|
|
VALGRIND = os.getenv("NOVALGRIND", None) == None
|
|
|
|
if os.getenv("TEST_DEBUG", None) != None:
|
|
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
|
|
logging.info("Tests running in '%s'", TEST_DIR)
|
|
|
|
|
|
def setupBitcoind():
|
|
global bitcoind
|
|
bitcoind = BitcoinD(rpcport=28332)
|
|
bitcoind.start()
|
|
info = bitcoind.rpc.getinfo()
|
|
# Make sure we have segwit and some funds
|
|
if info['blocks'] < 432:
|
|
logging.debug("SegWit not active, generating some more blocks")
|
|
bitcoind.rpc.generate(432 - info['blocks'])
|
|
elif info['balance'] < 1:
|
|
logging.debug("Insufficient balance, generating 1 block")
|
|
bitcoind.rpc.generate(1)
|
|
|
|
|
|
def tearDownBitcoind():
|
|
global bitcoind
|
|
try:
|
|
bitcoind.rpc.stop()
|
|
except:
|
|
bitcoind.proc.kill()
|
|
bitcoind.proc.wait()
|
|
|
|
|
|
def setUpModule():
|
|
setupBitcoind()
|
|
|
|
|
|
def tearDownModule():
|
|
tearDownBitcoind()
|
|
|
|
|
|
class NodeFactory(object):
|
|
"""A factory to setup and start `lightningd` daemons.
|
|
"""
|
|
def __init__(self, func, executor):
|
|
self.func = func
|
|
self.next_id = 1
|
|
self.nodes = []
|
|
self.executor = executor
|
|
|
|
def get_node(self):
|
|
node_id = self.next_id
|
|
self.next_id += 1
|
|
|
|
lightning_dir = os.path.join(
|
|
TEST_DIR, self.func._testMethodName, "lightning-{}/".format(node_id))
|
|
|
|
socket_path = os.path.join(lightning_dir, "lightning-rpc").format(node_id)
|
|
daemon = LightningD(lightning_dir, bitcoind.bitcoin_dir, port=16330+node_id)
|
|
rpc = LightningRpc(socket_path, self.executor)
|
|
node = LightningNode(daemon, rpc, bitcoind, self.executor)
|
|
self.nodes.append(node)
|
|
if VALGRIND:
|
|
node.daemon.cmd_line = [
|
|
'valgrind',
|
|
'-q',
|
|
'--error-exitcode=7',
|
|
'--log-file={}/valgrind-errors'.format(node.daemon.lightning_dir)
|
|
] + node.daemon.cmd_line
|
|
|
|
node.daemon.start()
|
|
node.rpc.connect_rpc()
|
|
# Cache `getinfo`, we'll be using it a lot
|
|
node.info = node.rpc.getinfo()
|
|
return node
|
|
|
|
def killall(self):
|
|
for n in self.nodes:
|
|
n.daemon.stop()
|
|
|
|
|
|
class LightningDTests(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
# Most of the executor threads will be waiting for IO, so
|
|
# let's have a few of them
|
|
self.executor = futures.ThreadPoolExecutor(max_workers=20)
|
|
self.node_factory = NodeFactory(self, self.executor)
|
|
|
|
def tearDown(self):
|
|
self.node_factory.killall()
|
|
self.executor.shutdown(wait=False)
|
|
# TODO(cdecker) Check that valgrind didn't find any errors
|
|
|
|
def test_connect(self):
|
|
l1 = self.node_factory.get_node()
|
|
l2 = self.node_factory.get_node()
|
|
l1.connect(l2, 0.01, async=False)
|
|
|
|
def test_successful_payment(self):
|
|
l1 = self.node_factory.get_node()
|
|
l2 = self.node_factory.get_node()
|
|
bitcoind = l1.bitcoin
|
|
|
|
capacity = 0.01 * 10**8 * 10**3
|
|
htlc_amount = 10000
|
|
l1.connect(l2, 0.01)
|
|
|
|
invoice = l2.rpc.invoice(htlc_amount, "successful_payment")
|
|
|
|
# TODO(cdecker) Assert that we have an invoice
|
|
rhash = invoice['rhash']
|
|
assert len(rhash) == 64
|
|
|
|
route = l1.rpc.getroute(l2.info['id'], htlc_amount, 1)
|
|
assert len(route) == 1
|
|
assert route[0] == {'msatoshi': htlc_amount, 'id': l2.info['id'], 'delay': 6}
|
|
|
|
receipt = l1.rpc.sendpay(route, invoice['rhash'])
|
|
assert sha256(unhexlify(receipt['preimage'])).hexdigest() == rhash
|
|
|
|
# Now go for the combined RPC call
|
|
invoice = l2.rpc.invoice(100, "one_shot_payment")
|
|
l1.rpc.pay(l2.info['id'], 100, invoice['rhash'])
|
|
|
|
def test_multihop_payment(self):
|
|
nodes = [self.node_factory.get_node() for _ in range(5)]
|
|
conn_futures = [nodes[i].connect(nodes[i+1], 0.01, async=True) for i in range(len(nodes)-1)]
|
|
|
|
htlc_amount = 10000
|
|
|
|
# Now wait for all of them
|
|
[f.result() for f in conn_futures]
|
|
|
|
time.sleep(1)
|
|
|
|
# Manually add channel l2 -> l3 to l1 so that it can compute the route
|
|
for i in range(len(nodes)-1):
|
|
nodes[0].rpc.dev_add_route(nodes[i].info['id'], nodes[i+1].info['id'], 1, 1, 6, 6)
|
|
#l1.rpc.dev_add_route(l2.info['id'], l3.info['id'], 1, 1, 6, 6)
|
|
|
|
invoice = nodes[-1].rpc.invoice(htlc_amount, "multihop_payment")
|
|
route = nodes[0].rpc.getroute(nodes[-1].info['id'], htlc_amount, 1)
|
|
receipt = nodes[0].rpc.sendpay(route, invoice['rhash'])
|
|
|
|
nodes[-1].daemon.wait_for_log("STATE_NORMAL_COMMITTING => STATE_NORMAL")
|
|
nodes[0].daemon.wait_for_log("STATE_NORMAL_COMMITTING => STATE_NORMAL")
|
|
|
|
@unittest.skip('Too damn long')
|
|
def test_routing_gossip(self):
|
|
nodes = [self.node_factory.get_node() for _ in range(20)]
|
|
l1 = nodes[0]
|
|
l5 = nodes[4]
|
|
|
|
for i in range(len(nodes)-1):
|
|
nodes[i].connect(nodes[i+1], 0.01)
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < len(nodes) * 30:
|
|
if sum([c['active'] for c in l1.rpc.getchannels()]) == 2*(len(nodes)-1):
|
|
break
|
|
time.sleep(1)
|
|
l1.bitcoin.rpc.getinfo()
|
|
|
|
while time.time() - start_time < len(nodes) * 30:
|
|
if sum([c['active'] for c in l5.rpc.getchannels()]) == 2*(len(nodes)-1):
|
|
break
|
|
time.sleep(1)
|
|
l1.bitcoin.rpc.getinfo()
|
|
|
|
# Quick check that things are reasonable
|
|
assert sum([len(l.rpc.getchannels()) for l in nodes]) == 5*2*(len(nodes) - 1)
|
|
|
|
# Deep check that all channels are in there
|
|
comb = []
|
|
for i in range(len(nodes) - 1):
|
|
comb.append((nodes[i].info['id'], nodes[i+1].info['id']))
|
|
comb.append((nodes[i+1].info['id'], nodes[i].info['id']))
|
|
|
|
for n in nodes:
|
|
seen = []
|
|
for c in n.rpc.getchannels():
|
|
seen.append((c['from'],c['to']))
|
|
assert set(seen) == set(comb)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2)
|