[test] create unit-test framework for RPC wallets (#2396)

---------

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
Vlad Stan 2024-04-15 18:24:28 +03:00 committed by GitHub
parent b145bff566
commit 69ce0e565b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 2128 additions and 270 deletions

View file

@ -40,6 +40,9 @@
pytest-md = prev.pytest-md.overridePythonAttrs (
old: { buildInputs = (old.buildInputs or []) ++ [ prev.setuptools ]; }
);
types-mock = prev.pytest-md.overridePythonAttrs (
old: { buildInputs = (old.buildInputs or []) ++ [ prev.setuptools ]; }
);
});
};
});

View file

@ -54,12 +54,19 @@ class CoreLightningWallet(Wallet):
async def status(self) -> StatusResponse:
try:
funds: dict = self.ln.listfunds() # type: ignore
if len(funds) == 0:
return StatusResponse("no data", 0)
return StatusResponse(
None, sum([int(ch["our_amount_msat"]) for ch in funds["channels"]])
)
except RpcError as exc:
error_message = f"lightningd '{exc.method}' failed with '{exc.error}'."
logger.warning(exc)
error_message = f"RPC '{exc.method}' failed with '{exc.error}'."
return StatusResponse(error_message, 0)
except Exception as exc:
logger.warning(f"Failed to connect, got: '{exc}'")
return StatusResponse(f"Unable to connect, got: '{exc}'", 0)
async def create_invoice(
self,
@ -69,7 +76,7 @@ class CoreLightningWallet(Wallet):
unhashed_description: Optional[bytes] = None,
**kwargs,
) -> InvoiceResponse:
label = f"lbl{random.random()}"
label = kwargs.get("label", f"lbl{random.random()}")
msat: int = int(amount * 1000)
try:
if description_hash and not unhashed_description:
@ -95,14 +102,18 @@ class CoreLightningWallet(Wallet):
if r.get("code") and r.get("code") < 0: # type: ignore
raise Exception(r.get("message"))
return InvoiceResponse(True, r["payment_hash"], r["bolt11"], "")
return InvoiceResponse(True, r["payment_hash"], r["bolt11"], None)
except RpcError as exc:
error_message = (
f"CoreLightning method '{exc.method}' failed with"
f" '{exc.error.get('message') or exc.error}'." # type: ignore
)
logger.warning(exc)
error_message = f"RPC '{exc.method}' failed with '{exc.error}'."
return InvoiceResponse(False, None, None, error_message)
except KeyError as exc:
logger.warning(exc)
return InvoiceResponse(
False, None, None, "Server error: 'missing required fields'"
)
except Exception as e:
logger.warning(e)
return InvoiceResponse(False, None, None, str(e))
async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse:
@ -111,94 +122,111 @@ class CoreLightningWallet(Wallet):
except Bolt11Exception as exc:
return PaymentResponse(False, None, None, None, str(exc))
previous_payment = await self.get_payment_status(invoice.payment_hash)
if previous_payment.paid:
return PaymentResponse(False, None, None, None, "invoice already paid")
if not invoice.amount_msat or invoice.amount_msat <= 0:
return PaymentResponse(
False, None, None, None, "CLN 0 amount invoice not supported"
)
fee_limit_percent = fee_limit_msat / invoice.amount_msat * 100
# so fee_limit_percent is applied even on payments with fee < 5000 millisatoshi
# (which is default value of exemptfee)
payload = {
"bolt11": bolt11,
"maxfeepercent": f"{fee_limit_percent:.11}",
"exemptfee": 0,
# so fee_limit_percent is applied even on payments with fee < 5000
# millisatoshi (which is default value of exemptfee)
"description": invoice.description,
}
try:
previous_payment = await self.get_payment_status(invoice.payment_hash)
if previous_payment.paid:
return PaymentResponse(False, None, None, None, "invoice already paid")
if not invoice.amount_msat or invoice.amount_msat <= 0:
return PaymentResponse(
False, None, None, None, "CLN 0 amount invoice not supported"
)
fee_limit_percent = fee_limit_msat / invoice.amount_msat * 100
# so fee_limit_percent is applied even
# on payments with fee < 5000 millisatoshi
# (which is default value of exemptfee)
payload = {
"bolt11": bolt11,
"maxfeepercent": f"{fee_limit_percent:.11}",
"exemptfee": 0,
# so fee_limit_percent is applied even on payments with fee < 5000
# millisatoshi (which is default value of exemptfee)
"description": invoice.description,
}
r = await run_sync(lambda: self.ln.call("pay", payload))
fee_msat = -int(r["amount_sent_msat"] - r["amount_msat"])
return PaymentResponse(
True, r["payment_hash"], fee_msat, r["payment_preimage"], None
)
except RpcError as exc:
logger.warning(exc)
try:
error_message = exc.error["attempts"][-1]["fail_reason"] # type: ignore
except Exception:
error_message = (
f"CoreLightning method '{exc.method}' failed with"
f" '{exc.error.get('message') or exc.error}'." # type: ignore
)
error_message = f"RPC '{exc.method}' failed with '{exc.error}'."
return PaymentResponse(False, None, None, None, error_message)
fee_msat = -int(r["amount_sent_msat"] - r["amount_msat"])
return PaymentResponse(
True, r["payment_hash"], fee_msat, r["payment_preimage"], None
)
except KeyError as exc:
logger.warning(exc)
return PaymentResponse(
False, None, None, None, "Server error: 'missing required fields'"
)
except Exception as exc:
logger.info(f"Failed to pay invoice {bolt11}")
logger.warning(exc)
return PaymentResponse(False, None, None, None, f"Payment failed: '{exc}'.")
async def get_invoice_status(self, checking_id: str) -> PaymentStatus:
try:
r: dict = self.ln.listinvoices(payment_hash=checking_id) # type: ignore
except RpcError:
return PaymentPendingStatus()
if not r["invoices"]:
return PaymentPendingStatus()
invoice_resp = r["invoices"][-1]
if invoice_resp["payment_hash"] == checking_id:
if invoice_resp["status"] == "paid":
return PaymentSuccessStatus()
elif invoice_resp["status"] == "unpaid":
if not r["invoices"]:
return PaymentPendingStatus()
elif invoice_resp["status"] == "expired":
return PaymentFailedStatus()
else:
logger.warning(f"supplied an invalid checking_id: {checking_id}")
return PaymentPendingStatus()
invoice_resp = r["invoices"][-1]
if invoice_resp["payment_hash"] == checking_id:
if invoice_resp["status"] == "paid":
return PaymentSuccessStatus()
elif invoice_resp["status"] == "unpaid":
return PaymentPendingStatus()
elif invoice_resp["status"] == "expired":
return PaymentFailedStatus()
else:
logger.warning(f"supplied an invalid checking_id: {checking_id}")
return PaymentPendingStatus()
except RpcError as exc:
logger.warning(exc)
return PaymentPendingStatus()
except Exception as exc:
logger.warning(exc)
return PaymentPendingStatus()
async def get_payment_status(self, checking_id: str) -> PaymentStatus:
try:
r: dict = self.ln.listpays(payment_hash=checking_id) # type: ignore
except Exception:
return PaymentPendingStatus()
if "pays" not in r:
return PaymentPendingStatus()
if not r["pays"]:
# no payment with this payment_hash is found
return PaymentFailedStatus()
payment_resp = r["pays"][-1]
if payment_resp["payment_hash"] == checking_id:
status = payment_resp["status"]
if status == "complete":
fee_msat = -int(
payment_resp["amount_sent_msat"] - payment_resp["amount_msat"]
)
return PaymentSuccessStatus(
fee_msat=fee_msat, preimage=payment_resp["preimage"]
)
elif status == "failed":
return PaymentFailedStatus()
else:
if "pays" not in r:
return PaymentPendingStatus()
else:
logger.warning(f"supplied an invalid checking_id: {checking_id}")
return PaymentPendingStatus()
if not r["pays"]:
# no payment with this payment_hash is found
return PaymentFailedStatus()
payment_resp = r["pays"][-1]
if payment_resp["payment_hash"] == checking_id:
status = payment_resp["status"]
if status == "complete":
fee_msat = -int(
payment_resp["amount_sent_msat"] - payment_resp["amount_msat"]
)
return PaymentSuccessStatus(
fee_msat=fee_msat, preimage=payment_resp["preimage"]
)
elif status == "failed":
return PaymentFailedStatus()
else:
return PaymentPendingStatus()
else:
logger.warning(f"supplied an invalid checking_id: {checking_id}")
return PaymentPendingStatus()
except Exception as exc:
logger.warning(exc)
return PaymentPendingStatus()
async def paid_invoices_stream(self) -> AsyncGenerator[str, None]:
while True:

46
poetry.lock generated
View file

@ -1364,6 +1364,22 @@ docs = ["alabaster (==0.7.13)", "autodocsumm (==0.2.11)", "sphinx (==7.0.1)", "s
lint = ["flake8 (==6.0.0)", "flake8-bugbear (==23.7.10)", "mypy (==1.4.1)", "pre-commit (>=2.4,<4.0)"]
tests = ["pytest", "pytz", "simplejson"]
[[package]]
name = "mock"
version = "5.1.0"
description = "Rolling backport of unittest.mock for all Pythons"
optional = false
python-versions = ">=3.6"
files = [
{file = "mock-5.1.0-py3-none-any.whl", hash = "sha256:18c694e5ae8a208cdb3d2c20a993ca1a7b0efa258c247a1e565150f477f83744"},
{file = "mock-5.1.0.tar.gz", hash = "sha256:5e96aad5ccda4718e0a229ed94b2024df75cc2d55575ba5762d31f5767b8767d"},
]
[package.extras]
build = ["blurb", "twine", "wheel"]
docs = ["sphinx"]
test = ["pytest", "pytest-cov"]
[[package]]
name = "mypy"
version = "1.7.1"
@ -1996,6 +2012,23 @@ files = [
[package.dependencies]
pytest = ">=4.2.1"
[[package]]
name = "pytest-mock"
version = "3.14.0"
description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
]
[package.dependencies]
pytest = ">=6.2.5"
[package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"]
[[package]]
name = "python-crontab"
version = "3.0.0"
@ -2592,6 +2625,17 @@ notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "types-mock"
version = "5.1.0.20240311"
description = "Typing stubs for mock"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-mock-5.1.0.20240311.tar.gz", hash = "sha256:7472797986d83016f96fde7f73577d129b0cd8a8d0b783487a7be330d57ba431"},
{file = "types_mock-5.1.0.20240311-py3-none-any.whl", hash = "sha256:0769cb376dfc75b45215619f17a9fd6333d771cc29ce4a38937f060b1e45530f"},
]
[[package]]
name = "types-passlib"
version = "1.7.7.13"
@ -3013,4 +3057,4 @@ liquid = ["wallycore"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10 | ^3.9"
content-hash = "4c11cc117beb703ebece5fac43adbabae76804f084c39ef90a67edcfb56795d7"
content-hash = "fd9ace1dada06a9a4556ffe888c9c391d1da4e2febd22084b6f53e6006eefa6e"

View file

@ -74,6 +74,9 @@ json5 = "^0.9.17"
asgi-lifespan = "^2.1.0"
pytest-md = "^0.2.0"
pytest-httpserver = "^1.0.10"
pytest-mock = "^3.14.0"
types-mock = "^5.1.0.20240311"
mock = "^5.1.0"
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -5,12 +5,11 @@ import random
import string
import time
from subprocess import PIPE, Popen, TimeoutExpired
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Tuple
from loguru import logger
from psycopg2 import connect
from psycopg2.errors import InvalidCatalogName
from pydantic import BaseModel
from lnbits import core
from lnbits.db import DB_TYPE, POSTGRES, FromRowModel
@ -179,119 +178,3 @@ def clean_database(settings):
# TODO: do this once mock data is removed from test data folder
# os.remove(settings.lnbits_data_folder + "/database.sqlite3")
pass
def rest_wallet_fixtures_from_json(path) -> List["WalletTest"]:
with open(path) as f:
data = json.load(f)
funding_sources = data["funding_sources"]
tests: Dict[str, List[WalletTest]] = {
fs_name: [] for fs_name in funding_sources
}
for fn_name in data["functions"]:
fn = data["functions"][fn_name]
for test in fn["tests"]:
"""create an unit test for each funding source"""
for fs_name in funding_sources:
t = WalletTest(
**{
"funding_source": FundingSourceConfig(
**funding_sources[fs_name]
),
"function": fn_name,
**test,
"mocks": [],
}
)
if "mocks" in test:
if fs_name not in test["mocks"]:
t.skip = True
tests[fs_name].append(t)
continue
test_mocks_names = test["mocks"][fs_name]
fs_mocks = fn["mocks"][fs_name]
for mock_name in fs_mocks:
for test_mock in test_mocks_names[mock_name]:
# different mocks that result in the same
# return value for the tested function
_mock = fs_mocks[mock_name] | test_mock
mock = Mock(**_mock)
unique_test = WalletTest(**t.dict())
unique_test.description = (
f"""{t.description}:{mock.description or ""}"""
)
unique_test.mocks = t.mocks + [mock]
unique_test.skip = mock.skip
tests[fs_name].append(unique_test)
else:
# add the test without mocks
tests[fs_name].append(t)
all_tests = sum([tests[fs_name] for fs_name in tests], [])
return all_tests
class FundingSourceConfig(BaseModel):
wallet_class: str
settings: dict
class FunctionMock(BaseModel):
uri: str
query_params: Optional[dict]
headers: dict
method: str
class TestMock(BaseModel):
skip: Optional[bool]
description: Optional[str]
request_type: Optional[str]
request_body: Optional[dict]
response_type: str
response: Union[str, dict]
class Mock(FunctionMock, TestMock):
pass
class FunctionMocks(BaseModel):
mocks: Dict[str, FunctionMock]
class FunctionTest(BaseModel):
description: str
call_params: dict
expect: dict
mocks: Dict[str, List[Dict[str, TestMock]]]
class FunctionData(BaseModel):
"""Data required for testing this function"""
"Function level mocks that apply for all tests of this function"
mocks: List[FunctionMock] = []
"All the tests for this function"
tests: List[FunctionTest] = []
class WalletTest(BaseModel):
skip: Optional[bool]
function: str
description: str
funding_source: FundingSourceConfig
call_params: Optional[dict] = {}
expect: Optional[dict]
expect_error: Optional[dict]
mocks: List[Mock] = []

Binary file not shown.

View file

@ -0,0 +1,32 @@
-----BEGIN CERTIFICATE-----
MIIFbzCCA1egAwIBAgIUfkee1G4E8QAadd517sY/9+6xr0AwDQYJKoZIhvcNAQEL
BQAwRjELMAkGA1UEBhMCU1YxFDASBgNVBAgMC0VsIFNhbHZhZG9yMSEwHwYDVQQK
DBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwIBcNMjQwNDAzMTMyMTM5WhgPMjA1
MTA4MjAxMzIxMzlaMEYxCzAJBgNVBAYTAlNWMRQwEgYDVQQIDAtFbCBTYWx2YWRv
cjEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIICIjANBgkqhkiG
9w0BAQEFAAOCAg8AMIICCgKCAgEAnW4MKs2Y3qZnn2+J/Bp21aUuJ7oE8ll82Q2C
uh8VAlsNnGDpTyOSRLHLmxV+cu82umvVPBpOVwAl17/VuxcLjFVSk7YOMj3MWoF5
hm+oBtetouSDt3H0+BoDuXN3eVsLI4b+e1F6ag7JIwsDQvRUbGTFiyHVvXolTZPb
wtFzlwQSB5i6KHKRQ+W6Q+cz4khIRO79IhaEiu5TWDrmx+6WkZxWYYO/g/I/S1gX
l1JP6gXQFabwUFn+CBAxPsi7f+igi6gIepXBQOIG1dkZ5ojJPabtvblO7mWJTsec
2D4Vb3L7OfboIYC85gY1cudWBX3oAASIVh9m9YoCZW2WOMNr6apnJSXx36ueJXAS
rPq3C2haPWO8z+0nYkaYTcTAxeCvs0ux2DGIniinC+u1cELg6REK2X1K8YsSsXrc
U1T8rNs2azyzTxglIHHac6ScG+Ac1nlY54C9UfZZcztE8nUBqJi+Eowpyr+y3QvT
zNdulc80xpi5arbzt85BNi+xX+NZC07QjgUJ/eexRglP3flfTbbnG8Pphe/M/l04
IfBWBqK2cF9Fd+1J+Zf7fXZrw+41QF8WukLoQ4JQEMqIIhDFzaoTi5ogsnhiGu0Z
iaCATfCLMsWvAPHw6afFw2/utdvCd2Dr22H16hj0xEkNOw702/AoNWMFmzIzuC9m
VjkH1KUCAwEAAaNTMFEwHQYDVR0OBBYEFJAQIGLZNVRwGIgb3cmPTAiduzreMB8G
A1UdIwQYMBaAFJAQIGLZNVRwGIgb3cmPTAiduzreMA8GA1UdEwEB/wQFMAMBAf8w
DQYJKoZIhvcNAQELBQADggIBAFOaWcLZSU46Zr43kQU+w+A70r+unmRfsANeREDi
Qvjg1ihJLO8g1l7Cu74QUqLwx8BG3KO7ZbDcN6uTeCrYgyERSVUxNAwu5hf2LnEr
MQ/L4h0j/8flj9oowTDCit/6YXTJ1Mf8OaKkSliUYVsoZCaIISZ2pvcZbU1cXCeX
JBM4Zr1ijM8qbghPoG6O7Ep/A3VHTozuAU9C7uREH+XJFepr9BXjrFqyzx/ArEZa
5HIO9nOqWqtwMFDE2jX3Ios3tjbU275ez2Xd7meDn0iPWMEgNbXX6b+FFlNkajR2
NchPmBigBpk9bt63HeIQb2t/VU7X9FvMTqCbp1R2MGiHTMyQ9IjeoYKNy/mur/GG
DQkG7rq52oPGI06CJ7uuMEhCm6jNVtIbnCTl2jRnkD1fqKVmQa9Cn7jqDqR2dhqX
AxTk01Vhinxhik0ckhcgViRgiBWSnnx4Vzk7wyV6O4EdtLTywkywTR/+WEisBVUV
LOXZEmxj+AVARARUds+a/IgdANFGr/yWI6WBOibjoEFZMEZqzwlcEErgxLRinUvb
9COmr6ig+zC1570V2ktmn1P/qodOD4tOL0ICSkKoTQLFPfevM2y0DdN48T2kxzZ5
TruiKHuAnOhvwKwUpF+TRFMUWft3VG9GJXm/4A9FWm/ALLrqw2oSXGrl5z8pq29z
SN2A
-----END CERTIFICATE-----

View file

@ -4,7 +4,8 @@
"wallet_class": "CoreLightningRestWallet",
"settings": {
"corelightning_rest_url": "http://127.0.0.1:8555",
"corelightning_rest_macaroon": "eNcRyPtEdMaCaRoOn"
"corelightning_rest_macaroon": "eNcRyPtEdMaCaRoOn",
"user_agent": "LNbits/Tests"
}
},
"lndrest": {
@ -12,14 +13,16 @@
"settings": {
"lnd_rest_endpoint": "http://127.0.0.1:8555",
"lnd_rest_macaroon": "eNcRyPtEdMaCaRoOn",
"lnd_rest_cert": ""
"lnd_rest_cert": "",
"user_agent": "LNbits/Tests"
}
},
"alby": {
"wallet_class": "AlbyWallet",
"settings": {
"alby_api_endpoint": "http://127.0.0.1:8555",
"alby_access_token": "mock-alby-access-token"
"alby_access_token": "mock-alby-access-token",
"user_agent": "LNbits/Tests"
}
}
},

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,134 @@
from typing import Dict, List, Optional, Union
from pydantic import BaseModel
class FundingSourceConfig(BaseModel):
name: str
skip: Optional[bool]
wallet_class: str
client_field: Optional[str]
settings: dict
class FunctionMock(BaseModel):
uri: Optional[str]
query_params: Optional[dict]
headers: Optional[dict]
method: Optional[str]
class TestMock(BaseModel):
skip: Optional[bool]
description: Optional[str]
request_type: Optional[str]
request_body: Optional[dict]
response_type: str
response: Union[str, dict]
class Mock(FunctionMock, TestMock):
@staticmethod
def combine_mocks(fs_mock, test_mock):
_mock = fs_mock | test_mock
if "response" in _mock and "response" in fs_mock:
_mock["response"] |= fs_mock["response"]
return Mock(**_mock)
class FunctionMocks(BaseModel):
mocks: Dict[str, FunctionMock]
class FunctionTest(BaseModel):
description: str
call_params: dict
expect: dict
mocks: Dict[str, List[Dict[str, TestMock]]]
class FunctionData(BaseModel):
"""Data required for testing this function"""
"Function level mocks that apply for all tests of this function"
mocks: List[FunctionMock] = []
"All the tests for this function"
tests: List[FunctionTest] = []
class WalletTest(BaseModel):
skip: Optional[bool]
function: str
description: str
funding_source: FundingSourceConfig
call_params: Optional[dict] = {}
expect: Optional[dict]
expect_error: Optional[dict]
mocks: List[Mock] = []
@staticmethod
def tests_for_funding_source(
fs: FundingSourceConfig,
fn_name: str,
fn,
test,
) -> List["WalletTest"]:
t = WalletTest(
**{
"funding_source": fs,
"function": fn_name,
**test,
"mocks": [],
"skip": fs.skip,
}
)
if "mocks" in test:
if fs.name not in test["mocks"]:
t.skip = True
return [t]
return t._tests_from_fs_mocks(fn, test, fs.name)
return [t]
def _tests_from_fs_mocks(self, fn, test, fs_name: str) -> List["WalletTest"]:
tests: List[WalletTest] = []
fs_mocks = fn["mocks"][fs_name]
test_mocks = test["mocks"][fs_name]
for mock_name in fs_mocks:
tests += self._tests_from_mocks(fs_mocks[mock_name], test_mocks[mock_name])
return tests
def _tests_from_mocks(self, fs_mock, test_mocks) -> List["WalletTest"]:
tests: List[WalletTest] = []
for test_mock in test_mocks:
# different mocks that result in the same
# return value for the tested function
unique_test = self._test_from_mocks(fs_mock, test_mock)
tests.append(unique_test)
return tests
def _test_from_mocks(self, fs_mock, test_mock) -> "WalletTest":
mock = Mock.combine_mocks(fs_mock, test_mock)
return WalletTest(
**(
self.dict()
| {
"description": f"""{self.description}:{mock.description or ""}""",
"mocks": self.mocks + [mock],
"skip": self.skip or mock.skip,
}
)
)
class DataObject:
def __init__(self, **kwargs):
for k in kwargs:
setattr(self, k, kwargs[k])

117
tests/wallets/helpers.py Normal file
View file

@ -0,0 +1,117 @@
import importlib
import json
from typing import Dict, List
import pytest
from lnbits.core.models import BaseWallet
from tests.wallets.fixtures.models import FundingSourceConfig, WalletTest
wallets_module = importlib.import_module("lnbits.wallets")
def wallet_fixtures_from_json(path) -> List["WalletTest"]:
with open(path) as f:
data = json.load(f)
funding_sources = [
FundingSourceConfig(name=fs_name, **data["funding_sources"][fs_name])
for fs_name in data["funding_sources"]
]
tests: Dict[str, List[WalletTest]] = {}
for fn_name in data["functions"]:
fn = data["functions"][fn_name]
fn_tests = _tests_for_function(funding_sources, fn_name, fn)
_merge_dict_of_lists(tests, fn_tests)
all_tests = sum([tests[fs_name] for fs_name in tests], [])
return all_tests
def _tests_for_function(
funding_sources: List[FundingSourceConfig], fn_name: str, fn
) -> Dict[str, List[WalletTest]]:
tests: Dict[str, List[WalletTest]] = {}
for test in fn["tests"]:
"""create an unit test for each funding source"""
fs_tests = _tests_for_funding_source(funding_sources, fn_name, fn, test)
_merge_dict_of_lists(tests, fs_tests)
return tests
def _tests_for_funding_source(
funding_sources: List[FundingSourceConfig], fn_name: str, fn, test
) -> Dict[str, List[WalletTest]]:
tests: Dict[str, List[WalletTest]] = {fs.name: [] for fs in funding_sources}
for fs in funding_sources:
tests[fs.name] += WalletTest.tests_for_funding_source(fs, fn_name, fn, test)
return tests
def build_test_id(test: WalletTest):
return f"{test.funding_source}.{test.function}({test.description})"
def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
custom_settings = funding_source.settings
original_settings = {}
settings = getattr(wallets_module, "settings")
for s in custom_settings:
original_settings[s] = getattr(settings, s)
setattr(settings, s, custom_settings[s])
fs_instance: BaseWallet = getattr(wallets_module, funding_source.wallet_class)()
# rollback settings (global variable)
for s in original_settings:
setattr(settings, s, original_settings[s])
return fs_instance
async def check_assertions(wallet, _test_data: WalletTest):
test_data = _test_data.dict()
tested_func = _test_data.function
call_params = _test_data.call_params
if "expect" in test_data:
await _assert_data(wallet, tested_func, call_params, _test_data.expect)
# if len(_test_data.mocks) == 0:
# # all calls should fail after this method is called
# await wallet.cleanup()
# # same behaviour expected is server canot be reached
# # or if the connection was closed
# await _assert_data(wallet, tested_func, call_params, _test_data.expect)
elif "expect_error" in test_data:
await _assert_error(wallet, tested_func, call_params, _test_data.expect_error)
else:
assert False, "Expected outcome not specified"
async def _assert_data(wallet, tested_func, call_params, expect):
resp = await getattr(wallet, tested_func)(**call_params)
for key in expect:
received = getattr(resp, key)
expected = expect[key]
assert (
getattr(resp, key) == expect[key]
), f"""Field "{key}". Received: "{received}". Expected: "{expected}"."""
async def _assert_error(wallet, tested_func, call_params, expect_error):
error_module = importlib.import_module(expect_error["module"])
error_class = getattr(error_module, expect_error["class"])
with pytest.raises(error_class) as e_info:
await getattr(wallet, tested_func)(**call_params)
assert e_info.match(expect_error["message"])
def _merge_dict_of_lists(v1: Dict[str, List], v2: Dict[str, List]):
"""Merge v2 into v1"""
for k in v2:
v1[k] = v2[k] if k not in v1 else v1[k] + v2[k]

View file

@ -1,4 +1,3 @@
import importlib
import json
from typing import Dict, Union
from urllib.parse import urlencode
@ -7,16 +6,15 @@ import pytest
from pytest_httpserver import HTTPServer
from werkzeug.wrappers import Response
from lnbits.core.models import BaseWallet
from tests.helpers import (
FundingSourceConfig,
Mock,
from tests.wallets.fixtures.models import Mock
from tests.wallets.helpers import (
WalletTest,
rest_wallet_fixtures_from_json,
build_test_id,
check_assertions,
load_funding_source,
wallet_fixtures_from_json,
)
wallets_module = importlib.import_module("lnbits.wallets")
# todo:
# - tests for extra fields
# - tests for paid_invoices_stream
@ -29,14 +27,10 @@ def httpserver_listen_address():
return ("127.0.0.1", 8555)
def build_test_id(test: WalletTest):
return f"{test.funding_source}.{test.function}({test.description})"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_data",
rest_wallet_fixtures_from_json("tests/wallets/fixtures.json"),
wallet_fixtures_from_json("tests/wallets/fixtures/json/fixtures_rest.json"),
ids=build_test_id,
)
async def test_rest_wallet(httpserver: HTTPServer, test_data: WalletTest):
@ -46,8 +40,8 @@ async def test_rest_wallet(httpserver: HTTPServer, test_data: WalletTest):
for mock in test_data.mocks:
_apply_mock(httpserver, mock)
wallet = _load_funding_source(test_data.funding_source)
await _check_assertions(wallet, test_data)
wallet = load_funding_source(test_data.funding_source)
await check_assertions(wallet, test_data)
def _apply_mock(httpserver: HTTPServer, mock: Mock):
@ -65,6 +59,8 @@ def _apply_mock(httpserver: HTTPServer, mock: Mock):
if mock.query_params:
request_data["query_string"] = mock.query_params
assert mock.uri, "Missing URI for HTTP mock."
assert mock.method, "Missing method for HTTP mock."
req = httpserver.expect_request(
uri=mock.uri,
headers=mock.headers,
@ -84,60 +80,3 @@ def _apply_mock(httpserver: HTTPServer, mock: Mock):
respond_with = f"respond_with_{response_type}"
getattr(req, respond_with)(server_response)
async def _check_assertions(wallet, _test_data: WalletTest):
test_data = _test_data.dict()
tested_func = _test_data.function
call_params = _test_data.call_params
if "expect" in test_data:
await _assert_data(wallet, tested_func, call_params, _test_data.expect)
# if len(_test_data.mocks) == 0:
# # all calls should fail after this method is called
# await wallet.cleanup()
# # same behaviour expected is server canot be reached
# # or if the connection was closed
# await _assert_data(wallet, tested_func, call_params, _test_data.expect)
elif "expect_error" in test_data:
await _assert_error(wallet, tested_func, call_params, _test_data.expect_error)
else:
assert False, "Expected outcome not specified"
async def _assert_data(wallet, tested_func, call_params, expect):
resp = await getattr(wallet, tested_func)(**call_params)
for key in expect:
received = getattr(resp, key)
expected = expect[key]
assert (
getattr(resp, key) == expect[key]
), f"""Field "{key}". Received: "{received}". Expected: "{expected}"."""
async def _assert_error(wallet, tested_func, call_params, expect_error):
error_module = importlib.import_module(expect_error["module"])
error_class = getattr(error_module, expect_error["class"])
with pytest.raises(error_class) as e_info:
await getattr(wallet, tested_func)(**call_params)
assert e_info.match(expect_error["message"])
def _load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
custom_settings = funding_source.settings | {"user_agent": "LNbits/Tests"}
original_settings = {}
settings = getattr(wallets_module, "settings")
for s in custom_settings:
original_settings[s] = getattr(settings, s)
setattr(settings, s, custom_settings[s])
fs_instance: BaseWallet = getattr(wallets_module, funding_source.wallet_class)()
# rollback settings (global variable)
for s in original_settings:
setattr(settings, s, original_settings[s])
return fs_instance

View file

@ -0,0 +1,145 @@
import importlib
from typing import Dict, List, Optional
import pytest
from mock import Mock
from pytest_mock.plugin import MockerFixture
from lnbits.core.models import BaseWallet
from tests.wallets.fixtures.models import DataObject
from tests.wallets.fixtures.models import Mock as RpcMock
from tests.wallets.helpers import (
WalletTest,
build_test_id,
check_assertions,
load_funding_source,
wallet_fixtures_from_json,
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_data",
wallet_fixtures_from_json("tests/wallets/fixtures/json/fixtures_rpc.json"),
ids=build_test_id,
)
async def test_wallets(mocker: MockerFixture, test_data: WalletTest):
if test_data.skip:
pytest.skip()
for mock in test_data.mocks:
_apply_rpc_mock(mocker, mock)
wallet = load_funding_source(test_data.funding_source)
expected_calls = _spy_mocks(mocker, test_data, wallet)
await check_assertions(wallet, test_data)
_check_calls(expected_calls)
def _apply_rpc_mock(mocker: MockerFixture, mock: RpcMock):
return_value = {}
assert isinstance(mock.response, dict), "Expected data RPC response"
for field_name in mock.response:
value = mock.response[field_name]
values = value if isinstance(value, list) else [value]
return_value[field_name] = Mock(side_effect=[_mock_field(f) for f in values])
m = _data_mock(return_value)
assert mock.method, "Missing method for RPC mock."
mocker.patch(mock.method, m)
def _check_calls(expected_calls):
for func in expected_calls:
func_calls = expected_calls[func]
for func_call in func_calls:
req = func_call["request_data"]
args = req["args"] if "args" in req else {}
kwargs = req["kwargs"] if "kwargs" in req else {}
if "klass" in req:
*rest, cls = req["klass"].split(".")
req_module = importlib.import_module(".".join(rest))
req_class = getattr(req_module, cls)
func_call["spy"].assert_called_with(req_class(*args, **kwargs))
else:
func_call["spy"].assert_called_with(*args, **kwargs)
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet):
assert (
test_data.funding_source.client_field
), f"Missing client field for wallet {wallet}"
client_field = getattr(wallet, test_data.funding_source.client_field)
expected_calls: Dict[str, List] = {}
for mock in test_data.mocks:
spy = _spy_mock(mocker, mock, client_field)
expected_calls |= spy
return expected_calls
def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field):
expected_calls: Dict[str, List] = {}
assert isinstance(mock.response, dict), "Expected data RPC response"
for field_name in mock.response:
value = mock.response[field_name]
values = value if isinstance(value, list) else [value]
expected_calls[field_name] = [
{
"spy": mocker.spy(client_field, field_name),
"request_data": f["request_data"],
}
for f in values
if f["request_type"] == "function" and "request_data" in f
]
return expected_calls
def _mock_field(field):
response_type = field["response_type"]
request_type = field["request_type"]
response = field["response"]
if request_type == "data":
return _dict_to_object(response)
if request_type == "function":
if response_type == "data":
return _dict_to_object(response)
if response_type == "exception":
return _raise(response)
return response
def _dict_to_object(data: Optional[dict]) -> Optional[DataObject]:
if not data:
return None
d = {**data}
for k in data:
value = data[k]
if isinstance(value, dict):
d[k] = _dict_to_object(value)
return DataObject(**d)
def _data_mock(data: dict) -> Mock:
return Mock(return_value=_dict_to_object(data))
def _raise(error: dict):
data = error["data"] if "data" in error else None
if "module" not in error or "class" not in error:
return Exception(data)
error_module = importlib.import_module(error["module"])
error_class = getattr(error_module, error["class"])
return error_class(**data)