diff --git a/tests/test_connection.py b/tests/test_connection.py index 879077a7d..2feeeb570 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -22,6 +22,7 @@ import re import time import unittest import websocket +import ssl def test_connect_basic(node_factory): @@ -4521,3 +4522,78 @@ def test_last_stable_connection(node_factory): assert only_one(l1.rpc.listpeerchannels()['channels'])['last_stable_connection'] > recon_time + STABLE_TIME assert only_one(l2.rpc.listpeerchannels()['channels'])['last_stable_connection'] > recon_time + STABLE_TIME + + +def test_wss_proxy(node_factory): + wss_port = reserve() + ws_port = reserve() + port = reserve() + wss_proxy_certs = node_factory.directory + '/wss-proxy-certs' + l1 = node_factory.get_node(options={'addr': ':' + str(port), + 'bind-addr': 'ws:127.0.0.1:' + str(ws_port), + 'wss-bind-addr': '127.0.0.1:' + str(wss_port), + 'wss-certs': wss_proxy_certs, + 'dev-allow-localhost': None}) + + # Some depend on ipv4 vs ipv6 behaviour... + for b in l1.rpc.getinfo()['binding']: + if b['type'] == 'ipv4': + assert b == {'type': 'ipv4', 'address': '0.0.0.0', 'port': port} + elif b['type'] == 'ipv6': + assert b == {'type': 'ipv6', 'address': '::', 'port': port} + else: + assert b == {'type': 'websocket', + 'address': '127.0.0.1', + 'subtype': 'ipv4', + 'port': ws_port} + + # Adapter to turn web secure socket into a stream "connection" + class BindWebSecureSocket(object): + def __init__(self, hostname, port): + certfile = f'{wss_proxy_certs}/client.pem' + keyfile = f'{wss_proxy_certs}/client-key.pem' + self.ws = websocket.WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE, "ssl_version": ssl.PROTOCOL_TLS_CLIENT, "certfile": certfile, "keyfile": keyfile}) + self.ws.connect("wss://" + hostname + ":" + str(port)) + self.recvbuf = bytes() + + def send(self, data): + self.ws.send(data, websocket.ABNF.OPCODE_BINARY) + + def recv(self, maxlen): + while len(self.recvbuf) < maxlen: + self.recvbuf += self.ws.recv() + + ret = self.recvbuf[:maxlen] + self.recvbuf = self.recvbuf[maxlen:] + return ret + + wss = BindWebSecureSocket('localhost', wss_port) + + lconn = wire.LightningConnection(wss, + wire.PublicKey(bytes.fromhex(l1.info['id'])), + wire.PrivateKey(bytes([1] * 32)), + is_initiator=True) + + # This might happen really early! + l1.daemon.logsearch_start = 0 + l1.daemon.wait_for_log(r'Websocket Secure Server Started') + + # Perform handshake. + lconn.shake() + + # Expect to receive init msg. + msg = lconn.read_message() + assert int.from_bytes(msg[0:2], 'big') == 16 + + # Echo same message back. + lconn.send_message(msg) + + # Now try sending a ping, ask for 50 bytes + msg = bytes((0, 18, 0, 50, 0, 0)) + lconn.send_message(msg) + + # Could actually reply with some gossip msg! + while True: + msg = lconn.read_message() + if int.from_bytes(msg[0:2], 'big') == 19: + break