2022-12-22 16:30:37 +02:00
|
|
|
import importlib
|
|
|
|
import re
|
2023-12-12 12:38:19 +02:00
|
|
|
from typing import Any, Optional
|
2023-05-09 11:22:19 +03:00
|
|
|
from uuid import UUID
|
2022-12-22 16:30:37 +02:00
|
|
|
|
2023-02-15 15:34:09 +02:00
|
|
|
import httpx
|
2022-12-22 16:30:37 +02:00
|
|
|
from loguru import logger
|
|
|
|
|
2023-09-12 12:25:05 +02:00
|
|
|
from lnbits.core.db import db as core_db
|
2023-01-20 17:08:44 +02:00
|
|
|
from lnbits.db import Connection
|
2023-01-20 10:06:32 +02:00
|
|
|
from lnbits.extension_manager import Extension
|
2023-02-15 17:25:58 +02:00
|
|
|
from lnbits.settings import settings
|
2023-01-11 10:57:19 +02:00
|
|
|
|
2022-12-22 16:30:37 +02:00
|
|
|
from .crud import update_migration_version
|
|
|
|
|
|
|
|
|
2023-01-11 10:57:19 +02:00
|
|
|
async def migrate_extension_database(ext: Extension, current_version):
|
2022-12-22 16:30:37 +02:00
|
|
|
try:
|
2022-11-30 10:56:23 +02:00
|
|
|
ext_migrations = importlib.import_module(f"{ext.module_name}.migrations")
|
|
|
|
ext_db = importlib.import_module(ext.module_name).db
|
2024-04-17 13:11:51 +02:00
|
|
|
except ImportError as exc:
|
|
|
|
logger.error(exc)
|
2022-12-22 16:30:37 +02:00
|
|
|
raise ImportError(
|
|
|
|
f"Please make sure that the extension `{ext.code}` has a migrations file."
|
2024-04-17 13:11:51 +02:00
|
|
|
) from exc
|
2022-12-22 16:30:37 +02:00
|
|
|
|
|
|
|
async with ext_db.connect() as ext_conn:
|
2023-10-17 11:56:38 +03:00
|
|
|
await run_migration(ext_conn, ext_migrations, ext.code, current_version)
|
2022-12-22 16:30:37 +02:00
|
|
|
|
|
|
|
|
2023-10-17 11:56:38 +03:00
|
|
|
async def run_migration(
|
|
|
|
db: Connection, migrations_module: Any, db_name: str, current_version: int
|
|
|
|
):
|
2022-12-22 16:30:37 +02:00
|
|
|
matcher = re.compile(r"^m(\d\d\d)_")
|
|
|
|
for key, migrate in migrations_module.__dict__.items():
|
2023-01-20 16:11:51 +02:00
|
|
|
match = matcher.match(key)
|
2022-12-22 16:30:37 +02:00
|
|
|
if match:
|
|
|
|
version = int(match.group(1))
|
|
|
|
if version > current_version:
|
|
|
|
logger.debug(f"running migration {db_name}.{version}")
|
|
|
|
print(f"running migration {db_name}.{version}")
|
|
|
|
await migrate(db)
|
|
|
|
|
2023-01-21 15:08:59 +00:00
|
|
|
if db.schema is None:
|
2022-12-22 16:30:37 +02:00
|
|
|
await update_migration_version(db, db_name, version)
|
|
|
|
else:
|
|
|
|
async with core_db.connect() as conn:
|
|
|
|
await update_migration_version(conn, db_name, version)
|
2023-02-15 15:34:09 +02:00
|
|
|
|
|
|
|
|
2023-12-12 12:38:19 +02:00
|
|
|
async def stop_extension_background_work(
|
|
|
|
ext_id: str, user: str, access_token: Optional[str] = None
|
|
|
|
):
|
2023-02-15 17:25:58 +02:00
|
|
|
"""
|
2023-02-15 17:32:49 +02:00
|
|
|
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
|
2024-02-21 12:08:37 +02:00
|
|
|
Extensions SHOULD expose a `api_stop()` function and/or a DELETE enpoint
|
|
|
|
at the root level of their API.
|
2023-02-15 17:25:58 +02:00
|
|
|
"""
|
2024-02-21 12:08:37 +02:00
|
|
|
stopped = await _stop_extension_background_work(ext_id)
|
|
|
|
|
|
|
|
if not stopped:
|
|
|
|
# fallback to REST API call
|
|
|
|
await _stop_extension_background_work_via_api(ext_id, user, access_token)
|
|
|
|
|
|
|
|
|
|
|
|
async def _stop_extension_background_work(ext_id) -> bool:
|
|
|
|
upgrade_hash = settings.extension_upgrade_hash(ext_id) or ""
|
|
|
|
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
|
|
|
|
|
|
|
|
try:
|
|
|
|
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
|
|
|
|
old_module = importlib.import_module(ext.module_name)
|
|
|
|
|
|
|
|
# Extensions must expose an `{ext_id}_stop()` function at the module level
|
|
|
|
# The `api_stop()` function is for backwards compatibility (will be deprecated)
|
|
|
|
stop_fns = [f"{ext_id}_stop", "api_stop"]
|
|
|
|
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None)
|
|
|
|
assert stop_fn_name, "No stop function found for '{ext.module_name}'"
|
|
|
|
|
2024-05-28 14:07:33 +03:00
|
|
|
stop_fn = getattr(old_module, stop_fn_name)
|
|
|
|
if stop_fn:
|
|
|
|
await stop_fn()
|
2024-02-21 12:08:37 +02:00
|
|
|
|
|
|
|
logger.info(f"Stopped background work for extension '{ext.module_name}'.")
|
|
|
|
except Exception as ex:
|
|
|
|
logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
|
|
|
|
logger.warning(ex)
|
|
|
|
return False
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
async def _stop_extension_background_work_via_api(ext_id, user, access_token):
|
|
|
|
logger.info(
|
|
|
|
f"Stopping background work for extension '{ext_id}' using the REST API."
|
|
|
|
)
|
2023-02-15 15:34:09 +02:00
|
|
|
async with httpx.AsyncClient() as client:
|
|
|
|
try:
|
2023-02-15 17:25:58 +02:00
|
|
|
url = f"http://{settings.host}:{settings.port}/{ext_id}/api/v1?usr={user}"
|
2023-12-12 12:38:19 +02:00
|
|
|
headers = (
|
|
|
|
{"Authorization": "Bearer " + access_token} if access_token else None
|
|
|
|
)
|
|
|
|
resp = await client.delete(url=url, headers=headers)
|
|
|
|
resp.raise_for_status()
|
2024-02-21 12:08:37 +02:00
|
|
|
logger.info(f"Stopped background work for extension '{ext_id}'.")
|
2023-02-15 15:34:09 +02:00
|
|
|
except Exception as ex:
|
2024-02-21 12:08:37 +02:00
|
|
|
logger.warning(
|
|
|
|
f"Failed to stop background work for '{ext_id}' using the REST API."
|
|
|
|
)
|
2023-02-15 15:34:09 +02:00
|
|
|
logger.warning(ex)
|
2023-05-09 11:22:19 +03:00
|
|
|
|
|
|
|
|
|
|
|
def to_valid_user_id(user_id: str) -> UUID:
|
|
|
|
if len(user_id) < 32:
|
|
|
|
raise ValueError("User ID must have at least 128 bits")
|
|
|
|
try:
|
|
|
|
int(user_id, 16)
|
2024-04-17 13:11:51 +02:00
|
|
|
except Exception as exc:
|
|
|
|
raise ValueError("Invalid hex string for User ID.") from exc
|
2023-05-09 11:22:19 +03:00
|
|
|
|
|
|
|
return UUID(hex=user_id[:32], version=4)
|