mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-02-23 14:40:47 +01:00
Fix overlapping redirect paths (#2671)
This commit is contained in:
parent
7a5e7fbd8c
commit
5f4f1288d7
14 changed files with 684 additions and 483 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -49,8 +49,8 @@ fly.toml
|
||||||
lnbits-backup.zip
|
lnbits-backup.zip
|
||||||
|
|
||||||
# Ignore extensions (post installable extension PR)
|
# Ignore extensions (post installable extension PR)
|
||||||
extensions
|
/extensions
|
||||||
upgrades/
|
/upgrades/
|
||||||
|
|
||||||
# builded python package
|
# builded python package
|
||||||
dist
|
dist
|
||||||
|
|
|
@ -17,10 +17,13 @@ from slowapi.util import get_remote_address
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
|
||||||
from lnbits.core.crud import (
|
from lnbits.core.crud import (
|
||||||
|
add_installed_extension,
|
||||||
get_dbversions,
|
get_dbversions,
|
||||||
get_installed_extensions,
|
get_installed_extensions,
|
||||||
update_installed_extension_state,
|
update_installed_extension_state,
|
||||||
)
|
)
|
||||||
|
from lnbits.core.extensions.extension_manager import deactivate_extension
|
||||||
|
from lnbits.core.extensions.helpers import version_parse
|
||||||
from lnbits.core.helpers import migrate_extension_database
|
from lnbits.core.helpers import migrate_extension_database
|
||||||
from lnbits.core.tasks import ( # watchdog_task
|
from lnbits.core.tasks import ( # watchdog_task
|
||||||
killswitch_task,
|
killswitch_task,
|
||||||
|
@ -44,14 +47,8 @@ from lnbits.wallets import get_funding_source, set_funding_source
|
||||||
from .commands import migrate_databases
|
from .commands import migrate_databases
|
||||||
from .core import init_core_routers
|
from .core import init_core_routers
|
||||||
from .core.db import core_app_extra
|
from .core.db import core_app_extra
|
||||||
|
from .core.extensions.models import Extension, InstallableExtension
|
||||||
from .core.services import check_admin_settings, check_webpush_settings
|
from .core.services import check_admin_settings, check_webpush_settings
|
||||||
from .core.views.extension_api import add_installed_extension
|
|
||||||
from .extension_manager import (
|
|
||||||
Extension,
|
|
||||||
InstallableExtension,
|
|
||||||
get_valid_extensions,
|
|
||||||
version_parse,
|
|
||||||
)
|
|
||||||
from .middleware import (
|
from .middleware import (
|
||||||
CustomGZipMiddleware,
|
CustomGZipMiddleware,
|
||||||
ExtensionsRedirectMiddleware,
|
ExtensionsRedirectMiddleware,
|
||||||
|
@ -243,6 +240,7 @@ async def check_installed_extensions(app: FastAPI):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(e)
|
logger.warning(e)
|
||||||
|
await deactivate_extension(ext.id)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to re-install extension: {ext.id} ({ext.installed_version})"
|
f"Failed to re-install extension: {ext.id} ({ext.installed_version})"
|
||||||
)
|
)
|
||||||
|
@ -317,7 +315,6 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
|
||||||
|
|
||||||
# mount routes for the new version
|
# mount routes for the new version
|
||||||
core_app_extra.register_new_ext_routes(extension)
|
core_app_extra.register_new_ext_routes(extension)
|
||||||
ext.notify_upgrade(extension.upgrade_hash)
|
|
||||||
|
|
||||||
|
|
||||||
def register_custom_extensions_path():
|
def register_custom_extensions_path():
|
||||||
|
@ -380,24 +377,22 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None:
|
||||||
)
|
)
|
||||||
app.mount(s["path"], StaticFiles(directory=static_dir), s["name"])
|
app.mount(s["path"], StaticFiles(directory=static_dir), s["name"])
|
||||||
|
|
||||||
if hasattr(ext_module, f"{ext.code}_redirect_paths"):
|
ext_redirects = (
|
||||||
ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths")
|
getattr(ext_module, f"{ext.code}_redirect_paths")
|
||||||
settings.lnbits_extensions_redirects = [
|
if hasattr(ext_module, f"{ext.code}_redirect_paths")
|
||||||
r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code
|
else []
|
||||||
]
|
)
|
||||||
for r in ext_redirects:
|
|
||||||
r["ext_id"] = ext.code
|
|
||||||
settings.lnbits_extensions_redirects.append(r)
|
|
||||||
|
|
||||||
logger.trace(f"adding route for extension {ext_module}")
|
settings.activate_extension_paths(ext.code, ext.upgrade_hash, ext_redirects)
|
||||||
|
|
||||||
|
logger.trace(f"Adding route for extension {ext_module}.")
|
||||||
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""
|
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""
|
||||||
app.include_router(router=ext_route, prefix=prefix)
|
app.include_router(router=ext_route, prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
async def check_and_register_extensions(app: FastAPI):
|
async def check_and_register_extensions(app: FastAPI):
|
||||||
await check_installed_extensions(app)
|
await check_installed_extensions(app)
|
||||||
for ext in get_valid_extensions(False):
|
for ext in Extension.get_valid_extensions(False):
|
||||||
try:
|
try:
|
||||||
register_ext_routes(app, ext)
|
register_ext_routes(app, ext)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
|
@ -25,18 +25,18 @@ from lnbits.core.crud import (
|
||||||
remove_deleted_wallets,
|
remove_deleted_wallets,
|
||||||
update_payment_status,
|
update_payment_status,
|
||||||
)
|
)
|
||||||
|
from lnbits.core.extensions.models import (
|
||||||
|
CreateExtension,
|
||||||
|
ExtensionRelease,
|
||||||
|
InstallableExtension,
|
||||||
|
)
|
||||||
from lnbits.core.helpers import migrate_databases
|
from lnbits.core.helpers import migrate_databases
|
||||||
from lnbits.core.models import Payment, PaymentState, User
|
from lnbits.core.models import Payment, PaymentState
|
||||||
from lnbits.core.services import check_admin_settings
|
from lnbits.core.services import check_admin_settings
|
||||||
from lnbits.core.views.extension_api import (
|
from lnbits.core.views.extension_api import (
|
||||||
api_install_extension,
|
api_install_extension,
|
||||||
api_uninstall_extension,
|
api_uninstall_extension,
|
||||||
)
|
)
|
||||||
from lnbits.extension_manager import (
|
|
||||||
CreateExtension,
|
|
||||||
ExtensionRelease,
|
|
||||||
InstallableExtension,
|
|
||||||
)
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
from lnbits.wallets.base import Wallet
|
from lnbits.wallets.base import Wallet
|
||||||
|
|
||||||
|
@ -611,7 +611,7 @@ async def _call_install_extension(
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
else:
|
else:
|
||||||
await api_install_extension(data, User(id="mock_id"))
|
await api_install_extension(data)
|
||||||
|
|
||||||
|
|
||||||
async def _call_uninstall_extension(
|
async def _call_uninstall_extension(
|
||||||
|
@ -625,7 +625,7 @@ async def _call_uninstall_extension(
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
else:
|
else:
|
||||||
await api_uninstall_extension(extension, User(id="mock_id"))
|
await api_uninstall_extension(extension)
|
||||||
|
|
||||||
|
|
||||||
async def _can_run_operation(url) -> bool:
|
async def _can_run_operation(url) -> bool:
|
||||||
|
|
|
@ -8,14 +8,14 @@ import shortuuid
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
from lnbits.core.db import db
|
from lnbits.core.db import db
|
||||||
from lnbits.core.models import PaymentState
|
from lnbits.core.extensions.models import (
|
||||||
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
|
|
||||||
from lnbits.extension_manager import (
|
|
||||||
InstallableExtension,
|
InstallableExtension,
|
||||||
PayToEnableInfo,
|
PayToEnableInfo,
|
||||||
UserExtension,
|
UserExtension,
|
||||||
UserExtensionInfo,
|
UserExtensionInfo,
|
||||||
)
|
)
|
||||||
|
from lnbits.core.models import PaymentState
|
||||||
|
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
|
||||||
from lnbits.settings import (
|
from lnbits.settings import (
|
||||||
AdminSettings,
|
AdminSettings,
|
||||||
EditableSettings,
|
EditableSettings,
|
||||||
|
@ -430,7 +430,7 @@ async def get_installed_extension(
|
||||||
async def get_installed_extensions(
|
async def get_installed_extensions(
|
||||||
active: Optional[bool] = None,
|
active: Optional[bool] = None,
|
||||||
conn: Optional[Connection] = None,
|
conn: Optional[Connection] = None,
|
||||||
) -> List["InstallableExtension"]:
|
) -> List[InstallableExtension]:
|
||||||
rows = await (conn or db).fetchall(
|
rows = await (conn or db).fetchall(
|
||||||
"SELECT * FROM installed_extensions",
|
"SELECT * FROM installed_extensions",
|
||||||
(),
|
(),
|
||||||
|
|
93
lnbits/core/extensions/extension_manager.py
Normal file
93
lnbits/core/extensions/extension_manager.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lnbits.core.crud import (
|
||||||
|
add_installed_extension,
|
||||||
|
delete_installed_extension,
|
||||||
|
get_dbversions,
|
||||||
|
get_installed_extension,
|
||||||
|
update_installed_extension_state,
|
||||||
|
)
|
||||||
|
from lnbits.core.db import core_app_extra
|
||||||
|
from lnbits.core.helpers import migrate_extension_database
|
||||||
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
from .models import Extension, InstallableExtension
|
||||||
|
|
||||||
|
|
||||||
|
async def install_extension(ext_info: InstallableExtension) -> Extension:
|
||||||
|
extension = Extension.from_installable_ext(ext_info)
|
||||||
|
installed_ext = await get_installed_extension(ext_info.id)
|
||||||
|
ext_info.payments = installed_ext.payments if installed_ext else []
|
||||||
|
|
||||||
|
await ext_info.download_archive()
|
||||||
|
|
||||||
|
ext_info.extract_archive()
|
||||||
|
|
||||||
|
db_version = (await get_dbversions()).get(ext_info.id, 0)
|
||||||
|
await migrate_extension_database(extension, db_version)
|
||||||
|
|
||||||
|
await add_installed_extension(ext_info)
|
||||||
|
|
||||||
|
if extension.is_upgrade_extension:
|
||||||
|
# call stop while the old routes are still active
|
||||||
|
await stop_extension_background_work(ext_info.id)
|
||||||
|
|
||||||
|
return extension
|
||||||
|
|
||||||
|
|
||||||
|
async def uninstall_extension(ext_id: str):
|
||||||
|
await stop_extension_background_work(ext_id)
|
||||||
|
|
||||||
|
settings.deactivate_extension_paths(ext_id)
|
||||||
|
|
||||||
|
extension = await get_installed_extension(ext_id)
|
||||||
|
if extension:
|
||||||
|
extension.clean_extension_files()
|
||||||
|
await delete_installed_extension(ext_id=ext_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def activate_extension(ext: Extension):
|
||||||
|
core_app_extra.register_new_ext_routes(ext)
|
||||||
|
await update_installed_extension_state(ext_id=ext.code, active=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def deactivate_extension(ext_id: str):
|
||||||
|
settings.deactivate_extension_paths(ext_id)
|
||||||
|
await update_installed_extension_state(ext_id=ext_id, active=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def stop_extension_background_work(ext_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
|
||||||
|
Extensions SHOULD expose a `api_stop()` function.
|
||||||
|
"""
|
||||||
|
upgrade_hash = settings.lnbits_upgraded_extensions.get(ext_id, "")
|
||||||
|
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
|
||||||
|
old_module = importlib.import_module(ext.module_name)
|
||||||
|
|
||||||
|
# Extensions must expose an `{ext_id}_stop()` function at the module level
|
||||||
|
# The `api_stop()` function is for backwards compatibility (will be deprecated)
|
||||||
|
stop_fns = [f"{ext_id}_stop", "api_stop"]
|
||||||
|
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None)
|
||||||
|
assert stop_fn_name, "No stop function found for '{ext.module_name}'"
|
||||||
|
|
||||||
|
stop_fn = getattr(old_module, stop_fn_name)
|
||||||
|
if stop_fn:
|
||||||
|
if asyncio.iscoroutinefunction(stop_fn):
|
||||||
|
await stop_fn()
|
||||||
|
else:
|
||||||
|
stop_fn()
|
||||||
|
|
||||||
|
logger.info(f"Stopped background work for extension '{ext.module_name}'.")
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
|
||||||
|
logger.warning(ex)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
56
lnbits/core/extensions/helpers.py
Normal file
56
lnbits/core/extensions/helpers.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import hashlib
|
||||||
|
from typing import Any, Optional
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def version_parse(v: str):
|
||||||
|
"""
|
||||||
|
Wrapper for version.parse() that does not throw if the version is invalid.
|
||||||
|
Instead it return the lowest possible version ("0.0.0")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return version.parse(v)
|
||||||
|
except Exception:
|
||||||
|
return version.parse("0.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
|
||||||
|
headers = {"User-Agent": settings.user_agent}
|
||||||
|
if settings.lnbits_ext_github_token:
|
||||||
|
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
|
||||||
|
async with httpx.AsyncClient(headers=headers) as client:
|
||||||
|
resp = await client.get(url)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(url, save_path):
|
||||||
|
with request.urlopen(url, timeout=60) as dl_file:
|
||||||
|
with open(save_path, "wb") as out_file:
|
||||||
|
out_file.write(dl_file.read())
|
||||||
|
|
||||||
|
|
||||||
|
def file_hash(filename):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
b = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(b)
|
||||||
|
with open(filename, "rb", buffering=0) as f:
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
h.update(mv[:n])
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
||||||
|
if not path:
|
||||||
|
return ""
|
||||||
|
_, _, *rest = path.split("/")
|
||||||
|
tail = "/".join(rest)
|
||||||
|
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
@ -6,16 +8,22 @@ import shutil
|
||||||
import sys
|
import sys
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
from typing import Any, NamedTuple, Optional
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from packaging import version
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
from .helpers import (
|
||||||
|
download_url,
|
||||||
|
file_hash,
|
||||||
|
github_api_get,
|
||||||
|
icon_to_github_url,
|
||||||
|
version_parse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExplicitRelease(BaseModel):
|
class ExplicitRelease(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
@ -23,7 +31,7 @@ class ExplicitRelease(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
archive: str
|
archive: str
|
||||||
hash: str
|
hash: str
|
||||||
dependencies: List[str] = []
|
dependencies: list[str] = []
|
||||||
repo: Optional[str]
|
repo: Optional[str]
|
||||||
icon: Optional[str]
|
icon: Optional[str]
|
||||||
short_description: Optional[str]
|
short_description: Optional[str]
|
||||||
|
@ -48,9 +56,9 @@ class GitHubRelease(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Manifest(BaseModel):
|
class Manifest(BaseModel):
|
||||||
featured: List[str] = []
|
featured: list[str] = []
|
||||||
extensions: List["ExplicitRelease"] = []
|
extensions: list[ExplicitRelease] = []
|
||||||
repos: List["GitHubRelease"] = []
|
repos: list[GitHubRelease] = []
|
||||||
|
|
||||||
|
|
||||||
class GitHubRepoRelease(BaseModel):
|
class GitHubRepoRelease(BaseModel):
|
||||||
|
@ -81,6 +89,17 @@ class ExtensionConfig(BaseModel):
|
||||||
return True
|
return True
|
||||||
return version_parse(self.min_lnbits_version) <= version_parse(settings.version)
|
return version_parse(self.min_lnbits_version) <= version_parse(settings.version)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_github_release_config(
|
||||||
|
cls, org: str, repo: str, tag_name: str
|
||||||
|
) -> Optional[ExtensionConfig]:
|
||||||
|
config_url = (
|
||||||
|
f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json"
|
||||||
|
)
|
||||||
|
error_msg = "Cannot fetch GitHub extension config"
|
||||||
|
config = await github_api_get(config_url, error_msg)
|
||||||
|
return ExtensionConfig.parse_obj(config)
|
||||||
|
|
||||||
|
|
||||||
class ReleasePaymentInfo(BaseModel):
|
class ReleasePaymentInfo(BaseModel):
|
||||||
amount: Optional[int] = None
|
amount: Optional[int] = None
|
||||||
|
@ -112,7 +131,7 @@ class UserExtension(BaseModel):
|
||||||
return self.extra.paid_to_enable is True
|
return self.extra.paid_to_enable is True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_row(cls, data: dict) -> "UserExtension":
|
def from_row(cls, data: dict) -> UserExtension:
|
||||||
ext = UserExtension(**data)
|
ext = UserExtension(**data)
|
||||||
ext.extra = (
|
ext.extra = (
|
||||||
UserExtensionInfo(**json.loads(data["_extra"] or "{}"))
|
UserExtensionInfo(**json.loads(data["_extra"] or "{}"))
|
||||||
|
@ -122,124 +141,6 @@ class UserExtension(BaseModel):
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
|
|
||||||
def download_url(url, save_path):
|
|
||||||
with request.urlopen(url, timeout=60) as dl_file:
|
|
||||||
with open(save_path, "wb") as out_file:
|
|
||||||
out_file.write(dl_file.read())
|
|
||||||
|
|
||||||
|
|
||||||
def file_hash(filename):
|
|
||||||
h = hashlib.sha256()
|
|
||||||
b = bytearray(128 * 1024)
|
|
||||||
mv = memoryview(b)
|
|
||||||
with open(filename, "rb", buffering=0) as f:
|
|
||||||
while n := f.readinto(mv):
|
|
||||||
h.update(mv[:n])
|
|
||||||
return h.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_github_repo_info(
|
|
||||||
org: str, repository: str
|
|
||||||
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
|
||||||
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
|
||||||
error_msg = "Cannot fetch extension repo"
|
|
||||||
repo = await github_api_get(repo_url, error_msg)
|
|
||||||
github_repo = GitHubRepo.parse_obj(repo)
|
|
||||||
|
|
||||||
lates_release_url = (
|
|
||||||
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
|
||||||
)
|
|
||||||
error_msg = "Cannot fetch extension releases"
|
|
||||||
latest_release: Any = await github_api_get(lates_release_url, error_msg)
|
|
||||||
|
|
||||||
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
|
||||||
error_msg = "Cannot fetch config for extension"
|
|
||||||
config = await github_api_get(config_url, error_msg)
|
|
||||||
|
|
||||||
return (
|
|
||||||
github_repo,
|
|
||||||
GitHubRepoRelease.parse_obj(latest_release),
|
|
||||||
ExtensionConfig.parse_obj(config),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_manifest(url) -> Manifest:
|
|
||||||
error_msg = "Cannot fetch extensions manifest"
|
|
||||||
manifest = await github_api_get(url, error_msg)
|
|
||||||
return Manifest.parse_obj(manifest)
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
|
|
||||||
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
|
||||||
error_msg = "Cannot fetch extension releases"
|
|
||||||
releases = await github_api_get(releases_url, error_msg)
|
|
||||||
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_github_release_config(
|
|
||||||
org: str, repo: str, tag_name: str
|
|
||||||
) -> Optional[ExtensionConfig]:
|
|
||||||
config_url = (
|
|
||||||
f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json"
|
|
||||||
)
|
|
||||||
error_msg = "Cannot fetch GitHub extension config"
|
|
||||||
config = await github_api_get(config_url, error_msg)
|
|
||||||
return ExtensionConfig.parse_obj(config)
|
|
||||||
|
|
||||||
|
|
||||||
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
|
|
||||||
headers = {"User-Agent": settings.user_agent}
|
|
||||||
if settings.lnbits_ext_github_token:
|
|
||||||
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
|
|
||||||
async with httpx.AsyncClient(headers=headers) as client:
|
|
||||||
resp = await client.get(url)
|
|
||||||
if resp.status_code != 200:
|
|
||||||
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
|
||||||
resp.raise_for_status()
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_release_payment_info(
|
|
||||||
url: str, amount: Optional[int] = None
|
|
||||||
) -> Optional[ReleasePaymentInfo]:
|
|
||||||
if amount:
|
|
||||||
url = f"{url}?amount={amount}"
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(url)
|
|
||||||
resp.raise_for_status()
|
|
||||||
return ReleasePaymentInfo(**resp.json())
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_release_details(details_link: str) -> Optional[dict]:
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(details_link)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
if "description_md" in data:
|
|
||||||
resp = await client.get(data["description_md"])
|
|
||||||
if not resp.is_error:
|
|
||||||
data["description_md"] = resp.text
|
|
||||||
|
|
||||||
return data
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
|
||||||
if not path:
|
|
||||||
return ""
|
|
||||||
_, _, *rest = path.split("/")
|
|
||||||
tail = "/".join(rest)
|
|
||||||
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
|
||||||
|
|
||||||
|
|
||||||
class Extension(NamedTuple):
|
class Extension(NamedTuple):
|
||||||
code: str
|
code: str
|
||||||
is_valid: bool
|
is_valid: bool
|
||||||
|
@ -247,7 +148,7 @@ class Extension(NamedTuple):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
short_description: Optional[str] = None
|
short_description: Optional[str] = None
|
||||||
tile: Optional[str] = None
|
tile: Optional[str] = None
|
||||||
contributors: Optional[List[str]] = None
|
contributors: Optional[list[str]] = None
|
||||||
hidden: bool = False
|
hidden: bool = False
|
||||||
migration_module: Optional[str] = None
|
migration_module: Optional[str] = None
|
||||||
db_name: Optional[str] = None
|
db_name: Optional[str] = None
|
||||||
|
@ -269,7 +170,7 @@ class Extension(NamedTuple):
|
||||||
return self.upgrade_hash != ""
|
return self.upgrade_hash != ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_installable_ext(cls, ext_info: "InstallableExtension") -> "Extension":
|
def from_installable_ext(cls, ext_info: InstallableExtension) -> Extension:
|
||||||
return Extension(
|
return Extension(
|
||||||
code=ext_info.id,
|
code=ext_info.id,
|
||||||
is_valid=True,
|
is_valid=True,
|
||||||
|
@ -278,22 +179,43 @@ class Extension(NamedTuple):
|
||||||
upgrade_hash=ext_info.hash if ext_info.module_installed else "",
|
upgrade_hash=ext_info.hash if ext_info.module_installed else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_valid_extensions(
|
||||||
|
cls, include_deactivated: Optional[bool] = True
|
||||||
|
) -> list[Extension]:
|
||||||
|
valid_extensions = [
|
||||||
|
extension for extension in cls._extensions() if extension.is_valid
|
||||||
|
]
|
||||||
|
|
||||||
# All subdirectories in the current directory, not recursive.
|
if include_deactivated:
|
||||||
|
return valid_extensions
|
||||||
|
|
||||||
|
if settings.lnbits_extensions_deactivate_all:
|
||||||
|
return []
|
||||||
|
|
||||||
class ExtensionManager:
|
return [
|
||||||
def __init__(self) -> None:
|
e
|
||||||
|
for e in valid_extensions
|
||||||
|
if e.code not in settings.lnbits_deactivated_extensions
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_valid_extension(
|
||||||
|
cls, ext_id: str, include_deactivated: Optional[bool] = True
|
||||||
|
) -> Optional[Extension]:
|
||||||
|
all_extensions = cls.get_valid_extensions(include_deactivated)
|
||||||
|
return next((e for e in all_extensions if e.code == ext_id), None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extensions(cls) -> list[Extension]:
|
||||||
p = Path(settings.lnbits_extensions_path, "extensions")
|
p = Path(settings.lnbits_extensions_path, "extensions")
|
||||||
Path(p).mkdir(parents=True, exist_ok=True)
|
Path(p).mkdir(parents=True, exist_ok=True)
|
||||||
self._extension_folders: List[Path] = [f for f in p.iterdir() if f.is_dir()]
|
extension_folders: list[Path] = [f for f in p.iterdir() if f.is_dir()]
|
||||||
|
|
||||||
@property
|
|
||||||
def extensions(self) -> List[Extension]:
|
|
||||||
# todo: remove this property somehow, it is too expensive
|
# todo: remove this property somehow, it is too expensive
|
||||||
output: List[Extension] = []
|
output: list[Extension] = []
|
||||||
|
|
||||||
for extension_folder in self._extension_folders:
|
for extension_folder in extension_folders:
|
||||||
extension_code = extension_folder.parts[-1]
|
extension_code = extension_folder.parts[-1]
|
||||||
try:
|
try:
|
||||||
with open(extension_folder / "config.json") as json_file:
|
with open(extension_folder / "config.json") as json_file:
|
||||||
|
@ -356,13 +278,27 @@ class ExtensionRelease(BaseModel):
|
||||||
if not self.pay_link:
|
if not self.pay_link:
|
||||||
return
|
return
|
||||||
|
|
||||||
payment_info = await fetch_release_payment_info(self.pay_link)
|
payment_info = await self.fetch_release_payment_info()
|
||||||
self.cost_sats = payment_info.amount if payment_info else None
|
self.cost_sats = payment_info.amount if payment_info else None
|
||||||
|
|
||||||
|
async def fetch_release_payment_info(
|
||||||
|
self, amount: Optional[int] = None
|
||||||
|
) -> Optional[ReleasePaymentInfo]:
|
||||||
|
url = f"{self.pay_link}?amount={amount}" if amount else self.pay_link
|
||||||
|
assert url, "Missing URL for payment info."
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return ReleasePaymentInfo(**resp.json())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(e)
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_github_release(
|
def from_github_release(
|
||||||
cls, source_repo: str, r: "GitHubRepoRelease"
|
cls, source_repo: str, r: GitHubRepoRelease
|
||||||
) -> "ExtensionRelease":
|
) -> ExtensionRelease:
|
||||||
return ExtensionRelease(
|
return ExtensionRelease(
|
||||||
name=r.name,
|
name=r.name,
|
||||||
description=r.name,
|
description=r.name,
|
||||||
|
@ -377,8 +313,8 @@ class ExtensionRelease(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_explicit_release(
|
def from_explicit_release(
|
||||||
cls, source_repo: str, e: "ExplicitRelease"
|
cls, source_repo: str, e: ExplicitRelease
|
||||||
) -> "ExtensionRelease":
|
) -> ExtensionRelease:
|
||||||
return ExtensionRelease(
|
return ExtensionRelease(
|
||||||
name=e.name,
|
name=e.name,
|
||||||
version=e.version,
|
version=e.version,
|
||||||
|
@ -397,9 +333,9 @@ class ExtensionRelease(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_github_releases(cls, org: str, repo: str) -> List["ExtensionRelease"]:
|
async def get_github_releases(cls, org: str, repo: str) -> list[ExtensionRelease]:
|
||||||
try:
|
try:
|
||||||
github_releases = await fetch_github_releases(org, repo)
|
github_releases = await cls.fetch_github_releases(org, repo)
|
||||||
return [
|
return [
|
||||||
ExtensionRelease.from_github_release(f"{org}/{repo}", r)
|
ExtensionRelease.from_github_release(f"{org}/{repo}", r)
|
||||||
for r in github_releases
|
for r in github_releases
|
||||||
|
@ -408,6 +344,33 @@ class ExtensionRelease(BaseModel):
|
||||||
logger.warning(e)
|
logger.warning(e)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_github_releases(
|
||||||
|
cls, org: str, repo: str
|
||||||
|
) -> list[GitHubRepoRelease]:
|
||||||
|
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
||||||
|
error_msg = "Cannot fetch extension releases"
|
||||||
|
releases = await github_api_get(releases_url, error_msg)
|
||||||
|
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_release_details(cls, details_link: str) -> Optional[dict]:
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(details_link)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
if "description_md" in data:
|
||||||
|
resp = await client.get(data["description_md"])
|
||||||
|
if not resp.is_error:
|
||||||
|
data["description_md"] = resp.text
|
||||||
|
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class InstallableExtension(BaseModel):
|
class InstallableExtension(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
@ -415,13 +378,13 @@ class InstallableExtension(BaseModel):
|
||||||
active: Optional[bool] = False
|
active: Optional[bool] = False
|
||||||
short_description: Optional[str] = None
|
short_description: Optional[str] = None
|
||||||
icon: Optional[str] = None
|
icon: Optional[str] = None
|
||||||
dependencies: List[str] = []
|
dependencies: list[str] = []
|
||||||
is_admin_only: bool = False
|
is_admin_only: bool = False
|
||||||
stars: int = 0
|
stars: int = 0
|
||||||
featured = False
|
featured = False
|
||||||
latest_release: Optional[ExtensionRelease] = None
|
latest_release: Optional[ExtensionRelease] = None
|
||||||
installed_release: Optional[ExtensionRelease] = None
|
installed_release: Optional[ExtensionRelease] = None
|
||||||
payments: List[ReleasePaymentInfo] = []
|
payments: list[ReleasePaymentInfo] = []
|
||||||
pay_to_enable: Optional[PayToEnableInfo] = None
|
pay_to_enable: Optional[PayToEnableInfo] = None
|
||||||
archive: Optional[str] = None
|
archive: Optional[str] = None
|
||||||
|
|
||||||
|
@ -546,16 +509,6 @@ class InstallableExtension(BaseModel):
|
||||||
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
|
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
|
||||||
logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
|
logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
|
||||||
|
|
||||||
def notify_upgrade(self, upgrade_hash: Optional[str]) -> None:
|
|
||||||
"""
|
|
||||||
Update the list of upgraded extensions. The middleware will perform
|
|
||||||
redirects based on this
|
|
||||||
"""
|
|
||||||
if upgrade_hash:
|
|
||||||
settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}")
|
|
||||||
|
|
||||||
settings.lnbits_all_extensions_ids.add(self.id)
|
|
||||||
|
|
||||||
def clean_extension_files(self):
|
def clean_extension_files(self):
|
||||||
# remove downloaded archive
|
# remove downloaded archive
|
||||||
if self.zip_path.is_file():
|
if self.zip_path.is_file():
|
||||||
|
@ -610,7 +563,7 @@ class InstallableExtension(BaseModel):
|
||||||
self.payments.append(payment_info)
|
self.payments.append(payment_info)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_row(cls, data: dict) -> "InstallableExtension":
|
def from_row(cls, data: dict) -> InstallableExtension:
|
||||||
meta = json.loads(data["meta"])
|
meta = json.loads(data["meta"])
|
||||||
ext = InstallableExtension(**data)
|
ext = InstallableExtension(**data)
|
||||||
if "installed_release" in meta:
|
if "installed_release" in meta:
|
||||||
|
@ -623,9 +576,7 @@ class InstallableExtension(BaseModel):
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_rows(
|
def from_rows(cls, rows: Optional[list[Any]] = None) -> list[InstallableExtension]:
|
||||||
cls, rows: Optional[List[Any]] = None
|
|
||||||
) -> List["InstallableExtension"]:
|
|
||||||
if rows is None:
|
if rows is None:
|
||||||
rows = []
|
rows = []
|
||||||
return [InstallableExtension.from_row(row) for row in rows]
|
return [InstallableExtension.from_row(row) for row in rows]
|
||||||
|
@ -633,9 +584,9 @@ class InstallableExtension(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_github_release(
|
async def from_github_release(
|
||||||
cls, github_release: GitHubRelease
|
cls, github_release: GitHubRelease
|
||||||
) -> Optional["InstallableExtension"]:
|
) -> Optional[InstallableExtension]:
|
||||||
try:
|
try:
|
||||||
repo, latest_release, config = await fetch_github_repo_info(
|
repo, latest_release, config = await cls.fetch_github_repo_info(
|
||||||
github_release.organisation, github_release.repository
|
github_release.organisation, github_release.repository
|
||||||
)
|
)
|
||||||
source_repo = f"{github_release.organisation}/{github_release.repository}"
|
source_repo = f"{github_release.organisation}/{github_release.repository}"
|
||||||
|
@ -657,7 +608,7 @@ class InstallableExtension(BaseModel):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_explicit_release(cls, e: ExplicitRelease) -> "InstallableExtension":
|
def from_explicit_release(cls, e: ExplicitRelease) -> InstallableExtension:
|
||||||
return InstallableExtension(
|
return InstallableExtension(
|
||||||
id=e.id,
|
id=e.id,
|
||||||
name=e.name,
|
name=e.name,
|
||||||
|
@ -670,13 +621,13 @@ class InstallableExtension(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_installable_extensions(
|
async def get_installable_extensions(
|
||||||
cls,
|
cls,
|
||||||
) -> List["InstallableExtension"]:
|
) -> list[InstallableExtension]:
|
||||||
extension_list: List[InstallableExtension] = []
|
extension_list: list[InstallableExtension] = []
|
||||||
extension_id_list: List[str] = []
|
extension_id_list: list[str] = []
|
||||||
|
|
||||||
for url in settings.lnbits_extensions_manifests:
|
for url in settings.lnbits_extensions_manifests:
|
||||||
try:
|
try:
|
||||||
manifest = await fetch_manifest(url)
|
manifest = await cls.fetch_manifest(url)
|
||||||
|
|
||||||
for r in manifest.repos:
|
for r in manifest.repos:
|
||||||
ext = await InstallableExtension.from_github_release(r)
|
ext = await InstallableExtension.from_github_release(r)
|
||||||
|
@ -712,12 +663,12 @@ class InstallableExtension(BaseModel):
|
||||||
return extension_list
|
return extension_list
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_extension_releases(cls, ext_id: str) -> List["ExtensionRelease"]:
|
async def get_extension_releases(cls, ext_id: str) -> list[ExtensionRelease]:
|
||||||
extension_releases: List[ExtensionRelease] = []
|
extension_releases: list[ExtensionRelease] = []
|
||||||
|
|
||||||
for url in settings.lnbits_extensions_manifests:
|
for url in settings.lnbits_extensions_manifests:
|
||||||
try:
|
try:
|
||||||
manifest = await fetch_manifest(url)
|
manifest = await cls.fetch_manifest(url)
|
||||||
for r in manifest.repos:
|
for r in manifest.repos:
|
||||||
if r.id != ext_id:
|
if r.id != ext_id:
|
||||||
continue
|
continue
|
||||||
|
@ -741,8 +692,8 @@ class InstallableExtension(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_extension_release(
|
async def get_extension_release(
|
||||||
cls, ext_id: str, source_repo: str, archive: str, version: str
|
cls, ext_id: str, source_repo: str, archive: str, version: str
|
||||||
) -> Optional["ExtensionRelease"]:
|
) -> Optional[ExtensionRelease]:
|
||||||
all_releases: List[ExtensionRelease] = (
|
all_releases: list[ExtensionRelease] = (
|
||||||
await InstallableExtension.get_extension_releases(ext_id)
|
await InstallableExtension.get_extension_releases(ext_id)
|
||||||
)
|
)
|
||||||
selected_release = [
|
selected_release = [
|
||||||
|
@ -755,6 +706,37 @@ class InstallableExtension(BaseModel):
|
||||||
|
|
||||||
return selected_release[0] if len(selected_release) != 0 else None
|
return selected_release[0] if len(selected_release) != 0 else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_github_repo_info(
|
||||||
|
cls, org: str, repository: str
|
||||||
|
) -> tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
||||||
|
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
||||||
|
error_msg = "Cannot fetch extension repo"
|
||||||
|
repo = await github_api_get(repo_url, error_msg)
|
||||||
|
github_repo = GitHubRepo.parse_obj(repo)
|
||||||
|
|
||||||
|
lates_release_url = (
|
||||||
|
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
||||||
|
)
|
||||||
|
error_msg = "Cannot fetch extension releases"
|
||||||
|
latest_release: Any = await github_api_get(lates_release_url, error_msg)
|
||||||
|
|
||||||
|
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
||||||
|
error_msg = "Cannot fetch config for extension"
|
||||||
|
config = await github_api_get(config_url, error_msg)
|
||||||
|
|
||||||
|
return (
|
||||||
|
github_repo,
|
||||||
|
GitHubRepoRelease.parse_obj(latest_release),
|
||||||
|
ExtensionConfig.parse_obj(config),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_manifest(cls, url) -> Manifest:
|
||||||
|
error_msg = "Cannot fetch extensions manifest"
|
||||||
|
manifest = await github_api_get(url, error_msg)
|
||||||
|
return Manifest.parse_obj(manifest)
|
||||||
|
|
||||||
|
|
||||||
class CreateExtension(BaseModel):
|
class CreateExtension(BaseModel):
|
||||||
ext_id: str
|
ext_id: str
|
||||||
|
@ -769,32 +751,3 @@ class ExtensionDetailsRequest(BaseModel):
|
||||||
ext_id: str
|
ext_id: str
|
||||||
source_repo: str
|
source_repo: str
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
|
||||||
def get_valid_extensions(include_deactivated: Optional[bool] = True) -> List[Extension]:
|
|
||||||
valid_extensions = [
|
|
||||||
extension for extension in ExtensionManager().extensions if extension.is_valid
|
|
||||||
]
|
|
||||||
|
|
||||||
if include_deactivated:
|
|
||||||
return valid_extensions
|
|
||||||
|
|
||||||
if settings.lnbits_extensions_deactivate_all:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [
|
|
||||||
e
|
|
||||||
for e in valid_extensions
|
|
||||||
if e.code not in settings.lnbits_deactivated_extensions
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def version_parse(v: str):
|
|
||||||
"""
|
|
||||||
Wrapper for version.parse() that does not throw if the version is invalid.
|
|
||||||
Instead it return the lowest possible version ("0.0.0")
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return version.parse(v)
|
|
||||||
except Exception:
|
|
||||||
return version.parse("0.0.0")
|
|
|
@ -1,9 +1,8 @@
|
||||||
import importlib
|
import importlib
|
||||||
import re
|
import re
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import httpx
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lnbits.core import migrations as core_migrations
|
from lnbits.core import migrations as core_migrations
|
||||||
|
@ -13,11 +12,10 @@ from lnbits.core.crud import (
|
||||||
update_migration_version,
|
update_migration_version,
|
||||||
)
|
)
|
||||||
from lnbits.core.db import db as core_db
|
from lnbits.core.db import db as core_db
|
||||||
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
|
from lnbits.core.extensions.models import (
|
||||||
from lnbits.extension_manager import (
|
|
||||||
Extension,
|
Extension,
|
||||||
get_valid_extensions,
|
|
||||||
)
|
)
|
||||||
|
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,68 +53,6 @@ async def run_migration(
|
||||||
await update_migration_version(conn, db_name, version)
|
await update_migration_version(conn, db_name, version)
|
||||||
|
|
||||||
|
|
||||||
async def stop_extension_background_work(
|
|
||||||
ext_id: str, user: str, access_token: Optional[str] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
|
|
||||||
Extensions SHOULD expose a `api_stop()` function and/or a DELETE enpoint
|
|
||||||
at the root level of their API.
|
|
||||||
"""
|
|
||||||
stopped = await _stop_extension_background_work(ext_id)
|
|
||||||
|
|
||||||
if not stopped:
|
|
||||||
# fallback to REST API call
|
|
||||||
await _stop_extension_background_work_via_api(ext_id, user, access_token)
|
|
||||||
|
|
||||||
|
|
||||||
async def _stop_extension_background_work(ext_id) -> bool:
|
|
||||||
upgrade_hash = settings.extension_upgrade_hash(ext_id) or ""
|
|
||||||
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
|
|
||||||
old_module = importlib.import_module(ext.module_name)
|
|
||||||
|
|
||||||
# Extensions must expose an `{ext_id}_stop()` function at the module level
|
|
||||||
# The `api_stop()` function is for backwards compatibility (will be deprecated)
|
|
||||||
stop_fns = [f"{ext_id}_stop", "api_stop"]
|
|
||||||
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None)
|
|
||||||
assert stop_fn_name, "No stop function found for '{ext.module_name}'"
|
|
||||||
|
|
||||||
stop_fn = getattr(old_module, stop_fn_name)
|
|
||||||
if stop_fn:
|
|
||||||
await stop_fn()
|
|
||||||
|
|
||||||
logger.info(f"Stopped background work for extension '{ext.module_name}'.")
|
|
||||||
except Exception as ex:
|
|
||||||
logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
|
|
||||||
logger.warning(ex)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def _stop_extension_background_work_via_api(ext_id, user, access_token):
|
|
||||||
logger.info(
|
|
||||||
f"Stopping background work for extension '{ext_id}' using the REST API."
|
|
||||||
)
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
try:
|
|
||||||
url = f"http://{settings.host}:{settings.port}/{ext_id}/api/v1?usr={user}"
|
|
||||||
headers = (
|
|
||||||
{"Authorization": "Bearer " + access_token} if access_token else None
|
|
||||||
)
|
|
||||||
resp = await client.delete(url=url, headers=headers)
|
|
||||||
resp.raise_for_status()
|
|
||||||
logger.info(f"Stopped background work for extension '{ext_id}'.")
|
|
||||||
except Exception as ex:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to stop background work for '{ext_id}' using the REST API."
|
|
||||||
)
|
|
||||||
logger.warning(ex)
|
|
||||||
|
|
||||||
|
|
||||||
def to_valid_user_id(user_id: str) -> UUID:
|
def to_valid_user_id(user_id: str) -> UUID:
|
||||||
if len(user_id) < 32:
|
if len(user_id) < 32:
|
||||||
raise ValueError("User ID must have at least 128 bits")
|
raise ValueError("User ID must have at least 128 bits")
|
||||||
|
@ -161,7 +97,7 @@ async def migrate_databases():
|
||||||
await load_disabled_extension_list()
|
await load_disabled_extension_list()
|
||||||
|
|
||||||
# todo: revisit, use installed extensions
|
# todo: revisit, use installed extensions
|
||||||
for ext in get_valid_extensions(False):
|
for ext in Extension.get_valid_extensions(False):
|
||||||
current_version = current_versions.get(ext.code, 0)
|
current_version = current_versions.get(ext.code, 0)
|
||||||
try:
|
try:
|
||||||
await migrate_extension_database(ext, current_version)
|
await migrate_extension_database(ext, current_version)
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import sys
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (
|
from typing import (
|
||||||
List,
|
List,
|
||||||
Optional,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from bolt11 import decode as bolt11_decode
|
from bolt11 import decode as bolt11_decode
|
||||||
|
@ -13,10 +11,21 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lnbits.core.db import core_app_extra
|
from lnbits.core.extensions.extension_manager import (
|
||||||
from lnbits.core.helpers import (
|
activate_extension,
|
||||||
migrate_extension_database,
|
deactivate_extension,
|
||||||
stop_extension_background_work,
|
install_extension,
|
||||||
|
uninstall_extension,
|
||||||
|
)
|
||||||
|
from lnbits.core.extensions.models import (
|
||||||
|
CreateExtension,
|
||||||
|
Extension,
|
||||||
|
ExtensionConfig,
|
||||||
|
ExtensionRelease,
|
||||||
|
InstallableExtension,
|
||||||
|
PayToEnableInfo,
|
||||||
|
ReleasePaymentInfo,
|
||||||
|
UserExtensionInfo,
|
||||||
)
|
)
|
||||||
from lnbits.core.models import (
|
from lnbits.core.models import (
|
||||||
SimpleStatus,
|
SimpleStatus,
|
||||||
|
@ -24,36 +33,18 @@ from lnbits.core.models import (
|
||||||
)
|
)
|
||||||
from lnbits.core.services import check_transaction_status, create_invoice
|
from lnbits.core.services import check_transaction_status, create_invoice
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
check_access_token,
|
|
||||||
check_admin,
|
check_admin,
|
||||||
check_user_exists,
|
check_user_exists,
|
||||||
)
|
)
|
||||||
from lnbits.extension_manager import (
|
|
||||||
CreateExtension,
|
|
||||||
Extension,
|
|
||||||
ExtensionRelease,
|
|
||||||
InstallableExtension,
|
|
||||||
PayToEnableInfo,
|
|
||||||
ReleasePaymentInfo,
|
|
||||||
UserExtensionInfo,
|
|
||||||
fetch_github_release_config,
|
|
||||||
fetch_release_details,
|
|
||||||
fetch_release_payment_info,
|
|
||||||
get_valid_extensions,
|
|
||||||
)
|
|
||||||
from lnbits.settings import settings
|
|
||||||
|
|
||||||
from ..crud import (
|
from ..crud import (
|
||||||
add_installed_extension,
|
|
||||||
delete_dbversion,
|
delete_dbversion,
|
||||||
delete_installed_extension,
|
|
||||||
drop_extension_db,
|
drop_extension_db,
|
||||||
get_dbversions,
|
get_dbversions,
|
||||||
get_installed_extension,
|
get_installed_extension,
|
||||||
get_installed_extensions,
|
get_installed_extensions,
|
||||||
get_user_extension,
|
get_user_extension,
|
||||||
update_extension_pay_to_enable,
|
update_extension_pay_to_enable,
|
||||||
update_installed_extension_state,
|
|
||||||
update_user_extension,
|
update_user_extension,
|
||||||
update_user_extension_extra,
|
update_user_extension_extra,
|
||||||
)
|
)
|
||||||
|
@ -64,12 +55,8 @@ extension_router = APIRouter(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@extension_router.post("")
|
@extension_router.post("", dependencies=[Depends(check_admin)])
|
||||||
async def api_install_extension(
|
async def api_install_extension(data: CreateExtension):
|
||||||
data: CreateExtension,
|
|
||||||
user: User = Depends(check_admin),
|
|
||||||
access_token: Optional[str] = Depends(check_access_token),
|
|
||||||
):
|
|
||||||
release = await InstallableExtension.get_extension_release(
|
release = await InstallableExtension.get_extension_release(
|
||||||
data.ext_id, data.source_repo, data.archive, data.version
|
data.ext_id, data.source_repo, data.archive, data.version
|
||||||
)
|
)
|
||||||
|
@ -89,43 +76,36 @@ async def api_install_extension(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_ext = await get_installed_extension(data.ext_id)
|
extension = await install_extension(ext_info)
|
||||||
ext_info.payments = installed_ext.payments if installed_ext else []
|
|
||||||
|
|
||||||
await ext_info.download_archive()
|
|
||||||
|
|
||||||
ext_info.extract_archive()
|
|
||||||
|
|
||||||
extension = Extension.from_installable_ext(ext_info)
|
|
||||||
|
|
||||||
db_version = (await get_dbversions()).get(data.ext_id, 0)
|
|
||||||
await migrate_extension_database(extension, db_version)
|
|
||||||
|
|
||||||
ext_info.active = True
|
|
||||||
await add_installed_extension(ext_info)
|
|
||||||
|
|
||||||
if extension.is_upgrade_extension:
|
|
||||||
# call stop while the old routes are still active
|
|
||||||
await stop_extension_background_work(data.ext_id, user.id, access_token)
|
|
||||||
|
|
||||||
# mount routes for the new version
|
|
||||||
core_app_extra.register_new_ext_routes(extension)
|
|
||||||
|
|
||||||
ext_info.notify_upgrade(extension.upgrade_hash)
|
|
||||||
settings.lnbits_deactivated_extensions.discard(data.ext_id)
|
|
||||||
|
|
||||||
return extension
|
|
||||||
except AssertionError as exc:
|
|
||||||
raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(exc)
|
logger.warning(exc)
|
||||||
ext_info.clean_extension_files()
|
ext_info.clean_extension_files()
|
||||||
|
detail = (
|
||||||
|
str(exc)
|
||||||
|
if isinstance(exc, AssertionError)
|
||||||
|
else f"Failed to install extension '{ext_info.id}'."
|
||||||
|
f"({ext_info.installed_version})."
|
||||||
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
detail=(
|
detail=detail,
|
||||||
f"Failed to install extension {ext_info.id} "
|
) from exc
|
||||||
f"({ext_info.installed_version})."
|
|
||||||
),
|
try:
|
||||||
|
await activate_extension(extension)
|
||||||
|
return extension
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(exc)
|
||||||
|
await deactivate_extension(extension.code)
|
||||||
|
detail = (
|
||||||
|
str(exc)
|
||||||
|
if isinstance(exc, AssertionError)
|
||||||
|
else f"Extension `{extension.code}` installed, but activation failed."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail,
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,7 +123,7 @@ async def api_extension_details(
|
||||||
)
|
)
|
||||||
assert release, "Details not found for release"
|
assert release, "Details not found for release"
|
||||||
|
|
||||||
release_details = await fetch_release_details(details_link)
|
release_details = await ExtensionRelease.fetch_release_details(details_link)
|
||||||
assert release_details, "Cannot fetch details for release"
|
assert release_details, "Cannot fetch details for release"
|
||||||
release_details["icon"] = release.icon
|
release_details["icon"] = release.icon
|
||||||
release_details["repo"] = release.repo
|
release_details["repo"] = release.repo
|
||||||
|
@ -186,7 +166,7 @@ async def api_update_pay_to_enable(
|
||||||
async def api_enable_extension(
|
async def api_enable_extension(
|
||||||
ext_id: str, user: User = Depends(check_user_exists)
|
ext_id: str, user: User = Depends(check_user_exists)
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
if ext_id not in [e.code for e in get_valid_extensions()]:
|
if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' doesn't exist."
|
HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' doesn't exist."
|
||||||
)
|
)
|
||||||
|
@ -249,7 +229,7 @@ async def api_enable_extension(
|
||||||
async def api_disable_extension(
|
async def api_disable_extension(
|
||||||
ext_id: str, user: User = Depends(check_user_exists)
|
ext_id: str, user: User = Depends(check_user_exists)
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
if ext_id not in [e.code for e in get_valid_extensions()]:
|
if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
|
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
|
||||||
)
|
)
|
||||||
|
@ -270,20 +250,14 @@ async def api_activate_extension(ext_id: str) -> SimpleStatus:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Activating extension: '{ext_id}'.")
|
logger.info(f"Activating extension: '{ext_id}'.")
|
||||||
|
|
||||||
all_extensions = get_valid_extensions()
|
ext = Extension.get_valid_extension(ext_id)
|
||||||
ext = next((e for e in all_extensions if e.code == ext_id), None)
|
|
||||||
assert ext, f"Extension '{ext_id}' doesn't exist."
|
assert ext, f"Extension '{ext_id}' doesn't exist."
|
||||||
# if extension never loaded (was deactivated on server startup)
|
|
||||||
if ext_id not in sys.modules.keys():
|
|
||||||
# run extension start-up routine
|
|
||||||
core_app_extra.register_new_ext_routes(ext)
|
|
||||||
|
|
||||||
settings.lnbits_deactivated_extensions.discard(ext_id)
|
await activate_extension(ext)
|
||||||
|
|
||||||
await update_installed_extension_state(ext_id=ext_id, active=True)
|
|
||||||
return SimpleStatus(success=True, message=f"Extension '{ext_id}' activated.")
|
return SimpleStatus(success=True, message=f"Extension '{ext_id}' activated.")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(exc)
|
logger.warning(exc)
|
||||||
|
await deactivate_extension(ext_id)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
detail=(f"Failed to activate '{ext_id}'."),
|
detail=(f"Failed to activate '{ext_id}'."),
|
||||||
|
@ -295,13 +269,10 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Deactivating extension: '{ext_id}'.")
|
logger.info(f"Deactivating extension: '{ext_id}'.")
|
||||||
|
|
||||||
all_extensions = get_valid_extensions()
|
ext = Extension.get_valid_extension(ext_id)
|
||||||
ext = next((e for e in all_extensions if e.code == ext_id), None)
|
|
||||||
assert ext, f"Extension '{ext_id}' doesn't exist."
|
assert ext, f"Extension '{ext_id}' doesn't exist."
|
||||||
|
|
||||||
settings.lnbits_deactivated_extensions.add(ext_id)
|
await deactivate_extension(ext_id)
|
||||||
|
|
||||||
await update_installed_extension_state(ext_id=ext_id, active=False)
|
|
||||||
return SimpleStatus(success=True, message=f"Extension '{ext_id}' deactivated.")
|
return SimpleStatus(success=True, message=f"Extension '{ext_id}' deactivated.")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(exc)
|
logger.warning(exc)
|
||||||
|
@ -311,23 +282,19 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus:
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
@extension_router.delete("/{ext_id}")
|
@extension_router.delete("/{ext_id}", dependencies=[Depends(check_admin)])
|
||||||
async def api_uninstall_extension(
|
async def api_uninstall_extension(ext_id: str) -> SimpleStatus:
|
||||||
ext_id: str,
|
|
||||||
user: User = Depends(check_admin),
|
|
||||||
access_token: Optional[str] = Depends(check_access_token),
|
|
||||||
) -> SimpleStatus:
|
|
||||||
installed_extensions = await get_installed_extensions()
|
|
||||||
|
|
||||||
extensions = [e for e in installed_extensions if e.id == ext_id]
|
extension = await get_installed_extension(ext_id)
|
||||||
if len(extensions) == 0:
|
if not extension:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND,
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
detail=f"Unknown extension id: {ext_id}",
|
detail=f"Unknown extension id: {ext_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
installed_extensions = await get_installed_extensions()
|
||||||
# check that other extensions do not depend on this one
|
# check that other extensions do not depend on this one
|
||||||
for valid_ext_id in [ext.code for ext in get_valid_extensions()]:
|
for valid_ext_id in [ext.code for ext in Extension.get_valid_extensions()]:
|
||||||
installed_ext = next(
|
installed_ext = next(
|
||||||
(ext for ext in installed_extensions if ext.id == valid_ext_id), None
|
(ext for ext in installed_extensions if ext.id == valid_ext_id), None
|
||||||
)
|
)
|
||||||
|
@ -341,14 +308,7 @@ async def api_uninstall_extension(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# call stop while the old routes are still active
|
await uninstall_extension(ext_id)
|
||||||
await stop_extension_background_work(ext_id, user.id, access_token)
|
|
||||||
|
|
||||||
settings.lnbits_deactivated_extensions.add(ext_id)
|
|
||||||
|
|
||||||
for ext_info in extensions:
|
|
||||||
ext_info.clean_extension_files()
|
|
||||||
await delete_installed_extension(ext_id=ext_info.id)
|
|
||||||
|
|
||||||
logger.success(f"Extension '{ext_id}' uninstalled.")
|
logger.success(f"Extension '{ext_id}' uninstalled.")
|
||||||
return SimpleStatus(success=True, message=f"Extension '{ext_id}' uninstalled.")
|
return SimpleStatus(success=True, message=f"Extension '{ext_id}' uninstalled.")
|
||||||
|
@ -397,9 +357,8 @@ async def get_pay_to_install_invoice(
|
||||||
assert release, "Release not found."
|
assert release, "Release not found."
|
||||||
assert release.pay_link, "Pay link not found for release."
|
assert release.pay_link, "Pay link not found for release."
|
||||||
|
|
||||||
payment_info = await fetch_release_payment_info(
|
payment_info = await release.fetch_release_payment_info(data.cost_sats)
|
||||||
release.pay_link, data.cost_sats
|
|
||||||
)
|
|
||||||
assert payment_info and payment_info.payment_request, "Cannot request invoice."
|
assert payment_info and payment_info.payment_request, "Cannot request invoice."
|
||||||
invoice = bolt11_decode(payment_info.payment_request)
|
invoice = bolt11_decode(payment_info.payment_request)
|
||||||
|
|
||||||
|
@ -474,7 +433,7 @@ async def get_pay_to_enable_invoice(
|
||||||
)
|
)
|
||||||
async def get_extension_release(org: str, repo: str, tag_name: str):
|
async def get_extension_release(org: str, repo: str, tag_name: str):
|
||||||
try:
|
try:
|
||||||
config = await fetch_github_release_config(org, repo, tag_name)
|
config = await ExtensionConfig.fetch_github_release_config(org, repo, tag_name)
|
||||||
if not config:
|
if not config:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from lnurl import decode as lnurl_decode
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic.types import UUID4
|
from pydantic.types import UUID4
|
||||||
|
|
||||||
|
from lnbits.core.extensions.models import Extension, InstallableExtension
|
||||||
from lnbits.core.helpers import to_valid_user_id
|
from lnbits.core.helpers import to_valid_user_id
|
||||||
from lnbits.core.models import User
|
from lnbits.core.models import User
|
||||||
from lnbits.core.services import create_invoice
|
from lnbits.core.services import create_invoice
|
||||||
|
@ -20,7 +21,6 @@ from lnbits.helpers import template_renderer
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
from lnbits.wallets import get_funding_source
|
from lnbits.wallets import get_funding_source
|
||||||
|
|
||||||
from ...extension_manager import InstallableExtension, get_valid_extensions
|
|
||||||
from ...utils.exchange_rates import allowed_currencies, currencies
|
from ...utils.exchange_rates import allowed_currencies, currencies
|
||||||
from ..crud import (
|
from ..crud import (
|
||||||
create_account,
|
create_account,
|
||||||
|
@ -104,7 +104,7 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)):
|
||||||
installed_exts_ids = []
|
installed_exts_ids = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
all_ext_ids = [ext.code for ext in get_valid_extensions()]
|
all_ext_ids = [ext.code for ext in Extension.get_valid_extensions()]
|
||||||
inactive_extensions = [
|
inactive_extensions = [
|
||||||
e.id for e in await get_installed_extensions(active=False)
|
e.id for e in await get_installed_extensions(active=False)
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,6 +10,7 @@ import shortuuid
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.schema import field_schema
|
from pydantic.schema import field_schema
|
||||||
|
|
||||||
|
from lnbits.core.extensions.models import Extension
|
||||||
from lnbits.db import get_placeholder
|
from lnbits.db import get_placeholder
|
||||||
from lnbits.jinja2_templating import Jinja2Templates
|
from lnbits.jinja2_templating import Jinja2Templates
|
||||||
from lnbits.nodes import get_node_class
|
from lnbits.nodes import get_node_class
|
||||||
|
@ -18,7 +19,6 @@ from lnbits.settings import settings
|
||||||
from lnbits.utils.crypto import AESCipher
|
from lnbits.utils.crypto import AESCipher
|
||||||
|
|
||||||
from .db import FilterModel
|
from .db import FilterModel
|
||||||
from .extension_manager import get_valid_extensions
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_vendor_name():
|
def get_db_vendor_name():
|
||||||
|
@ -93,7 +93,7 @@ def template_renderer(additional_folders: Optional[List] = None) -> Jinja2Templa
|
||||||
settings.lnbits_node_ui and get_node_class() is not None
|
settings.lnbits_node_ui and get_node_class() is not None
|
||||||
)
|
)
|
||||||
t.env.globals["LNBITS_NODE_UI_AVAILABLE"] = get_node_class() is not None
|
t.env.globals["LNBITS_NODE_UI_AVAILABLE"] = get_node_class() is not None
|
||||||
t.env.globals["EXTENSIONS"] = get_valid_extensions(False)
|
t.env.globals["EXTENSIONS"] = Extension.get_valid_extensions(False)
|
||||||
if settings.lnbits_custom_logo:
|
if settings.lnbits_custom_logo:
|
||||||
t.env.globals["USE_CUSTOM_LOGO"] = settings.lnbits_custom_logo
|
t.env.globals["USE_CUSTOM_LOGO"] = settings.lnbits_custom_logo
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Union
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||||
|
@ -45,16 +45,11 @@ class InstalledExtensionMiddleware:
|
||||||
await self.app(scope, receive, send)
|
await self.app(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
upgrade_path = next(
|
|
||||||
(
|
|
||||||
e
|
|
||||||
for e in settings.lnbits_upgraded_extensions
|
|
||||||
if e.endswith(f"/{top_path}")
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
# re-route all trafic if the extension has been upgraded
|
# re-route all trafic if the extension has been upgraded
|
||||||
if upgrade_path:
|
if top_path in settings.lnbits_upgraded_extensions:
|
||||||
|
upgrade_path = (
|
||||||
|
f"""{settings.lnbits_upgraded_extensions[top_path]}/{top_path}"""
|
||||||
|
)
|
||||||
tail = "/".join(rest)
|
tail = "/".join(rest)
|
||||||
scope["path"] = f"/upgrades/{upgrade_path}/{tail}"
|
scope["path"] = f"/upgrades/{upgrade_path}/{tail}"
|
||||||
|
|
||||||
|
@ -118,72 +113,12 @@ class ExtensionsRedirectMiddleware:
|
||||||
return
|
return
|
||||||
|
|
||||||
req_headers = scope["headers"] if "headers" in scope else []
|
req_headers = scope["headers"] if "headers" in scope else []
|
||||||
redirect = self._find_redirect(scope["path"], req_headers)
|
redirect = settings.find_extension_redirect(scope["path"], req_headers)
|
||||||
if redirect:
|
if redirect:
|
||||||
scope["path"] = self._new_path(redirect, scope["path"])
|
scope["path"] = redirect.new_path_from(scope["path"])
|
||||||
|
|
||||||
await self.app(scope, receive, send)
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]):
|
|
||||||
return next(
|
|
||||||
(
|
|
||||||
r
|
|
||||||
for r in settings.lnbits_extensions_redirects
|
|
||||||
if self._redirect_matches(r, path, req_headers)
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _redirect_matches(
|
|
||||||
self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]]
|
|
||||||
) -> bool:
|
|
||||||
if "from_path" not in redirect:
|
|
||||||
return False
|
|
||||||
header_filters = (
|
|
||||||
redirect["header_filters"] if "header_filters" in redirect else {}
|
|
||||||
)
|
|
||||||
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
|
|
||||||
header_filters, req_headers
|
|
||||||
)
|
|
||||||
|
|
||||||
def _has_headers(
|
|
||||||
self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]]
|
|
||||||
) -> bool:
|
|
||||||
for h in filter_headers:
|
|
||||||
if not self._has_header(req_headers, (str(h), str(filter_headers[h]))):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _has_header(
|
|
||||||
self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str]
|
|
||||||
) -> bool:
|
|
||||||
for h in req_headers:
|
|
||||||
if (
|
|
||||||
h[0].decode().lower() == header[0].lower()
|
|
||||||
and h[1].decode() == header[1]
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _has_common_path(self, redirect_path: str, req_path: str) -> bool:
|
|
||||||
redirect_path_elements = redirect_path.split("/")
|
|
||||||
req_path_elements = req_path.split("/")
|
|
||||||
if len(redirect_path) > len(req_path):
|
|
||||||
return False
|
|
||||||
sub_path = req_path_elements[: len(redirect_path_elements)]
|
|
||||||
return redirect_path == "/".join(sub_path)
|
|
||||||
|
|
||||||
def _new_path(self, redirect: dict, req_path: str) -> str:
|
|
||||||
from_path = redirect["from_path"].split("/")
|
|
||||||
redirect_to = redirect["redirect_to_path"].split("/")
|
|
||||||
req_tail_path = req_path.split("/")[len(from_path) :]
|
|
||||||
|
|
||||||
elements = [
|
|
||||||
e for e in ([redirect["ext_id"], *redirect_to, *req_tail_path]) if e != ""
|
|
||||||
]
|
|
||||||
|
|
||||||
return "/" + "/".join(elements)
|
|
||||||
|
|
||||||
|
|
||||||
def add_ratelimit_middleware(app: FastAPI):
|
def add_ratelimit_middleware(app: FastAPI):
|
||||||
core_app_extra.register_new_ratelimiter()
|
core_app_extra.register_new_ratelimiter()
|
||||||
|
|
|
@ -62,26 +62,132 @@ class ExtensionsInstallSettings(LNbitsSettings):
|
||||||
lnbits_ext_github_token: str = Field(default="")
|
lnbits_ext_github_token: str = Field(default="")
|
||||||
|
|
||||||
|
|
||||||
|
class RedirectPath(BaseModel):
|
||||||
|
ext_id: str
|
||||||
|
from_path: str
|
||||||
|
redirect_to_path: str
|
||||||
|
header_filters: dict = {}
|
||||||
|
|
||||||
|
def in_conflict(self, other: RedirectPath) -> bool:
|
||||||
|
if self.ext_id == other.ext_id:
|
||||||
|
return False
|
||||||
|
return self.redirect_matches(
|
||||||
|
other.from_path, list(other.header_filters.items())
|
||||||
|
) or other.redirect_matches(self.from_path, list(self.header_filters.items()))
|
||||||
|
|
||||||
|
def find_in_conflict(self, others: list[RedirectPath]) -> Optional[RedirectPath]:
|
||||||
|
for other in others:
|
||||||
|
if self.in_conflict(other):
|
||||||
|
return other
|
||||||
|
return None
|
||||||
|
|
||||||
|
def new_path_from(self, req_path: str) -> str:
|
||||||
|
from_path = self.from_path.split("/")
|
||||||
|
redirect_to = self.redirect_to_path.split("/")
|
||||||
|
req_tail_path = req_path.split("/")[len(from_path) :]
|
||||||
|
|
||||||
|
elements = [e for e in ([self.ext_id, *redirect_to, *req_tail_path]) if e != ""]
|
||||||
|
|
||||||
|
return "/" + "/".join(elements)
|
||||||
|
|
||||||
|
def redirect_matches(self, path: str, req_headers: list[tuple[str, str]]) -> bool:
|
||||||
|
return self._has_common_path(path) and self._has_headers(req_headers)
|
||||||
|
|
||||||
|
def _has_common_path(self, req_path: str) -> bool:
|
||||||
|
if len(self.from_path) > len(req_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
redirect_path_elements = self.from_path.split("/")
|
||||||
|
req_path_elements = req_path.split("/")
|
||||||
|
|
||||||
|
sub_path = req_path_elements[: len(redirect_path_elements)]
|
||||||
|
return self.from_path == "/".join(sub_path)
|
||||||
|
|
||||||
|
def _has_headers(self, req_headers: list[tuple[str, str]]) -> bool:
|
||||||
|
for h in self.header_filters:
|
||||||
|
if not self._has_header(req_headers, (str(h), str(self.header_filters[h]))):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _has_header(
|
||||||
|
self, req_headers: list[tuple[str, str]], header: tuple[str, str]
|
||||||
|
) -> bool:
|
||||||
|
for h in req_headers:
|
||||||
|
if h[0].lower() == header[0].lower() and h[1].lower() == header[1].lower():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class InstalledExtensionsSettings(LNbitsSettings):
|
class InstalledExtensionsSettings(LNbitsSettings):
|
||||||
# installed extensions that have been deactivated
|
# installed extensions that have been deactivated
|
||||||
lnbits_deactivated_extensions: set[str] = Field(default=[])
|
lnbits_deactivated_extensions: set[str] = Field(default=[])
|
||||||
# upgraded extensions that require API redirects
|
# upgraded extensions that require API redirects
|
||||||
lnbits_upgraded_extensions: set[str] = Field(default=[])
|
lnbits_upgraded_extensions: dict[str, str] = Field(default={})
|
||||||
# list of redirects that extensions want to perform
|
# list of redirects that extensions want to perform
|
||||||
lnbits_extensions_redirects: list[Any] = Field(default=[])
|
lnbits_extensions_redirects: list[RedirectPath] = Field(default=[])
|
||||||
|
|
||||||
# list of all extension ids
|
# list of all extension ids
|
||||||
lnbits_all_extensions_ids: set[Any] = Field(default=[])
|
lnbits_all_extensions_ids: set[Any] = Field(default=[])
|
||||||
|
|
||||||
def extension_upgrade_path(self, ext_id: str) -> Optional[str]:
|
def find_extension_redirect(
|
||||||
|
self, path: str, req_headers: list[tuple[bytes, bytes]]
|
||||||
|
) -> Optional[RedirectPath]:
|
||||||
|
headers = [(k.decode(), v.decode()) for k, v in req_headers]
|
||||||
return next(
|
return next(
|
||||||
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")),
|
(
|
||||||
|
r
|
||||||
|
for r in self.lnbits_extensions_redirects
|
||||||
|
if r.redirect_matches(path, headers)
|
||||||
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def extension_upgrade_hash(self, ext_id: str) -> Optional[str]:
|
def activate_extension_paths(
|
||||||
path = settings.extension_upgrade_path(ext_id)
|
self,
|
||||||
return path.split("/")[0] if path else None
|
ext_id: str,
|
||||||
|
upgrade_hash: Optional[str] = None,
|
||||||
|
ext_redirects: Optional[list[dict]] = None,
|
||||||
|
):
|
||||||
|
self.lnbits_deactivated_extensions.discard(ext_id)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Update the list of upgraded extensions. The middleware will perform
|
||||||
|
redirects based on this
|
||||||
|
"""
|
||||||
|
if upgrade_hash:
|
||||||
|
self.lnbits_upgraded_extensions[ext_id] = upgrade_hash
|
||||||
|
|
||||||
|
if ext_redirects:
|
||||||
|
self._activate_extension_redirects(ext_id, ext_redirects)
|
||||||
|
|
||||||
|
self.lnbits_all_extensions_ids.add(ext_id)
|
||||||
|
|
||||||
|
def deactivate_extension_paths(self, ext_id: str):
|
||||||
|
self.lnbits_deactivated_extensions.add(ext_id)
|
||||||
|
self._remove_extension_redirects(ext_id)
|
||||||
|
|
||||||
|
def _activate_extension_redirects(self, ext_id: str, ext_redirects: list[dict]):
|
||||||
|
ext_redirect_paths = [
|
||||||
|
RedirectPath(**{"ext_id": ext_id, **er}) for er in ext_redirects
|
||||||
|
]
|
||||||
|
existing_redirects = {
|
||||||
|
r.ext_id
|
||||||
|
for r in self.lnbits_extensions_redirects
|
||||||
|
if r.find_in_conflict(ext_redirect_paths)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert len(existing_redirects) == 0, (
|
||||||
|
f"Cannot redirect for extension '{ext_id}'."
|
||||||
|
f" Already mapped by {existing_redirects}."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._remove_extension_redirects(ext_id)
|
||||||
|
self.lnbits_extensions_redirects += ext_redirect_paths
|
||||||
|
|
||||||
|
def _remove_extension_redirects(self, ext_id: str):
|
||||||
|
self.lnbits_extensions_redirects = [
|
||||||
|
er for er in self.lnbits_extensions_redirects if er.ext_id != ext_id
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ThemesSettings(LNbitsSettings):
|
class ThemesSettings(LNbitsSettings):
|
||||||
|
|
168
tests/unit/test_settings.py
Normal file
168
tests/unit/test_settings.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lnbits.settings import RedirectPath
|
||||||
|
|
||||||
|
lnurlp_redirect_path = {
|
||||||
|
"from_path": "/.well-known/lnurlp",
|
||||||
|
"redirect_to_path": "/api/v1/well-known",
|
||||||
|
}
|
||||||
|
lnurlp_redirect_path_with_headers = {
|
||||||
|
"from_path": "/.well-known/lnurlp",
|
||||||
|
"redirect_to_path": "/api/v1/well-known",
|
||||||
|
"header_filters": {"accept": "application/nostr+json"},
|
||||||
|
}
|
||||||
|
|
||||||
|
lnaddress_redirect_path = {
|
||||||
|
"from_path": "/.well-known/lnurlp",
|
||||||
|
"redirect_to_path": "/api/v1/well-known",
|
||||||
|
}
|
||||||
|
|
||||||
|
nostrrelay_redirect_path = {
|
||||||
|
"from_path": "/",
|
||||||
|
"redirect_to_path": "/api/v1/relay-info",
|
||||||
|
"header_filters": {"accept": "application/nostr+json"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def lnurlp():
|
||||||
|
return RedirectPath(ext_id="lnurlp", **lnurlp_redirect_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def lnurlp_with_headers():
|
||||||
|
return RedirectPath(
|
||||||
|
ext_id="lnurlp_with_headers", **lnurlp_redirect_path_with_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def lnaddress():
|
||||||
|
return RedirectPath(ext_id="lnaddress", **lnaddress_redirect_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def nostrrelay():
|
||||||
|
return RedirectPath(ext_id="nostrrelay", **nostrrelay_redirect_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_self_not_in_conflict(
|
||||||
|
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
|
||||||
|
):
|
||||||
|
assert not lnurlp.in_conflict(lnurlp), "Path is not in conflict with itself."
|
||||||
|
assert not lnaddress.in_conflict(lnaddress), "Path is not in conflict with itself."
|
||||||
|
assert not nostrrelay.in_conflict(
|
||||||
|
nostrrelay
|
||||||
|
), "Path is not in conflict with itself."
|
||||||
|
|
||||||
|
assert not lnurlp.in_conflict(nostrrelay)
|
||||||
|
|
||||||
|
assert not nostrrelay.in_conflict(lnurlp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_not_in_conflict(
|
||||||
|
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
|
||||||
|
):
|
||||||
|
|
||||||
|
assert not lnurlp.in_conflict(nostrrelay)
|
||||||
|
|
||||||
|
assert not nostrrelay.in_conflict(lnurlp)
|
||||||
|
|
||||||
|
assert not lnaddress.in_conflict(nostrrelay)
|
||||||
|
|
||||||
|
assert not nostrrelay.in_conflict(lnaddress)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_in_conflict(lnurlp: RedirectPath, lnaddress: RedirectPath):
|
||||||
|
assert lnurlp.in_conflict(lnaddress)
|
||||||
|
assert lnaddress.in_conflict(lnurlp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_find_conflict(
|
||||||
|
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
|
||||||
|
):
|
||||||
|
assert lnurlp.find_in_conflict([nostrrelay, lnaddress])
|
||||||
|
assert lnurlp.find_in_conflict([lnaddress, nostrrelay])
|
||||||
|
assert lnaddress.find_in_conflict([nostrrelay, lnurlp])
|
||||||
|
assert lnaddress.find_in_conflict([lnurlp, nostrrelay])
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_find_no_conflict(
|
||||||
|
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
|
||||||
|
):
|
||||||
|
assert not nostrrelay.find_in_conflict([lnurlp, lnaddress])
|
||||||
|
assert not lnurlp.find_in_conflict([nostrrelay])
|
||||||
|
assert not lnaddress.find_in_conflict([nostrrelay])
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_in_conflict_with_headers(
|
||||||
|
lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath
|
||||||
|
):
|
||||||
|
assert lnurlp.in_conflict(lnurlp_with_headers)
|
||||||
|
assert lnurlp_with_headers.in_conflict(lnurlp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_matches_with_headers(
|
||||||
|
lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath
|
||||||
|
):
|
||||||
|
headers_list = list(lnurlp_with_headers.header_filters.items())
|
||||||
|
assert lnurlp.redirect_matches(
|
||||||
|
path=lnurlp_with_headers.from_path,
|
||||||
|
req_headers=headers_list,
|
||||||
|
)
|
||||||
|
assert lnurlp_with_headers.redirect_matches(
|
||||||
|
path=lnurlp_redirect_path["from_path"],
|
||||||
|
req_headers=[("ACCEPT", "APPlication/nostr+json")],
|
||||||
|
)
|
||||||
|
assert lnurlp_with_headers.redirect_matches(
|
||||||
|
path=lnurlp_redirect_path["from_path"],
|
||||||
|
req_headers=[("accept", "application/nostr+json"), ("my_header", "my_value")],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not lnurlp_with_headers.redirect_matches(
|
||||||
|
path=lnurlp_redirect_path["from_path"], req_headers=[]
|
||||||
|
)
|
||||||
|
assert not lnurlp_with_headers.redirect_matches(
|
||||||
|
path=lnurlp_redirect_path["from_path"],
|
||||||
|
req_headers=[("accept", "application/json")],
|
||||||
|
)
|
||||||
|
assert not lnurlp_with_headers.redirect_matches(path="/random/path", req_headers=[])
|
||||||
|
assert not lnurlp_with_headers.redirect_matches(path="/random_path", req_headers=[])
|
||||||
|
assert not lnurlp_with_headers.redirect_matches(
|
||||||
|
path="/.well-known/lnurlp", req_headers=[]
|
||||||
|
)
|
||||||
|
assert lnurlp.redirect_matches(path="/.well-known/lnurlp", req_headers=[])
|
||||||
|
assert lnurlp.redirect_matches(
|
||||||
|
path="/.well-known/lnurlp/some/other/path", req_headers=[]
|
||||||
|
)
|
||||||
|
assert lnurlp.redirect_matches(
|
||||||
|
path="/.well-known/lnurlp/some/other/path",
|
||||||
|
req_headers=headers_list,
|
||||||
|
)
|
||||||
|
assert not lnurlp_with_headers.redirect_matches(
|
||||||
|
path="/.well-known/lnurlp", req_headers=[]
|
||||||
|
)
|
||||||
|
assert not lnurlp_with_headers.redirect_matches(
|
||||||
|
path="/.well-known/lnurlp/some/other/path", req_headers=[]
|
||||||
|
)
|
||||||
|
assert lnurlp_with_headers.redirect_matches(
|
||||||
|
path="/.well-known/lnurlp/some/other/path",
|
||||||
|
req_headers=headers_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_path_new_path_from(lnurlp: RedirectPath):
|
||||||
|
assert lnurlp.new_path_from("") == "/lnurlp/api/v1/well-known"
|
||||||
|
assert lnurlp.new_path_from("/") == "/lnurlp/api/v1/well-known"
|
||||||
|
assert lnurlp.new_path_from("/path") == "/lnurlp/api/v1/well-known"
|
||||||
|
assert lnurlp.new_path_from("/path/more") == "/lnurlp/api/v1/well-known"
|
||||||
|
|
||||||
|
assert lnurlp.new_path_from("/.well-known/lnurlp") == "/lnurlp/api/v1/well-known"
|
||||||
|
assert (
|
||||||
|
lnurlp.new_path_from("/.well-known/lnurlp/path")
|
||||||
|
== "/lnurlp/api/v1/well-known/path"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
lnurlp.new_path_from("/.well-known/lnurlp/path/more")
|
||||||
|
== "/lnurlp/api/v1/well-known/path/more"
|
||||||
|
)
|
Loading…
Add table
Reference in a new issue