test: unit tests for lndrpc (#2442)

This commit is contained in:
Vlad Stan 2024-04-19 14:21:21 +03:00 committed by GitHub
parent 4f118c5f98
commit 67fdb77339
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 1082 additions and 112 deletions

View file

@ -116,7 +116,7 @@ class LndWallet(Wallet):
try:
resp = await self.rpc.ChannelBalance(ln.ChannelBalanceRequest())
except Exception as exc:
return StatusResponse(str(exc), 0)
return StatusResponse(f"Unable to connect, got: '{exc}'", 0)
return StatusResponse(None, resp.balance * 1000)
@ -147,6 +147,7 @@ class LndWallet(Wallet):
req = ln.Invoice(**data)
resp = await self.rpc.AddInvoice(req)
except Exception as exc:
logger.warning(exc)
error_message = str(exc)
return InvoiceResponse(False, None, None, error_message)
@ -165,6 +166,7 @@ class LndWallet(Wallet):
try:
resp = await self.routerpc.SendPaymentV2(req).read()
except Exception as exc:
logger.warning(exc)
return PaymentResponse(False, None, None, None, str(exc))
# PaymentStatus from https://github.com/lightningnetwork/lnd/blob/master/channeldb/payments.go#L178
@ -176,12 +178,12 @@ class LndWallet(Wallet):
}
failure_reasons = {
0: "No error given.",
1: "Payment timed out.",
2: "No route to destination.",
3: "Error.",
4: "Incorrect payment details.",
5: "Insufficient balance.",
0: "Payment failed: No error given.",
1: "Payment failed: Payment timed out.",
2: "Payment failed: No route to destination.",
3: "Payment failed: Error.",
4: "Payment failed: Incorrect payment details.",
5: "Payment failed: Insufficient balance.",
}
fee_msat = None
@ -204,19 +206,23 @@ class LndWallet(Wallet):
try:
r_hash = hex_to_bytes(checking_id)
if len(r_hash) != 32:
# this may happen if we switch between backend wallets
# that use different checking_id formats
raise ValueError
except ValueError:
# this may happen if we switch between backend wallets
# that use different checking_id formats
return PaymentPendingStatus()
try:
resp = await self.rpc.LookupInvoice(ln.PaymentHash(r_hash=r_hash))
except grpc.RpcError:
return PaymentPendingStatus()
if resp.settled:
return PaymentSuccessStatus()
return PaymentPendingStatus()
resp = await self.rpc.LookupInvoice(ln.PaymentHash(r_hash=r_hash))
# todo: where is the FAILED status
if resp.settled:
return PaymentSuccessStatus()
return PaymentPendingStatus()
except grpc.RpcError as exc:
logger.warning(exc)
return PaymentPendingStatus()
except Exception as exc:
logger.warning(exc)
return PaymentPendingStatus()
async def get_payment_status(self, checking_id: str) -> PaymentStatus:
"""
@ -231,10 +237,6 @@ class LndWallet(Wallet):
# that use different checking_id formats
return PaymentPendingStatus()
resp = self.routerpc.TrackPaymentV2(
router.TrackPaymentRequest(payment_hash=r_hash)
)
# # HTLCAttempt.HTLCStatus:
# # https://github.com/lightningnetwork/lnd/blob/master/lnrpc/lightning.proto#L3641
# htlc_statuses = {
@ -250,6 +252,9 @@ class LndWallet(Wallet):
}
try:
resp = self.routerpc.TrackPaymentV2(
router.TrackPaymentRequest(payment_hash=r_hash)
)
async for payment in resp:
if len(payment.htlcs) and statuses[payment.status]:
return PaymentSuccessStatus(

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,6 @@ class FundingSourceConfig(BaseModel):
name: str
skip: Optional[bool]
wallet_class: str
client_field: Optional[str]
settings: dict
@ -28,12 +27,16 @@ class TestMock(BaseModel):
class Mock(FunctionMock, TestMock):
name: str
@staticmethod
def combine_mocks(fs_mock, test_mock):
def combine_mocks(mock_name, fs_mock, test_mock):
_mock = fs_mock | test_mock
if "response" in _mock and "response" in fs_mock:
_mock["response"] |= fs_mock["response"]
return Mock(**_mock)
m = Mock(name=mock_name, **_mock)
return m
class FunctionMocks(BaseModel):
@ -93,35 +96,58 @@ class WalletTest(BaseModel):
return [t]
def _tests_from_fs_mocks(self, fn, test, fs_name: str) -> List["WalletTest"]:
tests: List[WalletTest] = []
fs_mocks = fn["mocks"][fs_name]
test_mocks = test["mocks"][fs_name]
for mock_name in fs_mocks:
tests += self._tests_from_mocks(fs_mocks[mock_name], test_mocks[mock_name])
return tests
mocks = self._build_mock_objects(list(fs_mocks), fs_mocks, test_mocks)
def _tests_from_mocks(self, fs_mock, test_mocks) -> List["WalletTest"]:
tests: List[WalletTest] = []
for test_mock in test_mocks:
# different mocks that result in the same
# return value for the tested function
unique_test = self._test_from_mocks(fs_mock, test_mock)
return [self._tests_from_mock(m) for m in mocks]
tests.append(unique_test)
return tests
def _build_mock_objects(self, mock_names, fs_mocks, test_mocks):
mocks = []
def _test_from_mocks(self, fs_mock, test_mock) -> "WalletTest":
mock = Mock.combine_mocks(fs_mock, test_mock)
for mock_name in mock_names:
if mock_name not in test_mocks:
continue
for test_mock in test_mocks[mock_name]:
mock = {"fs_mock": fs_mocks[mock_name], "test_mock": test_mock}
if len(mock_names) == 1:
mocks.append({mock_name: mock})
else:
sub_mocks = self._build_mock_objects(
mock_names[1:], fs_mocks, test_mocks
)
for sub_mock in sub_mocks:
mocks.append({mock_name: mock} | sub_mock)
return mocks
return mocks
def _tests_from_mock(self, mock_obj) -> "WalletTest":
test_mocks: List[Mock] = [
Mock.combine_mocks(
mock_name,
mock_obj[mock_name]["fs_mock"],
mock_obj[mock_name]["test_mock"],
)
for mock_name in mock_obj
]
any_mock_skipped = len([m for m in test_mocks if m.skip])
extra_description = ";".join(
[m.description for m in test_mocks if m.description]
)
return WalletTest(
**(
self.dict()
| {
"description": f"""{self.description}:{mock.description or ""}""",
"mocks": [*self.mocks, mock],
"skip": self.skip or mock.skip,
"description": f"{self.description}:{extra_description}",
"mocks": test_mocks,
"skip": self.skip or any_mock_skipped,
}
)
)
@ -131,3 +157,12 @@ class DataObject:
def __init__(self, **kwargs):
for k in kwargs:
setattr(self, k, kwargs[k])
def __str__(self):
data = []
for k in self.__dict__:
value = getattr(self, k)
if isinstance(value, list):
value = [f"{k}={v}" for v in value]
data.append(f"{k}={value}")
return ";".join(data)

View file

@ -55,7 +55,7 @@ def _tests_for_funding_source(
def build_test_id(test: WalletTest):
return f"{test.funding_source}.{test.function}({test.description})"
return f"{test.funding_source.name}.{test.function}({test.description})"
def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
@ -83,7 +83,13 @@ async def check_assertions(wallet, _test_data: WalletTest):
call_params = _test_data.call_params
if "expect" in test_data:
await _assert_data(wallet, tested_func, call_params, _test_data.expect)
await _assert_data(
wallet,
tested_func,
call_params,
_test_data.expect,
_test_data.description,
)
# if len(_test_data.mocks) == 0:
# # all calls should fail after this method is called
# await wallet.cleanup()
@ -96,14 +102,25 @@ async def check_assertions(wallet, _test_data: WalletTest):
raise AssertionError("Expected outcome not specified")
async def _assert_data(wallet, tested_func, call_params, expect):
async def _assert_data(wallet, tested_func, call_params, expect, description):
resp = await getattr(wallet, tested_func)(**call_params)
fn_prefix = "__eval__:"
for key in expect:
received = getattr(resp, key)
expected = expect[key]
assert (
getattr(resp, key) == expect[key]
), f"""Field "{key}". Received: "{received}". Expected: "{expected}"."""
if key.startswith(fn_prefix):
key = key[len(fn_prefix) :]
received = getattr(resp, key)
expected = expected.format(**{key: received, "description": description})
_assert = eval(expected)
else:
received = getattr(resp, key)
_assert = getattr(resp, key) == expect[key]
assert _assert, (
f""" Field "{key}"."""
f""" Received: "{received}"."""
f""" Expected: "{expected}"."""
)
async def _assert_error(wallet, tested_func, call_params, expect_error):

View file

@ -1,6 +1,6 @@
import importlib
from typing import Dict, List, Optional
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
import pytest
from pytest_mock.plugin import MockerFixture
@ -46,7 +46,12 @@ def _apply_rpc_mock(mocker: MockerFixture, mock: RpcMock):
value = mock.response[field_name]
values = value if isinstance(value, list) else [value]
return_value[field_name] = Mock(side_effect=[_mock_field(f) for f in values])
_mock_class = (
AsyncMock if values[0]["request_type"] == "async-function" else Mock
)
return_value[field_name] = _mock_class(
side_effect=[_mock_field(f) for f in values]
)
m = _data_mock(return_value)
assert mock.method, "Missing method for RPC mock."
@ -59,7 +64,8 @@ def _check_calls(expected_calls):
for func_call in func_calls:
req = func_call["request_data"]
args = req["args"] if "args" in req else {}
kwargs = req["kwargs"] if "kwargs" in req else {}
kwargs = _eval_dict(req["kwargs"]) if "kwargs" in req else {}
if "klass" in req:
*rest, cls = req["klass"].split(".")
req_module = importlib.import_module(".".join(rest))
@ -70,12 +76,9 @@ def _check_calls(expected_calls):
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet):
assert (
test_data.funding_source.client_field
), f"Missing client field for wallet {wallet}"
client_field = getattr(wallet, test_data.funding_source.client_field)
expected_calls: Dict[str, List] = {}
for mock in test_data.mocks:
client_field = getattr(wallet, mock.name)
spy = _spy_mock(mocker, mock, client_field)
expected_calls |= spy
@ -83,6 +86,7 @@ def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet)
def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field):
expected_calls: Dict[str, List] = {}
assert isinstance(mock.response, dict), "Expected data RPC response"
for field_name in mock.response:
@ -95,37 +99,95 @@ def _spy_mock(mocker: MockerFixture, mock: RpcMock, client_field):
"request_data": f["request_data"],
}
for f in values
if f["request_type"] == "function" and "request_data" in f
if (
f["request_type"] == "function" or f["request_type"] == "async-function"
)
and "request_data" in f
]
return expected_calls
def _async_generator(data):
async def f1():
for d in data:
value = _eval_dict(d)
yield _dict_to_object(value)
return f1()
def _mock_field(field):
response_type = field["response_type"]
request_type = field["request_type"]
response = field["response"]
response = _eval_dict(field["response"])
if request_type == "data":
return _dict_to_object(response)
if request_type == "function":
if request_type == "function" or request_type == "async-function":
if response_type == "data":
return _dict_to_object(response)
if response_type == "exception":
return _raise(response)
if response_type == "__aiter__":
# todo: support dict
return _async_generator(field["response"])
if response_type == "function" or response_type == "async-function":
return_value = {}
for field_name in field["response"]:
value = field["response"][field_name]
_mock_class = (
AsyncMock if value["request_type"] == "async-function" else Mock
)
return_value[field_name] = _mock_class(side_effect=[_mock_field(value)])
return _dict_to_object(return_value)
return response
def _eval_dict(data: Optional[dict]) -> Optional[dict]:
fn_prefix = "__eval__:"
if not data:
return data
# if isinstance(data, list):
# return [_eval_dict(i) for i in data]
if not isinstance(data, dict):
return data
d = {}
for k in data:
if k.startswith(fn_prefix):
field = k[len(fn_prefix) :]
d[field] = eval(data[k])
elif isinstance(data[k], dict):
d[k] = _eval_dict(data[k])
elif isinstance(data[k], list):
d[k] = [_eval_dict(i) for i in data[k]]
else:
d[k] = data[k]
return d
def _dict_to_object(data: Optional[dict]) -> Optional[DataObject]:
if not data:
return None
# if isinstance(data, list):
# return [_dict_to_object(i) for i in data]
if not isinstance(data, dict):
return data
d = {**data}
for k in data:
value = data[k]
if isinstance(value, dict):
d[k] = _dict_to_object(value)
elif isinstance(value, list):
d[k] = [_dict_to_object(v) for v in value]
return DataObject(**d)
@ -134,7 +196,9 @@ def _data_mock(data: dict) -> Mock:
return Mock(return_value=_dict_to_object(data))
def _raise(error: dict):
def _raise(error: Optional[dict]):
if not error:
return Exception()
data = error["data"] if "data" in error else None
if "module" not in error or "class" not in error:
return Exception(data)