diff --git a/lnbits/wallets/clightning.py b/lnbits/wallets/clightning.py index bd2c313a4..bb9543896 100644 --- a/lnbits/wallets/clightning.py +++ b/lnbits/wallets/clightning.py @@ -5,21 +5,35 @@ except ImportError: # pragma: nocover import asyncio import random -import json - +from functools import partial, wraps from os import getenv -from typing import Optional, AsyncGenerator +from typing import AsyncGenerator, Optional from .base import ( - StatusResponse, InvoiceResponse, PaymentResponse, PaymentStatus, - Wallet, + StatusResponse, Unsupported, + Wallet, ) +def async_wrap(func): + @wraps(func) + async def run(*args, loop=None, executor=None, **kwargs): + if loop is None: + loop = asyncio.get_event_loop() + partial_func = partial(func, *args, **kwargs) + return await loop.run_in_executor(executor, partial_func) + + return run + + +def _paid_invoices_stream(ln, last_pay_index): + return ln.waitanyinvoice(last_pay_index) + + class CLightningWallet(Wallet): def __init__(self): if LightningRpc is None: # pragma: nocover @@ -115,21 +129,8 @@ class CLightningWallet(Wallet): raise KeyError("supplied an invalid checking_id") async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: - reader, writer = await asyncio.open_unix_connection(self.rpc) - - i = 0 while True: - call = json.dumps( - {"method": "waitanyinvoice", "id": 0, "params": [self.last_pay_index]} - ) - writer.write(call.encode()) - await writer.drain() - - data = await reader.read() - paid = json.loads(data.decode("ascii")) - - paid = self.ln.waitanyinvoice(self.last_pay_index) + wrapped = async_wrap(_paid_invoices_stream) + paid = await wrapped(self.ln, self.last_pay_index) self.last_pay_index = paid["pay_index"] yield paid["label"] - - i += 1