mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-03-15 12:20:21 +01:00
test: unit tests for lndrpc (#2442)
This commit is contained in:
parent
4f118c5f98
commit
67fdb77339
5 changed files with 1082 additions and 112 deletions
|
@ -116,7 +116,7 @@ class LndWallet(Wallet):
|
|||
try:
|
||||
resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest())
|
||||
except Exception as exc:
|
||||
return StatusResponse(str(exc), 0)
|
||||
return StatusResponse(f"Unable to connect, got: '{exc}'", 0)
|
||||
|
||||
return StatusResponse(None, resp.balance * 1000)
|
||||
|
||||
|
@ -147,6 +147,7 @@ class LndWallet(Wallet):
|
|||
req = ln.Invoice(**data)
|
||||
resp = await self.rpc.AddInvoice(req)
|
||||
except Exception as exc:
|
||||
logger.warning(exc)
|
||||
error_message = str(exc)
|
||||
return InvoiceResponse(False, None, None, error_message)
|
||||
|
||||
|
@ -165,6 +166,7 @@ class LndWallet(Wallet):
|
|||
try:
|
||||
resp = await self.routerpc.SendPaymentV2(req).read()
|
||||
except Exception as exc:
|
||||
logger.warning(exc)
|
||||
return PaymentResponse(False, None, None, None, str(exc))
|
||||
|
||||
# PaymentStatus from https://github.com/lightningnetwork/lnd/blob/master/channeldb/payments.go#L178
|
||||
|
@ -176,12 +178,12 @@ class LndWallet(Wallet):
|
|||
}
|
||||
|
||||
failure_reasons = {
|
||||
0: "No error given.",
|
||||
1: "Payment timed out.",
|
||||
2: "No route to destination.",
|
||||
3: "Error.",
|
||||
4: "Incorrect payment details.",
|
||||
5: "Insufficient balance.",
|
||||
0: "Payment failed: No error given.",
|
||||
1: "Payment failed: Payment timed out.",
|
||||
2: "Payment failed: No route to destination.",
|
||||
3: "Payment failed: Error.",
|
||||
4: "Payment failed: Incorrect payment details.",
|
||||
5: "Payment failed: Insufficient balance.",
|
||||
}
|
||||
|
||||
fee_msat = None
|
||||
|
@ -204,19 +206,23 @@ class LndWallet(Wallet):
|
|||
try:
|
||||
r_hash = hex_to_bytes(checking_id)
|
||||
if len(r_hash) != 32:
|
||||
# this may happen if we switch between backend wallets
|
||||
# that use different checking_id formats
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
# this may happen if we switch between backend wallets
|
||||
# that use different checking_id formats
|
||||
return PaymentPendingStatus()
|
||||
try:
|
||||
resp = await self.rpc.LookupInvoice(ln.PaymentHash(r_hash=r_hash))
|
||||
except grpc.RpcError:
|
||||
return PaymentPendingStatus()
|
||||
if resp.settled:
|
||||
return PaymentSuccessStatus()
|
||||
|
||||
return PaymentPendingStatus()
|
||||
resp = await self.rpc.LookupInvoice(ln.PaymentHash(r_hash=r_hash))
|
||||
|
||||
# todo: where is the FAILED status
|
||||
if resp.settled:
|
||||
return PaymentSuccessStatus()
|
||||
|
||||
return PaymentPendingStatus()
|
||||
except grpc.RpcError as exc:
|
||||
logger.warning(exc)
|
||||
return PaymentPendingStatus()
|
||||
except Exception as exc:
|
||||
logger.warning(exc)
|
||||
return PaymentPendingStatus()
|
||||
|
||||
async def get_payment_status(self, checking_id: str) -> PaymentStatus:
|
||||
"""
|
||||
|
@ -231,10 +237,6 @@ class LndWallet(Wallet):
|
|||
# that use different checking_id formats
|
||||
return PaymentPendingStatus()
|
||||
|
||||
resp = self.routerpc.TrackPaymentV2(
|
||||
router.TrackPaymentRequest(payment_hash=r_hash)
|
||||
)
|
||||
|
||||
# # HTLCAttempt.HTLCStatus:
|
||||
# # https://github.com/lightningnetwork/lnd/blob/master/lnrpc/lightning.proto#L3641
|
||||
# htlc_statuses = {
|
||||
|
@ -250,6 +252,9 @@ class LndWallet(Wallet):
|
|||
}
|
||||
|
||||
try:
|
||||
resp = self.routerpc.TrackPaymentV2(
|
||||
router.TrackPaymentRequest(payment_hash=r_hash)
|
||||
)
|
||||
async for payment in resp:
|
||||
if len(payment.htlcs) and statuses[payment.status]:
|
||||
return PaymentSuccessStatus(
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -7,7 +7,6 @@ class FundingSourceConfig(BaseModel):
|
|||
name: str
|
||||
skip: Optional[bool]
|
||||
wallet_class: str
|
||||
client_field: Optional[str]
|
||||
settings: dict
|
||||
|
||||
|
||||
|
@ -28,12 +27,16 @@ class TestMock(BaseModel):
|
|||
|
||||
|
||||
class Mock(FunctionMock, TestMock):
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def combine_mocks(fs_mock, test_mock):
|
||||
def combine_mocks(mock_name, fs_mock, test_mock):
|
||||
_mock = fs_mock | test_mock
|
||||
if "response" in _mock and "response" in fs_mock:
|
||||
_mock["response"] |= fs_mock["response"]
|
||||
return Mock(**_mock)
|
||||
m = Mock(name=mock_name, **_mock)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class FunctionMocks(BaseModel):
|
||||
|
@ -93,35 +96,58 @@ class WalletTest(BaseModel):
|
|||
return [t]
|
||||
|
||||
def _tests_from_fs_mocks(self, fn, test, fs_name: str) -> List["WalletTest"]:
|
||||
tests: List[WalletTest] = []
|
||||
|
||||
fs_mocks = fn["mocks"][fs_name]
|
||||
test_mocks = test["mocks"][fs_name]
|
||||
|
||||
for mock_name in fs_mocks:
|
||||
tests += self._tests_from_mocks(fs_mocks[mock_name], test_mocks[mock_name])
|
||||
return tests
|
||||
mocks = self._build_mock_objects(list(fs_mocks), fs_mocks, test_mocks)
|
||||
|
||||
def _tests_from_mocks(self, fs_mock, test_mocks) -> List["WalletTest"]:
|
||||
tests: List[WalletTest] = []
|
||||
for test_mock in test_mocks:
|
||||
# different mocks that result in the same
|
||||
# return value for the tested function
|
||||
unique_test = self._test_from_mocks(fs_mock, test_mock)
|
||||
return [self._tests_from_mock(m) for m in mocks]
|
||||
|
||||
tests.append(unique_test)
|
||||
return tests
|
||||
def _build_mock_objects(self, mock_names, fs_mocks, test_mocks):
|
||||
mocks = []
|
||||
|
||||
def _test_from_mocks(self, fs_mock, test_mock) -> "WalletTest":
|
||||
mock = Mock.combine_mocks(fs_mock, test_mock)
|
||||
for mock_name in mock_names:
|
||||
if mock_name not in test_mocks:
|
||||
continue
|
||||
for test_mock in test_mocks[mock_name]:
|
||||
mock = {"fs_mock": fs_mocks[mock_name], "test_mock": test_mock}
|
||||
|
||||
if len(mock_names) == 1:
|
||||
mocks.append({mock_name: mock})
|
||||
else:
|
||||
sub_mocks = self._build_mock_objects(
|
||||
mock_names[1:], fs_mocks, test_mocks
|
||||
)
|
||||
for sub_mock in sub_mocks:
|
||||
mocks.append({mock_name: mock} | sub_mock)
|
||||
return mocks
|
||||
|
||||
return mocks
|
||||
|
||||
def _tests_from_mock(self, mock_obj) -> "WalletTest":
|
||||
|
||||
test_mocks: List[Mock] = [
|
||||
Mock.combine_mocks(
|
||||
mock_name,
|
||||
mock_obj[mock_name]["fs_mock"],
|
||||
mock_obj[mock_name]["test_mock"],
|
||||
)
|
||||
for mock_name in mock_obj
|
||||
]
|
||||
|
||||
any_mock_skipped = len([m for m in test_mocks if m.skip])
|
||||
extra_description = ";".join(
|
||||
[m.description for m in test_mocks if m.description]
|
||||
)
|
||||
|
||||
return WalletTest(
|
||||
**(
|
||||
self.dict()
|
||||
| {
|
||||
"description": f"""{self.description}:{mock.description or ""}""",
|
||||
"mocks": [*self.mocks, mock],
|
||||
"skip": self.skip or mock.skip,
|
||||
"description": f"{self.description}:{extra_description}",
|
||||
"mocks": test_mocks,
|
||||
"skip": self.skip or any_mock_skipped,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
@ -131,3 +157,12 @@ class DataObject:
|
|||
def __init__(self, **kwargs):
|
||||
for k in kwargs:
|
||||
setattr(self, k, kwargs[k])
|
||||
|
||||
def __str__(self):
|
||||
data = []
|
||||
for k in self.__dict__:
|
||||
value = getattr(self, k)
|
||||
if isinstance(value, list):
|
||||
value = [f"{k}={v}" for v in value]
|
||||
data.append(f"{k}={value}")
|
||||
return ";".join(data)
|
||||
|
|
|
@ -55,7 +55,7 @@ def _tests_for_funding_source(
|
|||
|
||||
|
||||
def build_test_id(test: WalletTest):
|
||||
return f"{test.funding_source}.{test.function}({test.description})"
|
||||
return f"{test.funding_source.name}.{test.function}({test.description})"
|
||||
|
||||
|
||||
def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
|
||||
|
@ -83,7 +83,13 @@ async def check_assertions(wallet, _test_data: WalletTest):
|
|||
call_params = _test_data.call_params
|
||||
|
||||
if "expect" in test_data:
|
||||
await _assert_data(wallet, tested_func, call_params, _test_data.expect)
|
||||
await _assert_data(
|
||||
wallet,
|
||||
tested_func,
|
||||
call_params,
|
||||
_test_data.expect,
|
||||
_test_data.description,
|
||||
)
|
||||
# if len(_test_data.mocks) == 0:
|
||||
# # all calls should fail after this method is called
|
||||
# await wallet.cleanup()
|
||||
|
@ -96,14 +102,25 @@ async def check_assertions(wallet, _test_data: WalletTest):
|
|||
raise AssertionError("Expected outcome not specified")
|
||||
|
||||
|
||||
async def _assert_data(wallet, tested_func, call_params, expect):
|
||||
async def _assert_data(wallet, tested_func, call_params, expect, description):
|
||||
resp = await getattr(wallet, tested_func)(**call_params)
|
||||
fn_prefix = "__eval__:"
|
||||
for key in expect:
|
||||
received = getattr(resp, key)
|
||||
expected = expect[key]
|
||||
assert (
|
||||
getattr(resp, key) == expect[key]
|
||||
), f"""Field "{key}". Received: "{received}". Expected: "{expected}"."""
|
||||
if key.startswith(fn_prefix):
|
||||
key = key[len(fn_prefix) :]
|
||||
received = getattr(resp, key)
|
||||
expected = expected.format(**{key: received, "description": description})
|
||||
_assert = eval(expected)
|
||||
else:
|
||||
received = getattr(resp, key)
|
||||
_assert = getattr(resp, key) == expect[key]
|
||||
|
||||
assert _assert, (
|
||||
f""" Field "{key}"."""
|
||||
f""" Received: "{received}"."""
|
||||
f""" Expected: "{expected}"."""
|
||||
)
|
||||
|
||||
|
||||
async def _assert_error(wallet, tested_func, call_params, expect_error):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import importlib
|
||||
from typing import Dict, List, Optional
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from pytest_mock.plugin import MockerFixture
|
||||
|
@ -46,7 +46,12 @@ def _apply_rpc_mock(mocker: MockerFixture, mock: RpcMock):
|
|||
value = mock.response[field_name]
|
||||
values = value if isinstance(value, list) else [value]
|
||||
|
||||
return_value[field_name] = Mock(side_effect=[_mock_field(f) for f in values])
|
||||
_mock_class = (
|
||||
AsyncMock if values[0]["request_type"] == "async-function" else Mock
|
||||
)
|
||||
return_value[field_name] = _mock_class(
|
||||
side_effect=[_mock_field(f) for f in values]
|
||||
)
|
||||
|
||||
m = _data_mock(return_value)
|
||||
assert mock.method, "Missing method for RPC mock."
|
||||
|
@ -59,7 +64,8 @@ def _check_calls(expected_calls):
|
|||
for func_call in func_calls:
|
||||
req = func_call["request_data"]
|
||||
args = req["args"] if "args" in req else {}
|
||||
kwargs = req["kwargs"] if "kwargs" in req else {}
|
||||
kwargs = _eval_dict(req["kwargs"]) if "kwargs" in req else {}
|
||||
|
||||
if "klass" in req:
|
||||
*rest, cls = req["klass"].split(".")
|
||||
req_module = importlib.import_module(".".join(rest))
|
||||
|
@ -70,12 +76,9 @@ def _check_calls(expected_calls):
|
|||
|
||||
|
||||
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet):
|
||||
assert (
|
||||
test_data.funding_source.client_field
|
||||
), f"Missing client field for wallet {wallet}"
|
||||
client_field = getattr(wallet, test_data.funding_source.client_field)
|
||||
expected_calls: Dict[str, List] = {}
|
||||
for mock in test_data.mocks:
|
||||
client_field = getattr(wallet, mock.name)
|
||||
spy = _spy_mock(mocker, mock, client_field)
|
||||
expected_calls |= spy
|
||||
|
||||
|
@ -83,6 +86,7 @@ def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet)
|
|||
|
||||
|
||||
def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field):
|
||||
|
||||
expected_calls: Dict[str, List] = {}
|
||||
assert isinstance(mock.response, dict), "Expected data RPC response"
|
||||
for field_name in mock.response:
|
||||
|
@ -95,37 +99,95 @@ def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field):
|
|||
"request_data": f["request_data"],
|
||||
}
|
||||
for f in values
|
||||
if f["request_type"] == "function" and "request_data" in f
|
||||
if (
|
||||
f["request_type"] == "function" or f["request_type"] == "async-function"
|
||||
)
|
||||
and "request_data" in f
|
||||
]
|
||||
return expected_calls
|
||||
|
||||
|
||||
def _async_generator(data):
|
||||
async def f1():
|
||||
for d in data:
|
||||
value = _eval_dict(d)
|
||||
yield _dict_to_object(value)
|
||||
|
||||
return f1()
|
||||
|
||||
|
||||
def _mock_field(field):
|
||||
response_type = field["response_type"]
|
||||
request_type = field["request_type"]
|
||||
response = field["response"]
|
||||
response = _eval_dict(field["response"])
|
||||
|
||||
if request_type == "data":
|
||||
return _dict_to_object(response)
|
||||
|
||||
if request_type == "function":
|
||||
if request_type == "function" or request_type == "async-function":
|
||||
if response_type == "data":
|
||||
return _dict_to_object(response)
|
||||
|
||||
if response_type == "exception":
|
||||
return _raise(response)
|
||||
|
||||
if response_type == "__aiter__":
|
||||
# todo: support dict
|
||||
return _async_generator(field["response"])
|
||||
|
||||
if response_type == "function" or response_type == "async-function":
|
||||
return_value = {}
|
||||
for field_name in field["response"]:
|
||||
value = field["response"][field_name]
|
||||
_mock_class = (
|
||||
AsyncMock if value["request_type"] == "async-function" else Mock
|
||||
)
|
||||
|
||||
return_value[field_name] = _mock_class(side_effect=[_mock_field(value)])
|
||||
|
||||
return _dict_to_object(return_value)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _eval_dict(data: Optional[dict]) -> Optional[dict]:
|
||||
fn_prefix = "__eval__:"
|
||||
if not data:
|
||||
return data
|
||||
# if isinstance(data, list):
|
||||
# return [_eval_dict(i) for i in data]
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
d = {}
|
||||
for k in data:
|
||||
if k.startswith(fn_prefix):
|
||||
field = k[len(fn_prefix) :]
|
||||
d[field] = eval(data[k])
|
||||
elif isinstance(data[k], dict):
|
||||
d[k] = _eval_dict(data[k])
|
||||
elif isinstance(data[k], list):
|
||||
d[k] = [_eval_dict(i) for i in data[k]]
|
||||
else:
|
||||
d[k] = data[k]
|
||||
return d
|
||||
|
||||
|
||||
def _dict_to_object(data: Optional[dict]) -> Optional[DataObject]:
|
||||
if not data:
|
||||
return None
|
||||
# if isinstance(data, list):
|
||||
# return [_dict_to_object(i) for i in data]
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
d = {**data}
|
||||
for k in data:
|
||||
value = data[k]
|
||||
if isinstance(value, dict):
|
||||
d[k] = _dict_to_object(value)
|
||||
elif isinstance(value, list):
|
||||
d[k] = [_dict_to_object(v) for v in value]
|
||||
|
||||
return DataObject(**d)
|
||||
|
||||
|
@ -134,7 +196,9 @@ def _data_mock(data: dict) -> Mock:
|
|||
return Mock(return_value=_dict_to_object(data))
|
||||
|
||||
|
||||
def _raise(error: dict):
|
||||
def _raise(error: Optional[dict]):
|
||||
if not error:
|
||||
return Exception()
|
||||
data = error["data"] if "data" in error else None
|
||||
if "module" not in error or "class" not in error:
|
||||
return Exception(data)
|
||||
|
|
Loading…
Add table
Reference in a new issue