cancel all long-running tasks (#1793)

* add centralized task management

in order to properly cleanup all long-running tasks we have to keep a list of them

* use new task management functions

* unify shutdown events

* vlads suggestions

rename variable for create_task
wrap cancel() with try/catch

fixup

* rename func to coro

---------

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
jackstar12 2023-08-18 11:25:33 +02:00 committed by GitHub
parent 65db43ace4
commit 1fd4d9d514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 52 deletions

View File

@ -28,9 +28,9 @@ from lnbits.core.services import websocketUpdater
from lnbits.core.tasks import ( # register_watchdog,; unregister_watchdog,
register_killswitch,
register_task_listeners,
unregister_killswitch,
)
from lnbits.settings import settings
from lnbits.tasks import cancel_all_tasks, create_permanent_task
from lnbits.wallets import get_wallet_class, set_wallet_class
from .commands import db_versions, load_disabled_extension_list, migrate_databases
@ -52,7 +52,6 @@ from .middleware import (
)
from .requestvars import g
from .tasks import (
catch_everything_and_restart,
check_pending_payments,
internal_invoice_listener,
invoice_listener,
@ -366,6 +365,9 @@ def register_startup(app: FastAPI):
def register_shutdown(app: FastAPI):
@app.on_event("shutdown")
async def on_shutdown():
cancel_all_tasks()
# wait a bit to allow them to finish, so that cleanup can run without problems
await asyncio.sleep(0.1)
WALLET = get_wallet_class()
await WALLET.cleanup()
@ -380,7 +382,7 @@ def initialize_server_logger():
msg = await serverlog_queue.get()
await websocketUpdater(super_user_hash, msg)
asyncio.create_task(update_websocket_serverlog())
create_permanent_task(update_websocket_serverlog)
logger.add(
lambda msg: serverlog_queue.put_nowait(msg),
@ -421,21 +423,13 @@ def register_async_tasks(app):
@app.on_event("startup")
async def listeners():
loop = asyncio.get_event_loop()
loop.create_task(catch_everything_and_restart(check_pending_payments))
loop.create_task(catch_everything_and_restart(invoice_listener))
loop.create_task(catch_everything_and_restart(internal_invoice_listener))
await register_task_listeners()
# await register_watchdog()
await register_killswitch()
create_permanent_task(check_pending_payments)
create_permanent_task(invoice_listener)
create_permanent_task(internal_invoice_listener)
register_task_listeners()
register_killswitch()
# await run_deferred_async() # calle: doesn't do anyting?
@app.on_event("shutdown")
async def stop_listeners():
# await unregister_watchdog()
await unregister_killswitch()
pass
def register_exception_handlers(app: FastAPI):
@app.exception_handler(Exception)

View File

@ -1,11 +1,16 @@
import asyncio
from typing import Dict, Optional
from typing import Dict
import httpx
from loguru import logger
from lnbits.settings import get_wallet_class, settings
from lnbits.tasks import SseListenersDict, register_invoice_listener
from lnbits.tasks import (
SseListenersDict,
create_permanent_task,
create_task,
register_invoice_listener,
)
from . import db
from .crud import get_balance_notify, get_wallet
@ -16,28 +21,14 @@ api_invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict(
"api_invoice_listeners"
)
killswitch: Optional[asyncio.Task] = None
watchdog: Optional[asyncio.Task] = None
async def register_killswitch():
def register_killswitch():
"""
Registers a killswitch which will check lnbits-status repository
for a signal from LNbits and will switch to VoidWallet if the killswitch is triggered.
Registers a killswitch which will check lnbits-status repository for a signal from
LNbits and will switch to VoidWallet if the killswitch is triggered.
"""
logger.debug("Starting killswitch task")
global killswitch
killswitch = asyncio.create_task(killswitch_task())
async def unregister_killswitch():
"""
Unregisters a killswitch taskl
"""
global killswitch
if killswitch:
logger.debug("Stopping killswitch task")
killswitch.cancel()
create_permanent_task(killswitch_task)
async def killswitch_task():
@ -67,20 +58,9 @@ async def register_watchdog():
Registers a watchdog which will check lnbits balance and nodebalance
and will switch to VoidWallet if the watchdog delta is reached.
"""
# TODO: implement watchdog porperly
# TODO: implement watchdog properly
# logger.debug("Starting watchdog task")
# global watchdog
# watchdog = asyncio.create_task(watchdog_task())
async def unregister_watchdog():
"""
Unregisters a watchdog task
"""
global watchdog
if watchdog:
logger.debug("Stopping watchdog task")
watchdog.cancel()
# create_permanent_task(watchdog_task)
async def watchdog_task():
@ -98,7 +78,7 @@ async def watchdog_task():
await asyncio.sleep(settings.lnbits_watchdog_interval * 60)
async def register_task_listeners():
def register_task_listeners():
"""
Registers an invoice listener queue for the core tasks.
Incoming payaments in this queue will eventually trigger the signals sent to all other extensions
@ -108,7 +88,7 @@ async def register_task_listeners():
# we register invoice_paid_queue to receive all incoming invoices
register_invoice_listener(invoice_paid_queue, "core/tasks.py")
# register a worker that will react to invoices
asyncio.create_task(wait_for_paid_invoices(invoice_paid_queue))
create_task(wait_for_paid_invoices(invoice_paid_queue))
async def wait_for_paid_invoices(invoice_paid_queue: asyncio.Queue):

View File

@ -3,7 +3,7 @@ import time
import traceback
import uuid
from http import HTTPStatus
from typing import Dict, Optional
from typing import Dict, List, Optional
from fastapi.exceptions import HTTPException
from loguru import logger
@ -19,6 +19,26 @@ from lnbits.wallets import get_wallet_class
from .core import db
tasks: List[asyncio.Task] = []
def create_task(coro):
task = asyncio.create_task(coro)
tasks.append(task)
return task
def create_permanent_task(func):
return create_task(catch_everything_and_restart(func))
def cancel_all_tasks():
for task in tasks:
try:
task.cancel()
except Exception as exc:
logger.warning(f"error while cancelling task: {str(exc)}")
async def catch_everything_and_restart(func):
try: