lnbits-legend/lnbits/core/crud/extensions.py
2024-11-05 13:26:12 +02:00

137 lines
3.9 KiB
Python

from typing import Optional
from lnbits.core.db import db
from lnbits.core.models.extensions import (
InstallableExtension,
UserExtension,
)
from lnbits.db import Connection, Database
async def create_installed_extension(
ext: InstallableExtension,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).insert("installed_extensions", ext)
async def update_installed_extension(
ext: InstallableExtension,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).update("installed_extensions", ext)
async def update_installed_extension_state(
*, ext_id: str, active: bool, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"""
UPDATE installed_extensions SET active = :active WHERE id = :ext
""",
{"ext": ext_id, "active": active},
)
async def delete_installed_extension(
*, ext_id: str, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"""
DELETE from installed_extensions WHERE id = :ext
""",
{"ext": ext_id},
)
async def drop_extension_db(ext_id: str, conn: Optional[Connection] = None) -> None:
row: dict = await (conn or db).fetchone(
"SELECT * FROM dbversions WHERE db = :id",
{"id": ext_id},
)
# Check that 'ext_id' is a valid extension id and not a malicious string
assert row, f"Extension '{ext_id}' db version cannot be found"
is_file_based_db = await Database.clean_ext_db_files(ext_id)
if is_file_based_db:
return
# String formatting is required, params are not accepted for 'DROP SCHEMA'.
# The `ext_id` value is verified above.
await (conn or db).execute(
f"DROP SCHEMA IF EXISTS {ext_id} CASCADE",
)
async def get_installed_extension(
ext_id: str, conn: Optional[Connection] = None
) -> Optional[InstallableExtension]:
extension = await (conn or db).fetchone(
"SELECT * FROM installed_extensions WHERE id = :id",
{"id": ext_id},
InstallableExtension,
)
return extension
async def get_installed_extensions(
active: Optional[bool] = None,
conn: Optional[Connection] = None,
) -> list[InstallableExtension]:
where = "WHERE active = :active" if active is not None else ""
values = {"active": active} if active is not None else {}
all_extensions = await (conn or db).fetchall(
f"SELECT * FROM installed_extensions {where}",
values,
model=InstallableExtension,
)
return all_extensions
async def get_user_extension(
user_id: str, extension: str, conn: Optional[Connection] = None
) -> Optional[UserExtension]:
return await (conn or db).fetchone(
"""
SELECT * FROM extensions
WHERE "user" = :user AND extension = :ext
""",
{"user": user_id, "ext": extension},
model=UserExtension,
)
async def get_user_extensions(
user_id: str, conn: Optional[Connection] = None
) -> list[UserExtension]:
return await (conn or db).fetchall(
"""SELECT * FROM extensions WHERE "user" = :user""",
{"user": user_id},
model=UserExtension,
)
async def create_user_extension(
user_extension: UserExtension, conn: Optional[Connection] = None
) -> None:
await (conn or db).insert("extensions", user_extension)
async def update_user_extension(
user_extension: UserExtension, conn: Optional[Connection] = None
) -> None:
where = """WHERE extension = :extension AND "user" = :user"""
await (conn or db).update("extensions", user_extension, where)
async def get_user_active_extensions_ids(
user_id: str, conn: Optional[Connection] = None
) -> list[str]:
exts = await (conn or db).fetchall(
"""
SELECT * FROM extensions WHERE "user" = :user AND active
""",
{"user": user_id},
UserExtension,
)
return [ext.extension for ext in exts]