lnbits-legend/lnbits/core/helpers.py
2024-11-05 13:26:12 +02:00

130 lines
4.3 KiB
Python

import importlib
import re
from typing import Any, Optional
from urllib.parse import urlparse
from uuid import UUID
from loguru import logger
from lnbits.core import migrations as core_migrations
from lnbits.core.crud import (
get_db_versions,
get_installed_extensions,
update_migration_version,
)
from lnbits.core.db import db as core_db
from lnbits.core.models import DbVersion
from lnbits.core.models.extensions import InstallableExtension
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
from lnbits.settings import settings
async def migrate_extension_database(
ext: InstallableExtension, current_version: Optional[DbVersion] = None
):
try:
ext_migrations = importlib.import_module(f"{ext.module_name}.migrations")
ext_db = importlib.import_module(ext.module_name).db
except ImportError as exc:
logger.error(exc)
raise ImportError(f"Cannot import module for extension '{ext.id}'.") from exc
async with ext_db.connect() as ext_conn:
await run_migration(ext_conn, ext_migrations, ext.id, current_version)
async def run_migration(
db: Connection,
migrations_module: Any,
db_name: str,
current_version: Optional[DbVersion] = None,
):
matcher = re.compile(r"^m(\d\d\d)_")
for key, migrate in list(migrations_module.__dict__.items()):
match = matcher.match(key)
if match:
version = int(match.group(1))
if not current_version or version > current_version.version:
logger.debug(f"running migration {db_name}.{version}")
print(f"running migration {db_name}.{version}")
await migrate(db)
if db.schema is None:
await update_migration_version(db, db_name, version)
else:
async with core_db.connect() as conn:
await update_migration_version(conn, db_name, version)
def to_valid_user_id(user_id: str) -> UUID:
if len(user_id) < 32:
raise ValueError("User ID must have at least 128 bits")
try:
int(user_id, 16)
except Exception as exc:
raise ValueError("Invalid hex string for User ID.") from exc
return UUID(hex=user_id[:32], version=4)
async def load_disabled_extension_list() -> None:
"""Update list of extensions that have been explicitly disabled"""
inactive_extensions = await get_installed_extensions(active=False)
settings.lnbits_deactivated_extensions.update([e.id for e in inactive_extensions])
async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
async with core_db.connect() as conn:
exists = False
if conn.type == SQLITE:
exists = await conn.fetchone(
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
)
elif conn.type in {POSTGRES, COCKROACH}:
exists = await conn.fetchone(
"SELECT * FROM information_schema.tables WHERE table_schema = 'public'"
" AND table_name = 'dbversions'"
)
if not exists:
await core_migrations.m000_create_migrations_table(conn)
current_versions = await get_db_versions(conn)
core_version = next(
(v for v in current_versions if v.db == "core"),
DbVersion(db="core", version=0),
)
await run_migration(conn, core_migrations, "core", core_version)
# here is the first place we can be sure that the
# `installed_extensions` table has been created
await load_disabled_extension_list()
for ext in await get_installed_extensions():
current_version = next(
(v for v in current_versions if v.db == ext.id),
DbVersion(db=ext.id, version=0),
)
if current_version is None:
logger.warning(
f"Extension {ext.id} has no migration version. This should not happen."
)
continue
try:
await migrate_extension_database(ext, current_version)
except Exception as e:
logger.exception(f"Error migrating extension {ext.id}: {e}")
logger.info("✔️ All migrations done.")
def is_valid_url(url):
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False