Fix overlapping redirect paths (#2671)

This commit is contained in:
Vlad Stan 2024-09-11 12:41:37 +03:00 committed by GitHub
parent 7a5e7fbd8c
commit 5f4f1288d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 684 additions and 483 deletions

4
.gitignore vendored
View file

@ -49,8 +49,8 @@ fly.toml
lnbits-backup.zip lnbits-backup.zip
# Ignore extensions (post installable extension PR) # Ignore extensions (post installable extension PR)
extensions /extensions
upgrades/ /upgrades/
# builded python package # builded python package
dist dist

View file

@ -17,10 +17,13 @@ from slowapi.util import get_remote_address
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from lnbits.core.crud import ( from lnbits.core.crud import (
add_installed_extension,
get_dbversions, get_dbversions,
get_installed_extensions, get_installed_extensions,
update_installed_extension_state, update_installed_extension_state,
) )
from lnbits.core.extensions.extension_manager import deactivate_extension
from lnbits.core.extensions.helpers import version_parse
from lnbits.core.helpers import migrate_extension_database from lnbits.core.helpers import migrate_extension_database
from lnbits.core.tasks import ( # watchdog_task from lnbits.core.tasks import ( # watchdog_task
killswitch_task, killswitch_task,
@ -44,14 +47,8 @@ from lnbits.wallets import get_funding_source, set_funding_source
from .commands import migrate_databases from .commands import migrate_databases
from .core import init_core_routers from .core import init_core_routers
from .core.db import core_app_extra from .core.db import core_app_extra
from .core.extensions.models import Extension, InstallableExtension
from .core.services import check_admin_settings, check_webpush_settings from .core.services import check_admin_settings, check_webpush_settings
from .core.views.extension_api import add_installed_extension
from .extension_manager import (
Extension,
InstallableExtension,
get_valid_extensions,
version_parse,
)
from .middleware import ( from .middleware import (
CustomGZipMiddleware, CustomGZipMiddleware,
ExtensionsRedirectMiddleware, ExtensionsRedirectMiddleware,
@ -243,6 +240,7 @@ async def check_installed_extensions(app: FastAPI):
) )
except Exception as e: except Exception as e:
logger.warning(e) logger.warning(e)
await deactivate_extension(ext.id)
logger.warning( logger.warning(
f"Failed to re-install extension: {ext.id} ({ext.installed_version})" f"Failed to re-install extension: {ext.id} ({ext.installed_version})"
) )
@ -317,7 +315,6 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
# mount routes for the new version # mount routes for the new version
core_app_extra.register_new_ext_routes(extension) core_app_extra.register_new_ext_routes(extension)
ext.notify_upgrade(extension.upgrade_hash)
def register_custom_extensions_path(): def register_custom_extensions_path():
@ -380,24 +377,22 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None:
) )
app.mount(s["path"], StaticFiles(directory=static_dir), s["name"]) app.mount(s["path"], StaticFiles(directory=static_dir), s["name"])
if hasattr(ext_module, f"{ext.code}_redirect_paths"): ext_redirects = (
ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths") getattr(ext_module, f"{ext.code}_redirect_paths")
settings.lnbits_extensions_redirects = [ if hasattr(ext_module, f"{ext.code}_redirect_paths")
r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code else []
] )
for r in ext_redirects:
r["ext_id"] = ext.code
settings.lnbits_extensions_redirects.append(r)
logger.trace(f"adding route for extension {ext_module}") settings.activate_extension_paths(ext.code, ext.upgrade_hash, ext_redirects)
logger.trace(f"Adding route for extension {ext_module}.")
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else "" prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""
app.include_router(router=ext_route, prefix=prefix) app.include_router(router=ext_route, prefix=prefix)
async def check_and_register_extensions(app: FastAPI): async def check_and_register_extensions(app: FastAPI):
await check_installed_extensions(app) await check_installed_extensions(app)
for ext in get_valid_extensions(False): for ext in Extension.get_valid_extensions(False):
try: try:
register_ext_routes(app, ext) register_ext_routes(app, ext)
except Exception as exc: except Exception as exc:

View file

@ -25,18 +25,18 @@ from lnbits.core.crud import (
remove_deleted_wallets, remove_deleted_wallets,
update_payment_status, update_payment_status,
) )
from lnbits.core.extensions.models import (
CreateExtension,
ExtensionRelease,
InstallableExtension,
)
from lnbits.core.helpers import migrate_databases from lnbits.core.helpers import migrate_databases
from lnbits.core.models import Payment, PaymentState, User from lnbits.core.models import Payment, PaymentState
from lnbits.core.services import check_admin_settings from lnbits.core.services import check_admin_settings
from lnbits.core.views.extension_api import ( from lnbits.core.views.extension_api import (
api_install_extension, api_install_extension,
api_uninstall_extension, api_uninstall_extension,
) )
from lnbits.extension_manager import (
CreateExtension,
ExtensionRelease,
InstallableExtension,
)
from lnbits.settings import settings from lnbits.settings import settings
from lnbits.wallets.base import Wallet from lnbits.wallets.base import Wallet
@ -611,7 +611,7 @@ async def _call_install_extension(
) )
resp.raise_for_status() resp.raise_for_status()
else: else:
await api_install_extension(data, User(id="mock_id")) await api_install_extension(data)
async def _call_uninstall_extension( async def _call_uninstall_extension(
@ -625,7 +625,7 @@ async def _call_uninstall_extension(
) )
resp.raise_for_status() resp.raise_for_status()
else: else:
await api_uninstall_extension(extension, User(id="mock_id")) await api_uninstall_extension(extension)
async def _can_run_operation(url) -> bool: async def _can_run_operation(url) -> bool:

View file

@ -8,14 +8,14 @@ import shortuuid
from passlib.context import CryptContext from passlib.context import CryptContext
from lnbits.core.db import db from lnbits.core.db import db
from lnbits.core.models import PaymentState from lnbits.core.extensions.models import (
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
from lnbits.extension_manager import (
InstallableExtension, InstallableExtension,
PayToEnableInfo, PayToEnableInfo,
UserExtension, UserExtension,
UserExtensionInfo, UserExtensionInfo,
) )
from lnbits.core.models import PaymentState
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
from lnbits.settings import ( from lnbits.settings import (
AdminSettings, AdminSettings,
EditableSettings, EditableSettings,
@ -430,7 +430,7 @@ async def get_installed_extension(
async def get_installed_extensions( async def get_installed_extensions(
active: Optional[bool] = None, active: Optional[bool] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List["InstallableExtension"]: ) -> List[InstallableExtension]:
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
"SELECT * FROM installed_extensions", "SELECT * FROM installed_extensions",
(), (),

View file

@ -0,0 +1,93 @@
import asyncio
import importlib
from loguru import logger
from lnbits.core.crud import (
add_installed_extension,
delete_installed_extension,
get_dbversions,
get_installed_extension,
update_installed_extension_state,
)
from lnbits.core.db import core_app_extra
from lnbits.core.helpers import migrate_extension_database
from lnbits.settings import settings
from .models import Extension, InstallableExtension
async def install_extension(ext_info: InstallableExtension) -> Extension:
extension = Extension.from_installable_ext(ext_info)
installed_ext = await get_installed_extension(ext_info.id)
ext_info.payments = installed_ext.payments if installed_ext else []
await ext_info.download_archive()
ext_info.extract_archive()
db_version = (await get_dbversions()).get(ext_info.id, 0)
await migrate_extension_database(extension, db_version)
await add_installed_extension(ext_info)
if extension.is_upgrade_extension:
# call stop while the old routes are still active
await stop_extension_background_work(ext_info.id)
return extension
async def uninstall_extension(ext_id: str):
await stop_extension_background_work(ext_id)
settings.deactivate_extension_paths(ext_id)
extension = await get_installed_extension(ext_id)
if extension:
extension.clean_extension_files()
await delete_installed_extension(ext_id=ext_id)
async def activate_extension(ext: Extension):
core_app_extra.register_new_ext_routes(ext)
await update_installed_extension_state(ext_id=ext.code, active=True)
async def deactivate_extension(ext_id: str):
settings.deactivate_extension_paths(ext_id)
await update_installed_extension_state(ext_id=ext_id, active=False)
async def stop_extension_background_work(ext_id: str) -> bool:
"""
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
Extensions SHOULD expose a `api_stop()` function.
"""
upgrade_hash = settings.lnbits_upgraded_extensions.get(ext_id, "")
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
try:
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
old_module = importlib.import_module(ext.module_name)
# Extensions must expose an `{ext_id}_stop()` function at the module level
# The `api_stop()` function is for backwards compatibility (will be deprecated)
stop_fns = [f"{ext_id}_stop", "api_stop"]
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None)
assert stop_fn_name, "No stop function found for '{ext.module_name}'"
stop_fn = getattr(old_module, stop_fn_name)
if stop_fn:
if asyncio.iscoroutinefunction(stop_fn):
await stop_fn()
else:
stop_fn()
logger.info(f"Stopped background work for extension '{ext.module_name}'.")
except Exception as ex:
logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
logger.warning(ex)
return False
return True

View file

@ -0,0 +1,56 @@
import hashlib
from typing import Any, Optional
from urllib import request
import httpx
from loguru import logger
from packaging import version
from lnbits.settings import settings
def version_parse(v: str):
"""
Wrapper for version.parse() that does not throw if the version is invalid.
Instead it return the lowest possible version ("0.0.0")
"""
try:
return version.parse(v)
except Exception:
return version.parse("0.0.0")
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
headers = {"User-Agent": settings.user_agent}
if settings.lnbits_ext_github_token:
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
async with httpx.AsyncClient(headers=headers) as client:
resp = await client.get(url)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
def download_url(url, save_path):
with request.urlopen(url, timeout=60) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio import asyncio
import hashlib import hashlib
import json import json
@ -6,16 +8,22 @@ import shutil
import sys import sys
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from typing import Any, List, NamedTuple, Optional, Tuple from typing import Any, NamedTuple, Optional
from urllib import request
import httpx import httpx
from loguru import logger from loguru import logger
from packaging import version
from pydantic import BaseModel from pydantic import BaseModel
from lnbits.settings import settings from lnbits.settings import settings
from .helpers import (
download_url,
file_hash,
github_api_get,
icon_to_github_url,
version_parse,
)
class ExplicitRelease(BaseModel): class ExplicitRelease(BaseModel):
id: str id: str
@ -23,7 +31,7 @@ class ExplicitRelease(BaseModel):
version: str version: str
archive: str archive: str
hash: str hash: str
dependencies: List[str] = [] dependencies: list[str] = []
repo: Optional[str] repo: Optional[str]
icon: Optional[str] icon: Optional[str]
short_description: Optional[str] short_description: Optional[str]
@ -48,9 +56,9 @@ class GitHubRelease(BaseModel):
class Manifest(BaseModel): class Manifest(BaseModel):
featured: List[str] = [] featured: list[str] = []
extensions: List["ExplicitRelease"] = [] extensions: list[ExplicitRelease] = []
repos: List["GitHubRelease"] = [] repos: list[GitHubRelease] = []
class GitHubRepoRelease(BaseModel): class GitHubRepoRelease(BaseModel):
@ -81,6 +89,17 @@ class ExtensionConfig(BaseModel):
return True return True
return version_parse(self.min_lnbits_version) <= version_parse(settings.version) return version_parse(self.min_lnbits_version) <= version_parse(settings.version)
@classmethod
async def fetch_github_release_config(
cls, org: str, repo: str, tag_name: str
) -> Optional[ExtensionConfig]:
config_url = (
f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json"
)
error_msg = "Cannot fetch GitHub extension config"
config = await github_api_get(config_url, error_msg)
return ExtensionConfig.parse_obj(config)
class ReleasePaymentInfo(BaseModel): class ReleasePaymentInfo(BaseModel):
amount: Optional[int] = None amount: Optional[int] = None
@ -112,7 +131,7 @@ class UserExtension(BaseModel):
return self.extra.paid_to_enable is True return self.extra.paid_to_enable is True
@classmethod @classmethod
def from_row(cls, data: dict) -> "UserExtension": def from_row(cls, data: dict) -> UserExtension:
ext = UserExtension(**data) ext = UserExtension(**data)
ext.extra = ( ext.extra = (
UserExtensionInfo(**json.loads(data["_extra"] or "{}")) UserExtensionInfo(**json.loads(data["_extra"] or "{}"))
@ -122,124 +141,6 @@ class UserExtension(BaseModel):
return ext return ext
def download_url(url, save_path):
with request.urlopen(url, timeout=60) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
async def fetch_github_repo_info(
org: str, repository: str
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await github_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await github_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await github_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
async def fetch_manifest(url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await github_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await github_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
async def fetch_github_release_config(
org: str, repo: str, tag_name: str
) -> Optional[ExtensionConfig]:
config_url = (
f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json"
)
error_msg = "Cannot fetch GitHub extension config"
config = await github_api_get(config_url, error_msg)
return ExtensionConfig.parse_obj(config)
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
headers = {"User-Agent": settings.user_agent}
if settings.lnbits_ext_github_token:
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
async with httpx.AsyncClient(headers=headers) as client:
resp = await client.get(url)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
async def fetch_release_payment_info(
url: str, amount: Optional[int] = None
) -> Optional[ReleasePaymentInfo]:
if amount:
url = f"{url}?amount={amount}"
try:
async with httpx.AsyncClient() as client:
resp = await client.get(url)
resp.raise_for_status()
return ReleasePaymentInfo(**resp.json())
except Exception as e:
logger.warning(e)
return None
async def fetch_release_details(details_link: str) -> Optional[dict]:
try:
async with httpx.AsyncClient() as client:
resp = await client.get(details_link)
resp.raise_for_status()
data = resp.json()
if "description_md" in data:
resp = await client.get(data["description_md"])
if not resp.is_error:
data["description_md"] = resp.text
return data
except Exception as e:
logger.warning(e)
return None
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"
class Extension(NamedTuple): class Extension(NamedTuple):
code: str code: str
is_valid: bool is_valid: bool
@ -247,7 +148,7 @@ class Extension(NamedTuple):
name: Optional[str] = None name: Optional[str] = None
short_description: Optional[str] = None short_description: Optional[str] = None
tile: Optional[str] = None tile: Optional[str] = None
contributors: Optional[List[str]] = None contributors: Optional[list[str]] = None
hidden: bool = False hidden: bool = False
migration_module: Optional[str] = None migration_module: Optional[str] = None
db_name: Optional[str] = None db_name: Optional[str] = None
@ -269,7 +170,7 @@ class Extension(NamedTuple):
return self.upgrade_hash != "" return self.upgrade_hash != ""
@classmethod @classmethod
def from_installable_ext(cls, ext_info: "InstallableExtension") -> "Extension": def from_installable_ext(cls, ext_info: InstallableExtension) -> Extension:
return Extension( return Extension(
code=ext_info.id, code=ext_info.id,
is_valid=True, is_valid=True,
@ -278,22 +179,43 @@ class Extension(NamedTuple):
upgrade_hash=ext_info.hash if ext_info.module_installed else "", upgrade_hash=ext_info.hash if ext_info.module_installed else "",
) )
@classmethod
def get_valid_extensions(
cls, include_deactivated: Optional[bool] = True
) -> list[Extension]:
valid_extensions = [
extension for extension in cls._extensions() if extension.is_valid
]
# All subdirectories in the current directory, not recursive. if include_deactivated:
return valid_extensions
if settings.lnbits_extensions_deactivate_all:
return []
class ExtensionManager: return [
def __init__(self) -> None: e
for e in valid_extensions
if e.code not in settings.lnbits_deactivated_extensions
]
@classmethod
def get_valid_extension(
cls, ext_id: str, include_deactivated: Optional[bool] = True
) -> Optional[Extension]:
all_extensions = cls.get_valid_extensions(include_deactivated)
return next((e for e in all_extensions if e.code == ext_id), None)
@classmethod
def _extensions(cls) -> list[Extension]:
p = Path(settings.lnbits_extensions_path, "extensions") p = Path(settings.lnbits_extensions_path, "extensions")
Path(p).mkdir(parents=True, exist_ok=True) Path(p).mkdir(parents=True, exist_ok=True)
self._extension_folders: List[Path] = [f for f in p.iterdir() if f.is_dir()] extension_folders: list[Path] = [f for f in p.iterdir() if f.is_dir()]
@property
def extensions(self) -> List[Extension]:
# todo: remove this property somehow, it is too expensive # todo: remove this property somehow, it is too expensive
output: List[Extension] = [] output: list[Extension] = []
for extension_folder in self._extension_folders: for extension_folder in extension_folders:
extension_code = extension_folder.parts[-1] extension_code = extension_folder.parts[-1]
try: try:
with open(extension_folder / "config.json") as json_file: with open(extension_folder / "config.json") as json_file:
@ -356,13 +278,27 @@ class ExtensionRelease(BaseModel):
if not self.pay_link: if not self.pay_link:
return return
payment_info = await fetch_release_payment_info(self.pay_link) payment_info = await self.fetch_release_payment_info()
self.cost_sats = payment_info.amount if payment_info else None self.cost_sats = payment_info.amount if payment_info else None
async def fetch_release_payment_info(
self, amount: Optional[int] = None
) -> Optional[ReleasePaymentInfo]:
url = f"{self.pay_link}?amount={amount}" if amount else self.pay_link
assert url, "Missing URL for payment info."
try:
async with httpx.AsyncClient() as client:
resp = await client.get(url)
resp.raise_for_status()
return ReleasePaymentInfo(**resp.json())
except Exception as e:
logger.warning(e)
return None
@classmethod @classmethod
def from_github_release( def from_github_release(
cls, source_repo: str, r: "GitHubRepoRelease" cls, source_repo: str, r: GitHubRepoRelease
) -> "ExtensionRelease": ) -> ExtensionRelease:
return ExtensionRelease( return ExtensionRelease(
name=r.name, name=r.name,
description=r.name, description=r.name,
@ -377,8 +313,8 @@ class ExtensionRelease(BaseModel):
@classmethod @classmethod
def from_explicit_release( def from_explicit_release(
cls, source_repo: str, e: "ExplicitRelease" cls, source_repo: str, e: ExplicitRelease
) -> "ExtensionRelease": ) -> ExtensionRelease:
return ExtensionRelease( return ExtensionRelease(
name=e.name, name=e.name,
version=e.version, version=e.version,
@ -397,9 +333,9 @@ class ExtensionRelease(BaseModel):
) )
@classmethod @classmethod
async def get_github_releases(cls, org: str, repo: str) -> List["ExtensionRelease"]: async def get_github_releases(cls, org: str, repo: str) -> list[ExtensionRelease]:
try: try:
github_releases = await fetch_github_releases(org, repo) github_releases = await cls.fetch_github_releases(org, repo)
return [ return [
ExtensionRelease.from_github_release(f"{org}/{repo}", r) ExtensionRelease.from_github_release(f"{org}/{repo}", r)
for r in github_releases for r in github_releases
@ -408,6 +344,33 @@ class ExtensionRelease(BaseModel):
logger.warning(e) logger.warning(e)
return [] return []
@classmethod
async def fetch_github_releases(
cls, org: str, repo: str
) -> list[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await github_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
@classmethod
async def fetch_release_details(cls, details_link: str) -> Optional[dict]:
try:
async with httpx.AsyncClient() as client:
resp = await client.get(details_link)
resp.raise_for_status()
data = resp.json()
if "description_md" in data:
resp = await client.get(data["description_md"])
if not resp.is_error:
data["description_md"] = resp.text
return data
except Exception as e:
logger.warning(e)
return None
class InstallableExtension(BaseModel): class InstallableExtension(BaseModel):
id: str id: str
@ -415,13 +378,13 @@ class InstallableExtension(BaseModel):
active: Optional[bool] = False active: Optional[bool] = False
short_description: Optional[str] = None short_description: Optional[str] = None
icon: Optional[str] = None icon: Optional[str] = None
dependencies: List[str] = [] dependencies: list[str] = []
is_admin_only: bool = False is_admin_only: bool = False
stars: int = 0 stars: int = 0
featured = False featured = False
latest_release: Optional[ExtensionRelease] = None latest_release: Optional[ExtensionRelease] = None
installed_release: Optional[ExtensionRelease] = None installed_release: Optional[ExtensionRelease] = None
payments: List[ReleasePaymentInfo] = [] payments: list[ReleasePaymentInfo] = []
pay_to_enable: Optional[PayToEnableInfo] = None pay_to_enable: Optional[PayToEnableInfo] = None
archive: Optional[str] = None archive: Optional[str] = None
@ -546,16 +509,6 @@ class InstallableExtension(BaseModel):
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir)) shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
logger.success(f"Extension {self.name} ({self.installed_version}) installed.") logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
def notify_upgrade(self, upgrade_hash: Optional[str]) -> None:
"""
Update the list of upgraded extensions. The middleware will perform
redirects based on this
"""
if upgrade_hash:
settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}")
settings.lnbits_all_extensions_ids.add(self.id)
def clean_extension_files(self): def clean_extension_files(self):
# remove downloaded archive # remove downloaded archive
if self.zip_path.is_file(): if self.zip_path.is_file():
@ -610,7 +563,7 @@ class InstallableExtension(BaseModel):
self.payments.append(payment_info) self.payments.append(payment_info)
@classmethod @classmethod
def from_row(cls, data: dict) -> "InstallableExtension": def from_row(cls, data: dict) -> InstallableExtension:
meta = json.loads(data["meta"]) meta = json.loads(data["meta"])
ext = InstallableExtension(**data) ext = InstallableExtension(**data)
if "installed_release" in meta: if "installed_release" in meta:
@ -623,9 +576,7 @@ class InstallableExtension(BaseModel):
return ext return ext
@classmethod @classmethod
def from_rows( def from_rows(cls, rows: Optional[list[Any]] = None) -> list[InstallableExtension]:
cls, rows: Optional[List[Any]] = None
) -> List["InstallableExtension"]:
if rows is None: if rows is None:
rows = [] rows = []
return [InstallableExtension.from_row(row) for row in rows] return [InstallableExtension.from_row(row) for row in rows]
@ -633,9 +584,9 @@ class InstallableExtension(BaseModel):
@classmethod @classmethod
async def from_github_release( async def from_github_release(
cls, github_release: GitHubRelease cls, github_release: GitHubRelease
) -> Optional["InstallableExtension"]: ) -> Optional[InstallableExtension]:
try: try:
repo, latest_release, config = await fetch_github_repo_info( repo, latest_release, config = await cls.fetch_github_repo_info(
github_release.organisation, github_release.repository github_release.organisation, github_release.repository
) )
source_repo = f"{github_release.organisation}/{github_release.repository}" source_repo = f"{github_release.organisation}/{github_release.repository}"
@ -657,7 +608,7 @@ class InstallableExtension(BaseModel):
return None return None
@classmethod @classmethod
def from_explicit_release(cls, e: ExplicitRelease) -> "InstallableExtension": def from_explicit_release(cls, e: ExplicitRelease) -> InstallableExtension:
return InstallableExtension( return InstallableExtension(
id=e.id, id=e.id,
name=e.name, name=e.name,
@ -670,13 +621,13 @@ class InstallableExtension(BaseModel):
@classmethod @classmethod
async def get_installable_extensions( async def get_installable_extensions(
cls, cls,
) -> List["InstallableExtension"]: ) -> list[InstallableExtension]:
extension_list: List[InstallableExtension] = [] extension_list: list[InstallableExtension] = []
extension_id_list: List[str] = [] extension_id_list: list[str] = []
for url in settings.lnbits_extensions_manifests: for url in settings.lnbits_extensions_manifests:
try: try:
manifest = await fetch_manifest(url) manifest = await cls.fetch_manifest(url)
for r in manifest.repos: for r in manifest.repos:
ext = await InstallableExtension.from_github_release(r) ext = await InstallableExtension.from_github_release(r)
@ -712,12 +663,12 @@ class InstallableExtension(BaseModel):
return extension_list return extension_list
@classmethod @classmethod
async def get_extension_releases(cls, ext_id: str) -> List["ExtensionRelease"]: async def get_extension_releases(cls, ext_id: str) -> list[ExtensionRelease]:
extension_releases: List[ExtensionRelease] = [] extension_releases: list[ExtensionRelease] = []
for url in settings.lnbits_extensions_manifests: for url in settings.lnbits_extensions_manifests:
try: try:
manifest = await fetch_manifest(url) manifest = await cls.fetch_manifest(url)
for r in manifest.repos: for r in manifest.repos:
if r.id != ext_id: if r.id != ext_id:
continue continue
@ -741,8 +692,8 @@ class InstallableExtension(BaseModel):
@classmethod @classmethod
async def get_extension_release( async def get_extension_release(
cls, ext_id: str, source_repo: str, archive: str, version: str cls, ext_id: str, source_repo: str, archive: str, version: str
) -> Optional["ExtensionRelease"]: ) -> Optional[ExtensionRelease]:
all_releases: List[ExtensionRelease] = ( all_releases: list[ExtensionRelease] = (
await InstallableExtension.get_extension_releases(ext_id) await InstallableExtension.get_extension_releases(ext_id)
) )
selected_release = [ selected_release = [
@ -755,6 +706,37 @@ class InstallableExtension(BaseModel):
return selected_release[0] if len(selected_release) != 0 else None return selected_release[0] if len(selected_release) != 0 else None
@classmethod
async def fetch_github_repo_info(
cls, org: str, repository: str
) -> tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await github_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await github_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await github_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
@classmethod
async def fetch_manifest(cls, url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await github_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
class CreateExtension(BaseModel): class CreateExtension(BaseModel):
ext_id: str ext_id: str
@ -769,32 +751,3 @@ class ExtensionDetailsRequest(BaseModel):
ext_id: str ext_id: str
source_repo: str source_repo: str
version: str version: str
def get_valid_extensions(include_deactivated: Optional[bool] = True) -> List[Extension]:
valid_extensions = [
extension for extension in ExtensionManager().extensions if extension.is_valid
]
if include_deactivated:
return valid_extensions
if settings.lnbits_extensions_deactivate_all:
return []
return [
e
for e in valid_extensions
if e.code not in settings.lnbits_deactivated_extensions
]
def version_parse(v: str):
"""
Wrapper for version.parse() that does not throw if the version is invalid.
Instead it return the lowest possible version ("0.0.0")
"""
try:
return version.parse(v)
except Exception:
return version.parse("0.0.0")

View file

@ -1,9 +1,8 @@
import importlib import importlib
import re import re
from typing import Any, Optional from typing import Any
from uuid import UUID from uuid import UUID
import httpx
from loguru import logger from loguru import logger
from lnbits.core import migrations as core_migrations from lnbits.core import migrations as core_migrations
@ -13,11 +12,10 @@ from lnbits.core.crud import (
update_migration_version, update_migration_version,
) )
from lnbits.core.db import db as core_db from lnbits.core.db import db as core_db
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection from lnbits.core.extensions.models import (
from lnbits.extension_manager import (
Extension, Extension,
get_valid_extensions,
) )
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
from lnbits.settings import settings from lnbits.settings import settings
@ -55,68 +53,6 @@ async def run_migration(
await update_migration_version(conn, db_name, version) await update_migration_version(conn, db_name, version)
async def stop_extension_background_work(
ext_id: str, user: str, access_token: Optional[str] = None
):
"""
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
Extensions SHOULD expose a `api_stop()` function and/or a DELETE enpoint
at the root level of their API.
"""
stopped = await _stop_extension_background_work(ext_id)
if not stopped:
# fallback to REST API call
await _stop_extension_background_work_via_api(ext_id, user, access_token)
async def _stop_extension_background_work(ext_id) -> bool:
upgrade_hash = settings.extension_upgrade_hash(ext_id) or ""
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
try:
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
old_module = importlib.import_module(ext.module_name)
# Extensions must expose an `{ext_id}_stop()` function at the module level
# The `api_stop()` function is for backwards compatibility (will be deprecated)
stop_fns = [f"{ext_id}_stop", "api_stop"]
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None)
assert stop_fn_name, "No stop function found for '{ext.module_name}'"
stop_fn = getattr(old_module, stop_fn_name)
if stop_fn:
await stop_fn()
logger.info(f"Stopped background work for extension '{ext.module_name}'.")
except Exception as ex:
logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
logger.warning(ex)
return False
return True
async def _stop_extension_background_work_via_api(ext_id, user, access_token):
logger.info(
f"Stopping background work for extension '{ext_id}' using the REST API."
)
async with httpx.AsyncClient() as client:
try:
url = f"http://{settings.host}:{settings.port}/{ext_id}/api/v1?usr={user}"
headers = (
{"Authorization": "Bearer " + access_token} if access_token else None
)
resp = await client.delete(url=url, headers=headers)
resp.raise_for_status()
logger.info(f"Stopped background work for extension '{ext_id}'.")
except Exception as ex:
logger.warning(
f"Failed to stop background work for '{ext_id}' using the REST API."
)
logger.warning(ex)
def to_valid_user_id(user_id: str) -> UUID: def to_valid_user_id(user_id: str) -> UUID:
if len(user_id) < 32: if len(user_id) < 32:
raise ValueError("User ID must have at least 128 bits") raise ValueError("User ID must have at least 128 bits")
@ -161,7 +97,7 @@ async def migrate_databases():
await load_disabled_extension_list() await load_disabled_extension_list()
# todo: revisit, use installed extensions # todo: revisit, use installed extensions
for ext in get_valid_extensions(False): for ext in Extension.get_valid_extensions(False):
current_version = current_versions.get(ext.code, 0) current_version = current_versions.get(ext.code, 0)
try: try:
await migrate_extension_database(ext, current_version) await migrate_extension_database(ext, current_version)

View file

@ -1,8 +1,6 @@
import sys
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import (
List, List,
Optional,
) )
from bolt11 import decode as bolt11_decode from bolt11 import decode as bolt11_decode
@ -13,10 +11,21 @@ from fastapi import (
) )
from loguru import logger from loguru import logger
from lnbits.core.db import core_app_extra from lnbits.core.extensions.extension_manager import (
from lnbits.core.helpers import ( activate_extension,
migrate_extension_database, deactivate_extension,
stop_extension_background_work, install_extension,
uninstall_extension,
)
from lnbits.core.extensions.models import (
CreateExtension,
Extension,
ExtensionConfig,
ExtensionRelease,
InstallableExtension,
PayToEnableInfo,
ReleasePaymentInfo,
UserExtensionInfo,
) )
from lnbits.core.models import ( from lnbits.core.models import (
SimpleStatus, SimpleStatus,
@ -24,36 +33,18 @@ from lnbits.core.models import (
) )
from lnbits.core.services import check_transaction_status, create_invoice from lnbits.core.services import check_transaction_status, create_invoice
from lnbits.decorators import ( from lnbits.decorators import (
check_access_token,
check_admin, check_admin,
check_user_exists, check_user_exists,
) )
from lnbits.extension_manager import (
CreateExtension,
Extension,
ExtensionRelease,
InstallableExtension,
PayToEnableInfo,
ReleasePaymentInfo,
UserExtensionInfo,
fetch_github_release_config,
fetch_release_details,
fetch_release_payment_info,
get_valid_extensions,
)
from lnbits.settings import settings
from ..crud import ( from ..crud import (
add_installed_extension,
delete_dbversion, delete_dbversion,
delete_installed_extension,
drop_extension_db, drop_extension_db,
get_dbversions, get_dbversions,
get_installed_extension, get_installed_extension,
get_installed_extensions, get_installed_extensions,
get_user_extension, get_user_extension,
update_extension_pay_to_enable, update_extension_pay_to_enable,
update_installed_extension_state,
update_user_extension, update_user_extension,
update_user_extension_extra, update_user_extension_extra,
) )
@ -64,12 +55,8 @@ extension_router = APIRouter(
) )
@extension_router.post("") @extension_router.post("", dependencies=[Depends(check_admin)])
async def api_install_extension( async def api_install_extension(data: CreateExtension):
data: CreateExtension,
user: User = Depends(check_admin),
access_token: Optional[str] = Depends(check_access_token),
):
release = await InstallableExtension.get_extension_release( release = await InstallableExtension.get_extension_release(
data.ext_id, data.source_repo, data.archive, data.version data.ext_id, data.source_repo, data.archive, data.version
) )
@ -89,43 +76,36 @@ async def api_install_extension(
) )
try: try:
installed_ext = await get_installed_extension(data.ext_id) extension = await install_extension(ext_info)
ext_info.payments = installed_ext.payments if installed_ext else []
await ext_info.download_archive()
ext_info.extract_archive()
extension = Extension.from_installable_ext(ext_info)
db_version = (await get_dbversions()).get(data.ext_id, 0)
await migrate_extension_database(extension, db_version)
ext_info.active = True
await add_installed_extension(ext_info)
if extension.is_upgrade_extension:
# call stop while the old routes are still active
await stop_extension_background_work(data.ext_id, user.id, access_token)
# mount routes for the new version
core_app_extra.register_new_ext_routes(extension)
ext_info.notify_upgrade(extension.upgrade_hash)
settings.lnbits_deactivated_extensions.discard(data.ext_id)
return extension
except AssertionError as exc:
raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc
except Exception as exc: except Exception as exc:
logger.warning(exc) logger.warning(exc)
ext_info.clean_extension_files() ext_info.clean_extension_files()
detail = (
str(exc)
if isinstance(exc, AssertionError)
else f"Failed to install extension '{ext_info.id}'."
f"({ext_info.installed_version})."
)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=( detail=detail,
f"Failed to install extension {ext_info.id} " ) from exc
f"({ext_info.installed_version})."
), try:
await activate_extension(extension)
return extension
except Exception as exc:
logger.warning(exc)
await deactivate_extension(extension.code)
detail = (
str(exc)
if isinstance(exc, AssertionError)
else f"Extension `{extension.code}` installed, but activation failed."
)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=detail,
) from exc ) from exc
@ -143,7 +123,7 @@ async def api_extension_details(
) )
assert release, "Details not found for release" assert release, "Details not found for release"
release_details = await fetch_release_details(details_link) release_details = await ExtensionRelease.fetch_release_details(details_link)
assert release_details, "Cannot fetch details for release" assert release_details, "Cannot fetch details for release"
release_details["icon"] = release.icon release_details["icon"] = release.icon
release_details["repo"] = release.repo release_details["repo"] = release.repo
@ -186,7 +166,7 @@ async def api_update_pay_to_enable(
async def api_enable_extension( async def api_enable_extension(
ext_id: str, user: User = Depends(check_user_exists) ext_id: str, user: User = Depends(check_user_exists)
) -> SimpleStatus: ) -> SimpleStatus:
if ext_id not in [e.code for e in get_valid_extensions()]: if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
raise HTTPException( raise HTTPException(
HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' doesn't exist." HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' doesn't exist."
) )
@ -249,7 +229,7 @@ async def api_enable_extension(
async def api_disable_extension( async def api_disable_extension(
ext_id: str, user: User = Depends(check_user_exists) ext_id: str, user: User = Depends(check_user_exists)
) -> SimpleStatus: ) -> SimpleStatus:
if ext_id not in [e.code for e in get_valid_extensions()]: if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
raise HTTPException( raise HTTPException(
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist." HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
) )
@ -270,20 +250,14 @@ async def api_activate_extension(ext_id: str) -> SimpleStatus:
try: try:
logger.info(f"Activating extension: '{ext_id}'.") logger.info(f"Activating extension: '{ext_id}'.")
all_extensions = get_valid_extensions() ext = Extension.get_valid_extension(ext_id)
ext = next((e for e in all_extensions if e.code == ext_id), None)
assert ext, f"Extension '{ext_id}' doesn't exist." assert ext, f"Extension '{ext_id}' doesn't exist."
# if extension never loaded (was deactivated on server startup)
if ext_id not in sys.modules.keys():
# run extension start-up routine
core_app_extra.register_new_ext_routes(ext)
settings.lnbits_deactivated_extensions.discard(ext_id) await activate_extension(ext)
await update_installed_extension_state(ext_id=ext_id, active=True)
return SimpleStatus(success=True, message=f"Extension '{ext_id}' activated.") return SimpleStatus(success=True, message=f"Extension '{ext_id}' activated.")
except Exception as exc: except Exception as exc:
logger.warning(exc) logger.warning(exc)
await deactivate_extension(ext_id)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=(f"Failed to activate '{ext_id}'."), detail=(f"Failed to activate '{ext_id}'."),
@ -295,13 +269,10 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus:
try: try:
logger.info(f"Deactivating extension: '{ext_id}'.") logger.info(f"Deactivating extension: '{ext_id}'.")
all_extensions = get_valid_extensions() ext = Extension.get_valid_extension(ext_id)
ext = next((e for e in all_extensions if e.code == ext_id), None)
assert ext, f"Extension '{ext_id}' doesn't exist." assert ext, f"Extension '{ext_id}' doesn't exist."
settings.lnbits_deactivated_extensions.add(ext_id) await deactivate_extension(ext_id)
await update_installed_extension_state(ext_id=ext_id, active=False)
return SimpleStatus(success=True, message=f"Extension '{ext_id}' deactivated.") return SimpleStatus(success=True, message=f"Extension '{ext_id}' deactivated.")
except Exception as exc: except Exception as exc:
logger.warning(exc) logger.warning(exc)
@ -311,23 +282,19 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus:
) from exc ) from exc
@extension_router.delete("/{ext_id}") @extension_router.delete("/{ext_id}", dependencies=[Depends(check_admin)])
async def api_uninstall_extension( async def api_uninstall_extension(ext_id: str) -> SimpleStatus:
ext_id: str,
user: User = Depends(check_admin),
access_token: Optional[str] = Depends(check_access_token),
) -> SimpleStatus:
installed_extensions = await get_installed_extensions()
extensions = [e for e in installed_extensions if e.id == ext_id] extension = await get_installed_extension(ext_id)
if len(extensions) == 0: if not extension:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
detail=f"Unknown extension id: {ext_id}", detail=f"Unknown extension id: {ext_id}",
) )
installed_extensions = await get_installed_extensions()
# check that other extensions do not depend on this one # check that other extensions do not depend on this one
for valid_ext_id in [ext.code for ext in get_valid_extensions()]: for valid_ext_id in [ext.code for ext in Extension.get_valid_extensions()]:
installed_ext = next( installed_ext = next(
(ext for ext in installed_extensions if ext.id == valid_ext_id), None (ext for ext in installed_extensions if ext.id == valid_ext_id), None
) )
@ -341,14 +308,7 @@ async def api_uninstall_extension(
) )
try: try:
# call stop while the old routes are still active await uninstall_extension(ext_id)
await stop_extension_background_work(ext_id, user.id, access_token)
settings.lnbits_deactivated_extensions.add(ext_id)
for ext_info in extensions:
ext_info.clean_extension_files()
await delete_installed_extension(ext_id=ext_info.id)
logger.success(f"Extension '{ext_id}' uninstalled.") logger.success(f"Extension '{ext_id}' uninstalled.")
return SimpleStatus(success=True, message=f"Extension '{ext_id}' uninstalled.") return SimpleStatus(success=True, message=f"Extension '{ext_id}' uninstalled.")
@ -397,9 +357,8 @@ async def get_pay_to_install_invoice(
assert release, "Release not found." assert release, "Release not found."
assert release.pay_link, "Pay link not found for release." assert release.pay_link, "Pay link not found for release."
payment_info = await fetch_release_payment_info( payment_info = await release.fetch_release_payment_info(data.cost_sats)
release.pay_link, data.cost_sats
)
assert payment_info and payment_info.payment_request, "Cannot request invoice." assert payment_info and payment_info.payment_request, "Cannot request invoice."
invoice = bolt11_decode(payment_info.payment_request) invoice = bolt11_decode(payment_info.payment_request)
@ -474,7 +433,7 @@ async def get_pay_to_enable_invoice(
) )
async def get_extension_release(org: str, repo: str, tag_name: str): async def get_extension_release(org: str, repo: str, tag_name: str):
try: try:
config = await fetch_github_release_config(org, repo, tag_name) config = await ExtensionConfig.fetch_github_release_config(org, repo, tag_name)
if not config: if not config:
return {} return {}

View file

@ -12,6 +12,7 @@ from lnurl import decode as lnurl_decode
from loguru import logger from loguru import logger
from pydantic.types import UUID4 from pydantic.types import UUID4
from lnbits.core.extensions.models import Extension, InstallableExtension
from lnbits.core.helpers import to_valid_user_id from lnbits.core.helpers import to_valid_user_id
from lnbits.core.models import User from lnbits.core.models import User
from lnbits.core.services import create_invoice from lnbits.core.services import create_invoice
@ -20,7 +21,6 @@ from lnbits.helpers import template_renderer
from lnbits.settings import settings from lnbits.settings import settings
from lnbits.wallets import get_funding_source from lnbits.wallets import get_funding_source
from ...extension_manager import InstallableExtension, get_valid_extensions
from ...utils.exchange_rates import allowed_currencies, currencies from ...utils.exchange_rates import allowed_currencies, currencies
from ..crud import ( from ..crud import (
create_account, create_account,
@ -104,7 +104,7 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)):
installed_exts_ids = [] installed_exts_ids = []
try: try:
all_ext_ids = [ext.code for ext in get_valid_extensions()] all_ext_ids = [ext.code for ext in Extension.get_valid_extensions()]
inactive_extensions = [ inactive_extensions = [
e.id for e in await get_installed_extensions(active=False) e.id for e in await get_installed_extensions(active=False)
] ]

View file

@ -10,6 +10,7 @@ import shortuuid
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.schema import field_schema from pydantic.schema import field_schema
from lnbits.core.extensions.models import Extension
from lnbits.db import get_placeholder from lnbits.db import get_placeholder
from lnbits.jinja2_templating import Jinja2Templates from lnbits.jinja2_templating import Jinja2Templates
from lnbits.nodes import get_node_class from lnbits.nodes import get_node_class
@ -18,7 +19,6 @@ from lnbits.settings import settings
from lnbits.utils.crypto import AESCipher from lnbits.utils.crypto import AESCipher
from .db import FilterModel from .db import FilterModel
from .extension_manager import get_valid_extensions
def get_db_vendor_name(): def get_db_vendor_name():
@ -93,7 +93,7 @@ def template_renderer(additional_folders: Optional[List] = None) -> Jinja2Templa
settings.lnbits_node_ui and get_node_class() is not None settings.lnbits_node_ui and get_node_class() is not None
) )
t.env.globals["LNBITS_NODE_UI_AVAILABLE"] = get_node_class() is not None t.env.globals["LNBITS_NODE_UI_AVAILABLE"] = get_node_class() is not None
t.env.globals["EXTENSIONS"] = get_valid_extensions(False) t.env.globals["EXTENSIONS"] = Extension.get_valid_extensions(False)
if settings.lnbits_custom_logo: if settings.lnbits_custom_logo:
t.env.globals["USE_CUSTOM_LOGO"] = settings.lnbits_custom_logo t.env.globals["USE_CUSTOM_LOGO"] = settings.lnbits_custom_logo

View file

@ -1,5 +1,5 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Any, List, Tuple, Union from typing import Any, List, Union
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
@ -45,16 +45,11 @@ class InstalledExtensionMiddleware:
await self.app(scope, receive, send) await self.app(scope, receive, send)
return return
upgrade_path = next(
(
e
for e in settings.lnbits_upgraded_extensions
if e.endswith(f"/{top_path}")
),
None,
)
# re-route all trafic if the extension has been upgraded # re-route all trafic if the extension has been upgraded
if upgrade_path: if top_path in settings.lnbits_upgraded_extensions:
upgrade_path = (
f"""{settings.lnbits_upgraded_extensions[top_path]}/{top_path}"""
)
tail = "/".join(rest) tail = "/".join(rest)
scope["path"] = f"/upgrades/{upgrade_path}/{tail}" scope["path"] = f"/upgrades/{upgrade_path}/{tail}"
@ -118,72 +113,12 @@ class ExtensionsRedirectMiddleware:
return return
req_headers = scope["headers"] if "headers" in scope else [] req_headers = scope["headers"] if "headers" in scope else []
redirect = self._find_redirect(scope["path"], req_headers) redirect = settings.find_extension_redirect(scope["path"], req_headers)
if redirect: if redirect:
scope["path"] = self._new_path(redirect, scope["path"]) scope["path"] = redirect.new_path_from(scope["path"])
await self.app(scope, receive, send) await self.app(scope, receive, send)
def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]):
return next(
(
r
for r in settings.lnbits_extensions_redirects
if self._redirect_matches(r, path, req_headers)
),
None,
)
def _redirect_matches(
self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]]
) -> bool:
if "from_path" not in redirect:
return False
header_filters = (
redirect["header_filters"] if "header_filters" in redirect else {}
)
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
header_filters, req_headers
)
def _has_headers(
self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]]
) -> bool:
for h in filter_headers:
if not self._has_header(req_headers, (str(h), str(filter_headers[h]))):
return False
return True
def _has_header(
self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str]
) -> bool:
for h in req_headers:
if (
h[0].decode().lower() == header[0].lower()
and h[1].decode() == header[1]
):
return True
return False
def _has_common_path(self, redirect_path: str, req_path: str) -> bool:
redirect_path_elements = redirect_path.split("/")
req_path_elements = req_path.split("/")
if len(redirect_path) > len(req_path):
return False
sub_path = req_path_elements[: len(redirect_path_elements)]
return redirect_path == "/".join(sub_path)
def _new_path(self, redirect: dict, req_path: str) -> str:
from_path = redirect["from_path"].split("/")
redirect_to = redirect["redirect_to_path"].split("/")
req_tail_path = req_path.split("/")[len(from_path) :]
elements = [
e for e in ([redirect["ext_id"], *redirect_to, *req_tail_path]) if e != ""
]
return "/" + "/".join(elements)
def add_ratelimit_middleware(app: FastAPI): def add_ratelimit_middleware(app: FastAPI):
core_app_extra.register_new_ratelimiter() core_app_extra.register_new_ratelimiter()

View file

@ -62,26 +62,132 @@ class ExtensionsInstallSettings(LNbitsSettings):
lnbits_ext_github_token: str = Field(default="") lnbits_ext_github_token: str = Field(default="")
class RedirectPath(BaseModel):
ext_id: str
from_path: str
redirect_to_path: str
header_filters: dict = {}
def in_conflict(self, other: RedirectPath) -> bool:
if self.ext_id == other.ext_id:
return False
return self.redirect_matches(
other.from_path, list(other.header_filters.items())
) or other.redirect_matches(self.from_path, list(self.header_filters.items()))
def find_in_conflict(self, others: list[RedirectPath]) -> Optional[RedirectPath]:
for other in others:
if self.in_conflict(other):
return other
return None
def new_path_from(self, req_path: str) -> str:
from_path = self.from_path.split("/")
redirect_to = self.redirect_to_path.split("/")
req_tail_path = req_path.split("/")[len(from_path) :]
elements = [e for e in ([self.ext_id, *redirect_to, *req_tail_path]) if e != ""]
return "/" + "/".join(elements)
def redirect_matches(self, path: str, req_headers: list[tuple[str, str]]) -> bool:
return self._has_common_path(path) and self._has_headers(req_headers)
def _has_common_path(self, req_path: str) -> bool:
if len(self.from_path) > len(req_path):
return False
redirect_path_elements = self.from_path.split("/")
req_path_elements = req_path.split("/")
sub_path = req_path_elements[: len(redirect_path_elements)]
return self.from_path == "/".join(sub_path)
def _has_headers(self, req_headers: list[tuple[str, str]]) -> bool:
for h in self.header_filters:
if not self._has_header(req_headers, (str(h), str(self.header_filters[h]))):
return False
return True
def _has_header(
self, req_headers: list[tuple[str, str]], header: tuple[str, str]
) -> bool:
for h in req_headers:
if h[0].lower() == header[0].lower() and h[1].lower() == header[1].lower():
return True
return False
class InstalledExtensionsSettings(LNbitsSettings): class InstalledExtensionsSettings(LNbitsSettings):
# installed extensions that have been deactivated # installed extensions that have been deactivated
lnbits_deactivated_extensions: set[str] = Field(default=[]) lnbits_deactivated_extensions: set[str] = Field(default=[])
# upgraded extensions that require API redirects # upgraded extensions that require API redirects
lnbits_upgraded_extensions: set[str] = Field(default=[]) lnbits_upgraded_extensions: dict[str, str] = Field(default={})
# list of redirects that extensions want to perform # list of redirects that extensions want to perform
lnbits_extensions_redirects: list[Any] = Field(default=[]) lnbits_extensions_redirects: list[RedirectPath] = Field(default=[])
# list of all extension ids # list of all extension ids
lnbits_all_extensions_ids: set[Any] = Field(default=[]) lnbits_all_extensions_ids: set[Any] = Field(default=[])
def extension_upgrade_path(self, ext_id: str) -> Optional[str]: def find_extension_redirect(
self, path: str, req_headers: list[tuple[bytes, bytes]]
) -> Optional[RedirectPath]:
headers = [(k.decode(), v.decode()) for k, v in req_headers]
return next( return next(
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")), (
r
for r in self.lnbits_extensions_redirects
if r.redirect_matches(path, headers)
),
None, None,
) )
def extension_upgrade_hash(self, ext_id: str) -> Optional[str]: def activate_extension_paths(
path = settings.extension_upgrade_path(ext_id) self,
return path.split("/")[0] if path else None ext_id: str,
upgrade_hash: Optional[str] = None,
ext_redirects: Optional[list[dict]] = None,
):
self.lnbits_deactivated_extensions.discard(ext_id)
"""
Update the list of upgraded extensions. The middleware will perform
redirects based on this
"""
if upgrade_hash:
self.lnbits_upgraded_extensions[ext_id] = upgrade_hash
if ext_redirects:
self._activate_extension_redirects(ext_id, ext_redirects)
self.lnbits_all_extensions_ids.add(ext_id)
def deactivate_extension_paths(self, ext_id: str):
self.lnbits_deactivated_extensions.add(ext_id)
self._remove_extension_redirects(ext_id)
def _activate_extension_redirects(self, ext_id: str, ext_redirects: list[dict]):
ext_redirect_paths = [
RedirectPath(**{"ext_id": ext_id, **er}) for er in ext_redirects
]
existing_redirects = {
r.ext_id
for r in self.lnbits_extensions_redirects
if r.find_in_conflict(ext_redirect_paths)
}
assert len(existing_redirects) == 0, (
f"Cannot redirect for extension '{ext_id}'."
f" Already mapped by {existing_redirects}."
)
self._remove_extension_redirects(ext_id)
self.lnbits_extensions_redirects += ext_redirect_paths
def _remove_extension_redirects(self, ext_id: str):
self.lnbits_extensions_redirects = [
er for er in self.lnbits_extensions_redirects if er.ext_id != ext_id
]
class ThemesSettings(LNbitsSettings): class ThemesSettings(LNbitsSettings):

168
tests/unit/test_settings.py Normal file
View file

@ -0,0 +1,168 @@
import pytest
from lnbits.settings import RedirectPath
lnurlp_redirect_path = {
"from_path": "/.well-known/lnurlp",
"redirect_to_path": "/api/v1/well-known",
}
lnurlp_redirect_path_with_headers = {
"from_path": "/.well-known/lnurlp",
"redirect_to_path": "/api/v1/well-known",
"header_filters": {"accept": "application/nostr+json"},
}
lnaddress_redirect_path = {
"from_path": "/.well-known/lnurlp",
"redirect_to_path": "/api/v1/well-known",
}
nostrrelay_redirect_path = {
"from_path": "/",
"redirect_to_path": "/api/v1/relay-info",
"header_filters": {"accept": "application/nostr+json"},
}
@pytest.fixture()
def lnurlp():
return RedirectPath(ext_id="lnurlp", **lnurlp_redirect_path)
@pytest.fixture()
def lnurlp_with_headers():
return RedirectPath(
ext_id="lnurlp_with_headers", **lnurlp_redirect_path_with_headers
)
@pytest.fixture()
def lnaddress():
return RedirectPath(ext_id="lnaddress", **lnaddress_redirect_path)
@pytest.fixture()
def nostrrelay():
return RedirectPath(ext_id="nostrrelay", **nostrrelay_redirect_path)
def test_redirect_path_self_not_in_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert not lnurlp.in_conflict(lnurlp), "Path is not in conflict with itself."
assert not lnaddress.in_conflict(lnaddress), "Path is not in conflict with itself."
assert not nostrrelay.in_conflict(
nostrrelay
), "Path is not in conflict with itself."
assert not lnurlp.in_conflict(nostrrelay)
assert not nostrrelay.in_conflict(lnurlp)
def test_redirect_path_not_in_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert not lnurlp.in_conflict(nostrrelay)
assert not nostrrelay.in_conflict(lnurlp)
assert not lnaddress.in_conflict(nostrrelay)
assert not nostrrelay.in_conflict(lnaddress)
def test_redirect_path_in_conflict(lnurlp: RedirectPath, lnaddress: RedirectPath):
assert lnurlp.in_conflict(lnaddress)
assert lnaddress.in_conflict(lnurlp)
def test_redirect_path_find_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert lnurlp.find_in_conflict([nostrrelay, lnaddress])
assert lnurlp.find_in_conflict([lnaddress, nostrrelay])
assert lnaddress.find_in_conflict([nostrrelay, lnurlp])
assert lnaddress.find_in_conflict([lnurlp, nostrrelay])
def test_redirect_path_find_no_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert not nostrrelay.find_in_conflict([lnurlp, lnaddress])
assert not lnurlp.find_in_conflict([nostrrelay])
assert not lnaddress.find_in_conflict([nostrrelay])
def test_redirect_path_in_conflict_with_headers(
lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath
):
assert lnurlp.in_conflict(lnurlp_with_headers)
assert lnurlp_with_headers.in_conflict(lnurlp)
def test_redirect_path_matches_with_headers(
lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath
):
headers_list = list(lnurlp_with_headers.header_filters.items())
assert lnurlp.redirect_matches(
path=lnurlp_with_headers.from_path,
req_headers=headers_list,
)
assert lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"],
req_headers=[("ACCEPT", "APPlication/nostr+json")],
)
assert lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"],
req_headers=[("accept", "application/nostr+json"), ("my_header", "my_value")],
)
assert not lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"], req_headers=[]
)
assert not lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"],
req_headers=[("accept", "application/json")],
)
assert not lnurlp_with_headers.redirect_matches(path="/random/path", req_headers=[])
assert not lnurlp_with_headers.redirect_matches(path="/random_path", req_headers=[])
assert not lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp", req_headers=[]
)
assert lnurlp.redirect_matches(path="/.well-known/lnurlp", req_headers=[])
assert lnurlp.redirect_matches(
path="/.well-known/lnurlp/some/other/path", req_headers=[]
)
assert lnurlp.redirect_matches(
path="/.well-known/lnurlp/some/other/path",
req_headers=headers_list,
)
assert not lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp", req_headers=[]
)
assert not lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp/some/other/path", req_headers=[]
)
assert lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp/some/other/path",
req_headers=headers_list,
)
def test_redirect_path_new_path_from(lnurlp: RedirectPath):
assert lnurlp.new_path_from("") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/path") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/path/more") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/.well-known/lnurlp") == "/lnurlp/api/v1/well-known"
assert (
lnurlp.new_path_from("/.well-known/lnurlp/path")
== "/lnurlp/api/v1/well-known/path"
)
assert (
lnurlp.new_path_from("/.well-known/lnurlp/path/more")
== "/lnurlp/api/v1/well-known/path/more"
)