feat: add typing for tasks (#2629)

* feat: add typing for tasks

* fixup!
This commit is contained in:
dni ⚡ 2024-08-07 09:57:15 +02:00 committed by GitHub
parent 27b9e8254c
commit ddb8fcb986
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,7 +4,13 @@ import time
import traceback
import uuid
from http import HTTPStatus
from typing import Dict, List, Optional
from typing import (
Callable,
Coroutine,
Dict,
List,
Optional,
)
from loguru import logger
from py_vapid import Vapid
@ -17,7 +23,7 @@ from lnbits.core.crud import (
update_payment_details,
update_payment_status,
)
from lnbits.core.models import PaymentState
from lnbits.core.models import Payment, PaymentState
from lnbits.settings import settings
from lnbits.wallets import get_funding_source
@ -25,17 +31,13 @@ tasks: List[asyncio.Task] = []
unique_tasks: Dict[str, asyncio.Task] = {}
def create_task(coro):
def create_task(coro: Coroutine) -> asyncio.Task:
task = asyncio.create_task(coro)
tasks.append(task)
return task
def create_permanent_task(func):
return create_task(catch_everything_and_restart(func))
def create_unique_task(name: str, coro):
def create_unique_task(name: str, coro: Coroutine) -> asyncio.Task:
if unique_tasks.get(name):
logger.warning(f"task `{name}` already exists, cancelling it")
try:
@ -47,11 +49,17 @@ def create_unique_task(name: str, coro):
return task
def create_permanent_unique_task(name: str, coro):
def create_permanent_task(func: Callable[[], Coroutine]) -> asyncio.Task:
return create_task(catch_everything_and_restart(func))
def create_permanent_unique_task(
name: str, coro: Callable[[], Coroutine]
) -> asyncio.Task:
return create_unique_task(name, catch_everything_and_restart(coro, name))
def cancel_all_tasks():
def cancel_all_tasks() -> None:
for task in tasks:
try:
task.cancel()
@ -64,9 +72,12 @@ def cancel_all_tasks():
logger.warning(f"error while cancelling task `{name}`: {exc!s}")
async def catch_everything_and_restart(func, name: str = "unnamed"):
async def catch_everything_and_restart(
func: Callable[[], Coroutine],
name: str = "unnamed",
) -> Coroutine:
try:
await func()
return await func()
except asyncio.CancelledError:
raise # because we must pass this up
except Exception as exc:
@ -74,7 +85,7 @@ async def catch_everything_and_restart(func, name: str = "unnamed"):
logger.error(traceback.format_exc())
logger.error("will restart the task in 5 seconds.")
await asyncio.sleep(5)
await catch_everything_and_restart(func, name)
return catch_everything_and_restart(func, name)
invoice_listeners: Dict[str, asyncio.Queue] = {}
@ -101,7 +112,7 @@ def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = No
internal_invoice_queue: asyncio.Queue = asyncio.Queue(0)
async def internal_invoice_listener():
async def internal_invoice_listener() -> None:
"""
internal_invoice_queue will be filled directly in core/services.py
after the payment was deemed to be settled internally.
@ -111,10 +122,10 @@ async def internal_invoice_listener():
while settings.lnbits_running:
checking_id = await internal_invoice_queue.get()
logger.info(f"got an internal payment notification {checking_id}")
create_task(invoice_callback_dispatcher(checking_id, is_internal=True))
await invoice_callback_dispatcher(checking_id, is_internal=True)
async def invoice_listener():
async def invoice_listener() -> None:
"""
invoice_listener will collect all invoices that come directly
from the backend wallet.
@ -124,7 +135,22 @@ async def invoice_listener():
funding_source = get_funding_source()
async for checking_id in funding_source.paid_invoices_stream():
logger.info(f"got a payment notification {checking_id}")
create_task(invoice_callback_dispatcher(checking_id))
await invoice_callback_dispatcher(checking_id)
def wait_for_paid_invoices(
invoice_listener_name: str,
func: Callable[[Payment], Coroutine],
) -> Callable[[], Coroutine]:
async def wrapper() -> None:
invoice_queue: asyncio.Queue = asyncio.Queue()
register_invoice_listener(invoice_queue, invoice_listener_name)
while settings.lnbits_running:
payment = await invoice_queue.get()
await func(payment)
return wrapper
async def check_pending_payments():