From ddb8fcb9865d1475d6b4064f7bb2a3b9a6ea5dc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Wed, 7 Aug 2024 09:57:15 +0200 Subject: [PATCH] feat: add typing for tasks (#2629) * feat: add typing for tasks * fixup! --- lnbits/tasks.py | 60 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/lnbits/tasks.py b/lnbits/tasks.py index 36a00ef30..e435fabb9 100644 --- a/lnbits/tasks.py +++ b/lnbits/tasks.py @@ -4,7 +4,13 @@ import time import traceback import uuid from http import HTTPStatus -from typing import Dict, List, Optional +from typing import ( + Callable, + Coroutine, + Dict, + List, + Optional, +) from loguru import logger from py_vapid import Vapid @@ -17,7 +23,7 @@ from lnbits.core.crud import ( update_payment_details, update_payment_status, ) -from lnbits.core.models import PaymentState +from lnbits.core.models import Payment, PaymentState from lnbits.settings import settings from lnbits.wallets import get_funding_source @@ -25,17 +31,13 @@ tasks: List[asyncio.Task] = [] unique_tasks: Dict[str, asyncio.Task] = {} -def create_task(coro): +def create_task(coro: Coroutine) -> asyncio.Task: task = asyncio.create_task(coro) tasks.append(task) return task -def create_permanent_task(func): - return create_task(catch_everything_and_restart(func)) - - -def create_unique_task(name: str, coro): +def create_unique_task(name: str, coro: Coroutine) -> asyncio.Task: if unique_tasks.get(name): logger.warning(f"task `{name}` already exists, cancelling it") try: @@ -47,11 +49,17 @@ def create_unique_task(name: str, coro): return task -def create_permanent_unique_task(name: str, coro): +def create_permanent_task(func: Callable[[], Coroutine]) -> asyncio.Task: + return create_task(catch_everything_and_restart(func)) + + +def create_permanent_unique_task( + name: str, coro: Callable[[], Coroutine] +) -> asyncio.Task: return create_unique_task(name, catch_everything_and_restart(coro, name)) -def cancel_all_tasks(): +def cancel_all_tasks() -> None: for task in tasks: try: task.cancel() @@ -64,9 +72,12 @@ def cancel_all_tasks(): logger.warning(f"error while cancelling task `{name}`: {exc!s}") -async def catch_everything_and_restart(func, name: str = "unnamed"): +async def catch_everything_and_restart( + func: Callable[[], Coroutine], + name: str = "unnamed", +) -> Coroutine: try: - await func() + return await func() except asyncio.CancelledError: raise # because we must pass this up except Exception as exc: @@ -74,7 +85,7 @@ async def catch_everything_and_restart(func, name: str = "unnamed"): logger.error(traceback.format_exc()) logger.error("will restart the task in 5 seconds.") await asyncio.sleep(5) - await catch_everything_and_restart(func, name) + return catch_everything_and_restart(func, name) invoice_listeners: Dict[str, asyncio.Queue] = {} @@ -101,7 +112,7 @@ def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = No internal_invoice_queue: asyncio.Queue = asyncio.Queue(0) -async def internal_invoice_listener(): +async def internal_invoice_listener() -> None: """ internal_invoice_queue will be filled directly in core/services.py after the payment was deemed to be settled internally. @@ -111,10 +122,10 @@ async def internal_invoice_listener(): while settings.lnbits_running: checking_id = await internal_invoice_queue.get() logger.info(f"got an internal payment notification {checking_id}") - create_task(invoice_callback_dispatcher(checking_id, is_internal=True)) + await invoice_callback_dispatcher(checking_id, is_internal=True) -async def invoice_listener(): +async def invoice_listener() -> None: """ invoice_listener will collect all invoices that come directly from the backend wallet. @@ -124,7 +135,22 @@ async def invoice_listener(): funding_source = get_funding_source() async for checking_id in funding_source.paid_invoices_stream(): logger.info(f"got a payment notification {checking_id}") - create_task(invoice_callback_dispatcher(checking_id)) + await invoice_callback_dispatcher(checking_id) + + +def wait_for_paid_invoices( + invoice_listener_name: str, + func: Callable[[Payment], Coroutine], +) -> Callable[[], Coroutine]: + + async def wrapper() -> None: + invoice_queue: asyncio.Queue = asyncio.Queue() + register_invoice_listener(invoice_queue, invoice_listener_name) + while settings.lnbits_running: + payment = await invoice_queue.get() + await func(payment) + + return wrapper async def check_pending_payments():