mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-02-23 22:47:05 +01:00
130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
import importlib
|
|
import re
|
|
from typing import Any, Optional
|
|
from urllib.parse import urlparse
|
|
from uuid import UUID
|
|
|
|
from loguru import logger
|
|
|
|
from lnbits.core import migrations as core_migrations
|
|
from lnbits.core.crud import (
|
|
get_db_versions,
|
|
get_installed_extensions,
|
|
update_migration_version,
|
|
)
|
|
from lnbits.core.db import db as core_db
|
|
from lnbits.core.models import DbVersion
|
|
from lnbits.core.models.extensions import InstallableExtension
|
|
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
|
|
from lnbits.settings import settings
|
|
|
|
|
|
async def migrate_extension_database(
|
|
ext: InstallableExtension, current_version: Optional[DbVersion] = None
|
|
):
|
|
|
|
try:
|
|
ext_migrations = importlib.import_module(f"{ext.module_name}.migrations")
|
|
ext_db = importlib.import_module(ext.module_name).db
|
|
except ImportError as exc:
|
|
logger.error(exc)
|
|
raise ImportError(f"Cannot import module for extension '{ext.id}'.") from exc
|
|
|
|
async with ext_db.connect() as ext_conn:
|
|
await run_migration(ext_conn, ext_migrations, ext.id, current_version)
|
|
|
|
|
|
async def run_migration(
|
|
db: Connection,
|
|
migrations_module: Any,
|
|
db_name: str,
|
|
current_version: Optional[DbVersion] = None,
|
|
):
|
|
matcher = re.compile(r"^m(\d\d\d)_")
|
|
|
|
for key, migrate in list(migrations_module.__dict__.items()):
|
|
match = matcher.match(key)
|
|
if match:
|
|
version = int(match.group(1))
|
|
if not current_version or version > current_version.version:
|
|
logger.debug(f"running migration {db_name}.{version}")
|
|
print(f"running migration {db_name}.{version}")
|
|
await migrate(db)
|
|
|
|
if db.schema is None:
|
|
await update_migration_version(db, db_name, version)
|
|
else:
|
|
async with core_db.connect() as conn:
|
|
await update_migration_version(conn, db_name, version)
|
|
|
|
|
|
def to_valid_user_id(user_id: str) -> UUID:
|
|
if len(user_id) < 32:
|
|
raise ValueError("User ID must have at least 128 bits")
|
|
try:
|
|
int(user_id, 16)
|
|
except Exception as exc:
|
|
raise ValueError("Invalid hex string for User ID.") from exc
|
|
|
|
return UUID(hex=user_id[:32], version=4)
|
|
|
|
|
|
async def load_disabled_extension_list() -> None:
|
|
"""Update list of extensions that have been explicitly disabled"""
|
|
inactive_extensions = await get_installed_extensions(active=False)
|
|
settings.lnbits_deactivated_extensions.update([e.id for e in inactive_extensions])
|
|
|
|
|
|
async def migrate_databases():
|
|
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
|
|
|
async with core_db.connect() as conn:
|
|
exists = False
|
|
if conn.type == SQLITE:
|
|
exists = await conn.fetchone(
|
|
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
|
)
|
|
elif conn.type in {POSTGRES, COCKROACH}:
|
|
exists = await conn.fetchone(
|
|
"SELECT * FROM information_schema.tables WHERE table_schema = 'public'"
|
|
" AND table_name = 'dbversions'"
|
|
)
|
|
|
|
if not exists:
|
|
await core_migrations.m000_create_migrations_table(conn)
|
|
|
|
current_versions = await get_db_versions(conn)
|
|
core_version = next(
|
|
(v for v in current_versions if v.db == "core"),
|
|
DbVersion(db="core", version=0),
|
|
)
|
|
await run_migration(conn, core_migrations, "core", core_version)
|
|
|
|
# here is the first place we can be sure that the
|
|
# `installed_extensions` table has been created
|
|
await load_disabled_extension_list()
|
|
|
|
for ext in await get_installed_extensions():
|
|
current_version = next(
|
|
(v for v in current_versions if v.db == ext.id),
|
|
DbVersion(db=ext.id, version=0),
|
|
)
|
|
if current_version is None:
|
|
logger.warning(
|
|
f"Extension {ext.id} has no migration version. This should not happen."
|
|
)
|
|
continue
|
|
try:
|
|
await migrate_extension_database(ext, current_version)
|
|
except Exception as e:
|
|
logger.exception(f"Error migrating extension {ext.id}: {e}")
|
|
|
|
logger.info("✔️ All migrations done.")
|
|
|
|
|
|
def is_valid_url(url):
|
|
try:
|
|
result = urlparse(url)
|
|
return all([result.scheme, result.netloc])
|
|
except ValueError:
|
|
return False
|