Fix overlapping redirect paths (#2671)

This commit is contained in:
Vlad Stan 2024-09-11 12:41:37 +03:00 committed by GitHub
parent 7a5e7fbd8c
commit 5f4f1288d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 684 additions and 483 deletions

4
.gitignore vendored
View file

@ -49,8 +49,8 @@ fly.toml
lnbits-backup.zip
# Ignore extensions (post installable extension PR)
extensions
upgrades/
/extensions
/upgrades/
# builded python package
dist

View file

@ -17,10 +17,13 @@ from slowapi.util import get_remote_address
from starlette.middleware.sessions import SessionMiddleware
from lnbits.core.crud import (
add_installed_extension,
get_dbversions,
get_installed_extensions,
update_installed_extension_state,
)
from lnbits.core.extensions.extension_manager import deactivate_extension
from lnbits.core.extensions.helpers import version_parse
from lnbits.core.helpers import migrate_extension_database
from lnbits.core.tasks import ( # watchdog_task
killswitch_task,
@ -44,14 +47,8 @@ from lnbits.wallets import get_funding_source, set_funding_source
from .commands import migrate_databases
from .core import init_core_routers
from .core.db import core_app_extra
from .core.extensions.models import Extension, InstallableExtension
from .core.services import check_admin_settings, check_webpush_settings
from .core.views.extension_api import add_installed_extension
from .extension_manager import (
Extension,
InstallableExtension,
get_valid_extensions,
version_parse,
)
from .middleware import (
CustomGZipMiddleware,
ExtensionsRedirectMiddleware,
@ -243,6 +240,7 @@ async def check_installed_extensions(app: FastAPI):
)
except Exception as e:
logger.warning(e)
await deactivate_extension(ext.id)
logger.warning(
f"Failed to re-install extension: {ext.id} ({ext.installed_version})"
)
@ -317,7 +315,6 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
# mount routes for the new version
core_app_extra.register_new_ext_routes(extension)
ext.notify_upgrade(extension.upgrade_hash)
def register_custom_extensions_path():
@ -380,24 +377,22 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None:
)
app.mount(s["path"], StaticFiles(directory=static_dir), s["name"])
if hasattr(ext_module, f"{ext.code}_redirect_paths"):
ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths")
settings.lnbits_extensions_redirects = [
r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code
]
for r in ext_redirects:
r["ext_id"] = ext.code
settings.lnbits_extensions_redirects.append(r)
ext_redirects = (
getattr(ext_module, f"{ext.code}_redirect_paths")
if hasattr(ext_module, f"{ext.code}_redirect_paths")
else []
)
logger.trace(f"adding route for extension {ext_module}")
settings.activate_extension_paths(ext.code, ext.upgrade_hash, ext_redirects)
logger.trace(f"Adding route for extension {ext_module}.")
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""
app.include_router(router=ext_route, prefix=prefix)
async def check_and_register_extensions(app: FastAPI):
await check_installed_extensions(app)
for ext in get_valid_extensions(False):
for ext in Extension.get_valid_extensions(False):
try:
register_ext_routes(app, ext)
except Exception as exc:

View file

@ -25,18 +25,18 @@ from lnbits.core.crud import (
remove_deleted_wallets,
update_payment_status,
)
from lnbits.core.extensions.models import (
CreateExtension,
ExtensionRelease,
InstallableExtension,
)
from lnbits.core.helpers import migrate_databases
from lnbits.core.models import Payment, PaymentState, User
from lnbits.core.models import Payment, PaymentState
from lnbits.core.services import check_admin_settings
from lnbits.core.views.extension_api import (
api_install_extension,
api_uninstall_extension,
)
from lnbits.extension_manager import (
CreateExtension,
ExtensionRelease,
InstallableExtension,
)
from lnbits.settings import settings
from lnbits.wallets.base import Wallet
@ -611,7 +611,7 @@ async def _call_install_extension(
)
resp.raise_for_status()
else:
await api_install_extension(data, User(id="mock_id"))
await api_install_extension(data)
async def _call_uninstall_extension(
@ -625,7 +625,7 @@ async def _call_uninstall_extension(
)
resp.raise_for_status()
else:
await api_uninstall_extension(extension, User(id="mock_id"))
await api_uninstall_extension(extension)
async def _can_run_operation(url) -> bool:

View file

@ -8,14 +8,14 @@ import shortuuid
from passlib.context import CryptContext
from lnbits.core.db import db
from lnbits.core.models import PaymentState
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
from lnbits.extension_manager import (
from lnbits.core.extensions.models import (
InstallableExtension,
PayToEnableInfo,
UserExtension,
UserExtensionInfo,
)
from lnbits.core.models import PaymentState
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
from lnbits.settings import (
AdminSettings,
EditableSettings,
@ -430,7 +430,7 @@ async def get_installed_extension(
async def get_installed_extensions(
active: Optional[bool] = None,
conn: Optional[Connection] = None,
) -> List["InstallableExtension"]:
) -> List[InstallableExtension]:
rows = await (conn or db).fetchall(
"SELECT * FROM installed_extensions",
(),

View file

@ -0,0 +1,93 @@
import asyncio
import importlib
from loguru import logger
from lnbits.core.crud import (
add_installed_extension,
delete_installed_extension,
get_dbversions,
get_installed_extension,
update_installed_extension_state,
)
from lnbits.core.db import core_app_extra
from lnbits.core.helpers import migrate_extension_database
from lnbits.settings import settings
from .models import Extension, InstallableExtension
async def install_extension(ext_info: InstallableExtension) -> Extension:
extension = Extension.from_installable_ext(ext_info)
installed_ext = await get_installed_extension(ext_info.id)
ext_info.payments = installed_ext.payments if installed_ext else []
await ext_info.download_archive()
ext_info.extract_archive()
db_version = (await get_dbversions()).get(ext_info.id, 0)
await migrate_extension_database(extension, db_version)
await add_installed_extension(ext_info)
if extension.is_upgrade_extension:
# call stop while the old routes are still active
await stop_extension_background_work(ext_info.id)
return extension
async def uninstall_extension(ext_id: str):
await stop_extension_background_work(ext_id)
settings.deactivate_extension_paths(ext_id)
extension = await get_installed_extension(ext_id)
if extension:
extension.clean_extension_files()
await delete_installed_extension(ext_id=ext_id)
async def activate_extension(ext: Extension):
core_app_extra.register_new_ext_routes(ext)
await update_installed_extension_state(ext_id=ext.code, active=True)
async def deactivate_extension(ext_id: str):
settings.deactivate_extension_paths(ext_id)
await update_installed_extension_state(ext_id=ext_id, active=False)
async def stop_extension_background_work(ext_id: str) -> bool:
"""
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
Extensions SHOULD expose a `api_stop()` function.
"""
upgrade_hash = settings.lnbits_upgraded_extensions.get(ext_id, "")
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
try:
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
old_module = importlib.import_module(ext.module_name)
# Extensions must expose an `{ext_id}_stop()` function at the module level
# The `api_stop()` function is for backwards compatibility (will be deprecated)
stop_fns = [f"{ext_id}_stop", "api_stop"]
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None)
assert stop_fn_name, "No stop function found for '{ext.module_name}'"
stop_fn = getattr(old_module, stop_fn_name)
if stop_fn:
if asyncio.iscoroutinefunction(stop_fn):
await stop_fn()
else:
stop_fn()
logger.info(f"Stopped background work for extension '{ext.module_name}'.")
except Exception as ex:
logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
logger.warning(ex)
return False
return True

View file

@ -0,0 +1,56 @@
import hashlib
from typing import Any, Optional
from urllib import request
import httpx
from loguru import logger
from packaging import version
from lnbits.settings import settings
def version_parse(v: str):
"""
Wrapper for version.parse() that does not throw if the version is invalid.
Instead it return the lowest possible version ("0.0.0")
"""
try:
return version.parse(v)
except Exception:
return version.parse("0.0.0")
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
headers = {"User-Agent": settings.user_agent}
if settings.lnbits_ext_github_token:
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
async with httpx.AsyncClient(headers=headers) as client:
resp = await client.get(url)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
def download_url(url, save_path):
with request.urlopen(url, timeout=60) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import hashlib
import json
@ -6,16 +8,22 @@ import shutil
import sys
import zipfile
from pathlib import Path
from typing import Any, List, NamedTuple, Optional, Tuple
from urllib import request
from typing import Any, NamedTuple, Optional
import httpx
from loguru import logger
from packaging import version
from pydantic import BaseModel
from lnbits.settings import settings
from .helpers import (
download_url,
file_hash,
github_api_get,
icon_to_github_url,
version_parse,
)
class ExplicitRelease(BaseModel):
id: str
@ -23,7 +31,7 @@ class ExplicitRelease(BaseModel):
version: str
archive: str
hash: str
dependencies: List[str] = []
dependencies: list[str] = []
repo: Optional[str]
icon: Optional[str]
short_description: Optional[str]
@ -48,9 +56,9 @@ class GitHubRelease(BaseModel):
class Manifest(BaseModel):
featured: List[str] = []
extensions: List["ExplicitRelease"] = []
repos: List["GitHubRelease"] = []
featured: list[str] = []
extensions: list[ExplicitRelease] = []
repos: list[GitHubRelease] = []
class GitHubRepoRelease(BaseModel):
@ -81,6 +89,17 @@ class ExtensionConfig(BaseModel):
return True
return version_parse(self.min_lnbits_version) <= version_parse(settings.version)
@classmethod
async def fetch_github_release_config(
cls, org: str, repo: str, tag_name: str
) -> Optional[ExtensionConfig]:
config_url = (
f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json"
)
error_msg = "Cannot fetch GitHub extension config"
config = await github_api_get(config_url, error_msg)
return ExtensionConfig.parse_obj(config)
class ReleasePaymentInfo(BaseModel):
amount: Optional[int] = None
@ -112,7 +131,7 @@ class UserExtension(BaseModel):
return self.extra.paid_to_enable is True
@classmethod
def from_row(cls, data: dict) -> "UserExtension":
def from_row(cls, data: dict) -> UserExtension:
ext = UserExtension(**data)
ext.extra = (
UserExtensionInfo(**json.loads(data["_extra"] or "{}"))
@ -122,124 +141,6 @@ class UserExtension(BaseModel):
return ext
def download_url(url, save_path):
with request.urlopen(url, timeout=60) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
async def fetch_github_repo_info(
org: str, repository: str
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await github_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await github_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await github_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
async def fetch_manifest(url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await github_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await github_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
async def fetch_github_release_config(
org: str, repo: str, tag_name: str
) -> Optional[ExtensionConfig]:
config_url = (
f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json"
)
error_msg = "Cannot fetch GitHub extension config"
config = await github_api_get(config_url, error_msg)
return ExtensionConfig.parse_obj(config)
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
headers = {"User-Agent": settings.user_agent}
if settings.lnbits_ext_github_token:
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
async with httpx.AsyncClient(headers=headers) as client:
resp = await client.get(url)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
async def fetch_release_payment_info(
url: str, amount: Optional[int] = None
) -> Optional[ReleasePaymentInfo]:
if amount:
url = f"{url}?amount={amount}"
try:
async with httpx.AsyncClient() as client:
resp = await client.get(url)
resp.raise_for_status()
return ReleasePaymentInfo(**resp.json())
except Exception as e:
logger.warning(e)
return None
async def fetch_release_details(details_link: str) -> Optional[dict]:
try:
async with httpx.AsyncClient() as client:
resp = await client.get(details_link)
resp.raise_for_status()
data = resp.json()
if "description_md" in data:
resp = await client.get(data["description_md"])
if not resp.is_error:
data["description_md"] = resp.text
return data
except Exception as e:
logger.warning(e)
return None
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"
class Extension(NamedTuple):
code: str
is_valid: bool
@ -247,7 +148,7 @@ class Extension(NamedTuple):
name: Optional[str] = None
short_description: Optional[str] = None
tile: Optional[str] = None
contributors: Optional[List[str]] = None
contributors: Optional[list[str]] = None
hidden: bool = False
migration_module: Optional[str] = None
db_name: Optional[str] = None
@ -269,7 +170,7 @@ class Extension(NamedTuple):
return self.upgrade_hash != ""
@classmethod
def from_installable_ext(cls, ext_info: "InstallableExtension") -> "Extension":
def from_installable_ext(cls, ext_info: InstallableExtension) -> Extension:
return Extension(
code=ext_info.id,
is_valid=True,
@ -278,22 +179,43 @@ class Extension(NamedTuple):
upgrade_hash=ext_info.hash if ext_info.module_installed else "",
)
@classmethod
def get_valid_extensions(
cls, include_deactivated: Optional[bool] = True
) -> list[Extension]:
valid_extensions = [
extension for extension in cls._extensions() if extension.is_valid
]
# All subdirectories in the current directory, not recursive.
if include_deactivated:
return valid_extensions
if settings.lnbits_extensions_deactivate_all:
return []
class ExtensionManager:
def __init__(self) -> None:
return [
e
for e in valid_extensions
if e.code not in settings.lnbits_deactivated_extensions
]
@classmethod
def get_valid_extension(
cls, ext_id: str, include_deactivated: Optional[bool] = True
) -> Optional[Extension]:
all_extensions = cls.get_valid_extensions(include_deactivated)
return next((e for e in all_extensions if e.code == ext_id), None)
@classmethod
def _extensions(cls) -> list[Extension]:
p = Path(settings.lnbits_extensions_path, "extensions")
Path(p).mkdir(parents=True, exist_ok=True)
self._extension_folders: List[Path] = [f for f in p.iterdir() if f.is_dir()]
extension_folders: list[Path] = [f for f in p.iterdir() if f.is_dir()]
@property
def extensions(self) -> List[Extension]:
# todo: remove this property somehow, it is too expensive
output: List[Extension] = []
output: list[Extension] = []
for extension_folder in self._extension_folders:
for extension_folder in extension_folders:
extension_code = extension_folder.parts[-1]
try:
with open(extension_folder / "config.json") as json_file:
@ -356,13 +278,27 @@ class ExtensionRelease(BaseModel):
if not self.pay_link:
return
payment_info = await fetch_release_payment_info(self.pay_link)
payment_info = await self.fetch_release_payment_info()
self.cost_sats = payment_info.amount if payment_info else None
async def fetch_release_payment_info(
self, amount: Optional[int] = None
) -> Optional[ReleasePaymentInfo]:
url = f"{self.pay_link}?amount={amount}" if amount else self.pay_link
assert url, "Missing URL for payment info."
try:
async with httpx.AsyncClient() as client:
resp = await client.get(url)
resp.raise_for_status()
return ReleasePaymentInfo(**resp.json())
except Exception as e:
logger.warning(e)
return None
@classmethod
def from_github_release(
cls, source_repo: str, r: "GitHubRepoRelease"
) -> "ExtensionRelease":
cls, source_repo: str, r: GitHubRepoRelease
) -> ExtensionRelease:
return ExtensionRelease(
name=r.name,
description=r.name,
@ -377,8 +313,8 @@ class ExtensionRelease(BaseModel):
@classmethod
def from_explicit_release(
cls, source_repo: str, e: "ExplicitRelease"
) -> "ExtensionRelease":
cls, source_repo: str, e: ExplicitRelease
) -> ExtensionRelease:
return ExtensionRelease(
name=e.name,
version=e.version,
@ -397,9 +333,9 @@ class ExtensionRelease(BaseModel):
)
@classmethod
async def get_github_releases(cls, org: str, repo: str) -> List["ExtensionRelease"]:
async def get_github_releases(cls, org: str, repo: str) -> list[ExtensionRelease]:
try:
github_releases = await fetch_github_releases(org, repo)
github_releases = await cls.fetch_github_releases(org, repo)
return [
ExtensionRelease.from_github_release(f"{org}/{repo}", r)
for r in github_releases
@ -408,6 +344,33 @@ class ExtensionRelease(BaseModel):
logger.warning(e)
return []
@classmethod
async def fetch_github_releases(
cls, org: str, repo: str
) -> list[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await github_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
@classmethod
async def fetch_release_details(cls, details_link: str) -> Optional[dict]:
try:
async with httpx.AsyncClient() as client:
resp = await client.get(details_link)
resp.raise_for_status()
data = resp.json()
if "description_md" in data:
resp = await client.get(data["description_md"])
if not resp.is_error:
data["description_md"] = resp.text
return data
except Exception as e:
logger.warning(e)
return None
class InstallableExtension(BaseModel):
id: str
@ -415,13 +378,13 @@ class InstallableExtension(BaseModel):
active: Optional[bool] = False
short_description: Optional[str] = None
icon: Optional[str] = None
dependencies: List[str] = []
dependencies: list[str] = []
is_admin_only: bool = False
stars: int = 0
featured = False
latest_release: Optional[ExtensionRelease] = None
installed_release: Optional[ExtensionRelease] = None
payments: List[ReleasePaymentInfo] = []
payments: list[ReleasePaymentInfo] = []
pay_to_enable: Optional[PayToEnableInfo] = None
archive: Optional[str] = None
@ -546,16 +509,6 @@ class InstallableExtension(BaseModel):
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
def notify_upgrade(self, upgrade_hash: Optional[str]) -> None:
"""
Update the list of upgraded extensions. The middleware will perform
redirects based on this
"""
if upgrade_hash:
settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}")
settings.lnbits_all_extensions_ids.add(self.id)
def clean_extension_files(self):
# remove downloaded archive
if self.zip_path.is_file():
@ -610,7 +563,7 @@ class InstallableExtension(BaseModel):
self.payments.append(payment_info)
@classmethod
def from_row(cls, data: dict) -> "InstallableExtension":
def from_row(cls, data: dict) -> InstallableExtension:
meta = json.loads(data["meta"])
ext = InstallableExtension(**data)
if "installed_release" in meta:
@ -623,9 +576,7 @@ class InstallableExtension(BaseModel):
return ext
@classmethod
def from_rows(
cls, rows: Optional[List[Any]] = None
) -> List["InstallableExtension"]:
def from_rows(cls, rows: Optional[list[Any]] = None) -> list[InstallableExtension]:
if rows is None:
rows = []
return [InstallableExtension.from_row(row) for row in rows]
@ -633,9 +584,9 @@ class InstallableExtension(BaseModel):
@classmethod
async def from_github_release(
cls, github_release: GitHubRelease
) -> Optional["InstallableExtension"]:
) -> Optional[InstallableExtension]:
try:
repo, latest_release, config = await fetch_github_repo_info(
repo, latest_release, config = await cls.fetch_github_repo_info(
github_release.organisation, github_release.repository
)
source_repo = f"{github_release.organisation}/{github_release.repository}"
@ -657,7 +608,7 @@ class InstallableExtension(BaseModel):
return None
@classmethod
def from_explicit_release(cls, e: ExplicitRelease) -> "InstallableExtension":
def from_explicit_release(cls, e: ExplicitRelease) -> InstallableExtension:
return InstallableExtension(
id=e.id,
name=e.name,
@ -670,13 +621,13 @@ class InstallableExtension(BaseModel):
@classmethod
async def get_installable_extensions(
cls,
) -> List["InstallableExtension"]:
extension_list: List[InstallableExtension] = []
extension_id_list: List[str] = []
) -> list[InstallableExtension]:
extension_list: list[InstallableExtension] = []
extension_id_list: list[str] = []
for url in settings.lnbits_extensions_manifests:
try:
manifest = await fetch_manifest(url)
manifest = await cls.fetch_manifest(url)
for r in manifest.repos:
ext = await InstallableExtension.from_github_release(r)
@ -712,12 +663,12 @@ class InstallableExtension(BaseModel):
return extension_list
@classmethod
async def get_extension_releases(cls, ext_id: str) -> List["ExtensionRelease"]:
extension_releases: List[ExtensionRelease] = []
async def get_extension_releases(cls, ext_id: str) -> list[ExtensionRelease]:
extension_releases: list[ExtensionRelease] = []
for url in settings.lnbits_extensions_manifests:
try:
manifest = await fetch_manifest(url)
manifest = await cls.fetch_manifest(url)
for r in manifest.repos:
if r.id != ext_id:
continue
@ -741,8 +692,8 @@ class InstallableExtension(BaseModel):
@classmethod
async def get_extension_release(
cls, ext_id: str, source_repo: str, archive: str, version: str
) -> Optional["ExtensionRelease"]:
all_releases: List[ExtensionRelease] = (
) -> Optional[ExtensionRelease]:
all_releases: list[ExtensionRelease] = (
await InstallableExtension.get_extension_releases(ext_id)
)
selected_release = [
@ -755,6 +706,37 @@ class InstallableExtension(BaseModel):
return selected_release[0] if len(selected_release) != 0 else None
@classmethod
async def fetch_github_repo_info(
cls, org: str, repository: str
) -> tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await github_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await github_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await github_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
@classmethod
async def fetch_manifest(cls, url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await github_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
class CreateExtension(BaseModel):
ext_id: str
@ -769,32 +751,3 @@ class ExtensionDetailsRequest(BaseModel):
ext_id: str
source_repo: str
version: str
def get_valid_extensions(include_deactivated: Optional[bool] = True) -> List[Extension]:
valid_extensions = [
extension for extension in ExtensionManager().extensions if extension.is_valid
]
if include_deactivated:
return valid_extensions
if settings.lnbits_extensions_deactivate_all:
return []
return [
e
for e in valid_extensions
if e.code not in settings.lnbits_deactivated_extensions
]
def version_parse(v: str):
"""
Wrapper for version.parse() that does not throw if the version is invalid.
Instead it return the lowest possible version ("0.0.0")
"""
try:
return version.parse(v)
except Exception:
return version.parse("0.0.0")

View file

@ -1,9 +1,8 @@
import importlib
import re
from typing import Any, Optional
from typing import Any
from uuid import UUID
import httpx
from loguru import logger
from lnbits.core import migrations as core_migrations
@ -13,11 +12,10 @@ from lnbits.core.crud import (
update_migration_version,
)
from lnbits.core.db import db as core_db
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
from lnbits.extension_manager import (
from lnbits.core.extensions.models import (
Extension,
get_valid_extensions,
)
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
from lnbits.settings import settings
@ -55,68 +53,6 @@ async def run_migration(
await update_migration_version(conn, db_name, version)
async def stop_extension_background_work(
ext_id: str, user: str, access_token: Optional[str] = None
):
"""
Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
Extensions SHOULD expose a `api_stop()` function and/or a DELETE enpoint
at the root level of their API.
"""
stopped = await _stop_extension_background_work(ext_id)
if not stopped:
# fallback to REST API call
await _stop_extension_background_work_via_api(ext_id, user, access_token)
async def _stop_extension_background_work(ext_id) -> bool:
upgrade_hash = settings.extension_upgrade_hash(ext_id) or ""
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
try:
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
old_module = importlib.import_module(ext.module_name)
# Extensions must expose an `{ext_id}_stop()` function at the module level
# The `api_stop()` function is for backwards compatibility (will be deprecated)
stop_fns = [f"{ext_id}_stop", "api_stop"]
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None)
assert stop_fn_name, "No stop function found for '{ext.module_name}'"
stop_fn = getattr(old_module, stop_fn_name)
if stop_fn:
await stop_fn()
logger.info(f"Stopped background work for extension '{ext.module_name}'.")
except Exception as ex:
logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
logger.warning(ex)
return False
return True
async def _stop_extension_background_work_via_api(ext_id, user, access_token):
logger.info(
f"Stopping background work for extension '{ext_id}' using the REST API."
)
async with httpx.AsyncClient() as client:
try:
url = f"http://{settings.host}:{settings.port}/{ext_id}/api/v1?usr={user}"
headers = (
{"Authorization": "Bearer " + access_token} if access_token else None
)
resp = await client.delete(url=url, headers=headers)
resp.raise_for_status()
logger.info(f"Stopped background work for extension '{ext_id}'.")
except Exception as ex:
logger.warning(
f"Failed to stop background work for '{ext_id}' using the REST API."
)
logger.warning(ex)
def to_valid_user_id(user_id: str) -> UUID:
if len(user_id) < 32:
raise ValueError("User ID must have at least 128 bits")
@ -161,7 +97,7 @@ async def migrate_databases():
await load_disabled_extension_list()
# todo: revisit, use installed extensions
for ext in get_valid_extensions(False):
for ext in Extension.get_valid_extensions(False):
current_version = current_versions.get(ext.code, 0)
try:
await migrate_extension_database(ext, current_version)

View file

@ -1,8 +1,6 @@
import sys
from http import HTTPStatus
from typing import (
List,
Optional,
)
from bolt11 import decode as bolt11_decode
@ -13,10 +11,21 @@ from fastapi import (
)
from loguru import logger
from lnbits.core.db import core_app_extra
from lnbits.core.helpers import (
migrate_extension_database,
stop_extension_background_work,
from lnbits.core.extensions.extension_manager import (
activate_extension,
deactivate_extension,
install_extension,
uninstall_extension,
)
from lnbits.core.extensions.models import (
CreateExtension,
Extension,
ExtensionConfig,
ExtensionRelease,
InstallableExtension,
PayToEnableInfo,
ReleasePaymentInfo,
UserExtensionInfo,
)
from lnbits.core.models import (
SimpleStatus,
@ -24,36 +33,18 @@ from lnbits.core.models import (
)
from lnbits.core.services import check_transaction_status, create_invoice
from lnbits.decorators import (
check_access_token,
check_admin,
check_user_exists,
)
from lnbits.extension_manager import (
CreateExtension,
Extension,
ExtensionRelease,
InstallableExtension,
PayToEnableInfo,
ReleasePaymentInfo,
UserExtensionInfo,
fetch_github_release_config,
fetch_release_details,
fetch_release_payment_info,
get_valid_extensions,
)
from lnbits.settings import settings
from ..crud import (
add_installed_extension,
delete_dbversion,
delete_installed_extension,
drop_extension_db,
get_dbversions,
get_installed_extension,
get_installed_extensions,
get_user_extension,
update_extension_pay_to_enable,
update_installed_extension_state,
update_user_extension,
update_user_extension_extra,
)
@ -64,12 +55,8 @@ extension_router = APIRouter(
)
@extension_router.post("")
async def api_install_extension(
data: CreateExtension,
user: User = Depends(check_admin),
access_token: Optional[str] = Depends(check_access_token),
):
@extension_router.post("", dependencies=[Depends(check_admin)])
async def api_install_extension(data: CreateExtension):
release = await InstallableExtension.get_extension_release(
data.ext_id, data.source_repo, data.archive, data.version
)
@ -89,43 +76,36 @@ async def api_install_extension(
)
try:
installed_ext = await get_installed_extension(data.ext_id)
ext_info.payments = installed_ext.payments if installed_ext else []
extension = await install_extension(ext_info)
await ext_info.download_archive()
ext_info.extract_archive()
extension = Extension.from_installable_ext(ext_info)
db_version = (await get_dbversions()).get(data.ext_id, 0)
await migrate_extension_database(extension, db_version)
ext_info.active = True
await add_installed_extension(ext_info)
if extension.is_upgrade_extension:
# call stop while the old routes are still active
await stop_extension_background_work(data.ext_id, user.id, access_token)
# mount routes for the new version
core_app_extra.register_new_ext_routes(extension)
ext_info.notify_upgrade(extension.upgrade_hash)
settings.lnbits_deactivated_extensions.discard(data.ext_id)
return extension
except AssertionError as exc:
raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc
except Exception as exc:
logger.warning(exc)
ext_info.clean_extension_files()
detail = (
str(exc)
if isinstance(exc, AssertionError)
else f"Failed to install extension '{ext_info.id}'."
f"({ext_info.installed_version})."
)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=(
f"Failed to install extension {ext_info.id} "
f"({ext_info.installed_version})."
),
detail=detail,
) from exc
try:
await activate_extension(extension)
return extension
except Exception as exc:
logger.warning(exc)
await deactivate_extension(extension.code)
detail = (
str(exc)
if isinstance(exc, AssertionError)
else f"Extension `{extension.code}` installed, but activation failed."
)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=detail,
) from exc
@ -143,7 +123,7 @@ async def api_extension_details(
)
assert release, "Details not found for release"
release_details = await fetch_release_details(details_link)
release_details = await ExtensionRelease.fetch_release_details(details_link)
assert release_details, "Cannot fetch details for release"
release_details["icon"] = release.icon
release_details["repo"] = release.repo
@ -186,7 +166,7 @@ async def api_update_pay_to_enable(
async def api_enable_extension(
ext_id: str, user: User = Depends(check_user_exists)
) -> SimpleStatus:
if ext_id not in [e.code for e in get_valid_extensions()]:
if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
raise HTTPException(
HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' doesn't exist."
)
@ -249,7 +229,7 @@ async def api_enable_extension(
async def api_disable_extension(
ext_id: str, user: User = Depends(check_user_exists)
) -> SimpleStatus:
if ext_id not in [e.code for e in get_valid_extensions()]:
if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
raise HTTPException(
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
)
@ -270,20 +250,14 @@ async def api_activate_extension(ext_id: str) -> SimpleStatus:
try:
logger.info(f"Activating extension: '{ext_id}'.")
all_extensions = get_valid_extensions()
ext = next((e for e in all_extensions if e.code == ext_id), None)
ext = Extension.get_valid_extension(ext_id)
assert ext, f"Extension '{ext_id}' doesn't exist."
# if extension never loaded (was deactivated on server startup)
if ext_id not in sys.modules.keys():
# run extension start-up routine
core_app_extra.register_new_ext_routes(ext)
settings.lnbits_deactivated_extensions.discard(ext_id)
await update_installed_extension_state(ext_id=ext_id, active=True)
await activate_extension(ext)
return SimpleStatus(success=True, message=f"Extension '{ext_id}' activated.")
except Exception as exc:
logger.warning(exc)
await deactivate_extension(ext_id)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=(f"Failed to activate '{ext_id}'."),
@ -295,13 +269,10 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus:
try:
logger.info(f"Deactivating extension: '{ext_id}'.")
all_extensions = get_valid_extensions()
ext = next((e for e in all_extensions if e.code == ext_id), None)
ext = Extension.get_valid_extension(ext_id)
assert ext, f"Extension '{ext_id}' doesn't exist."
settings.lnbits_deactivated_extensions.add(ext_id)
await update_installed_extension_state(ext_id=ext_id, active=False)
await deactivate_extension(ext_id)
return SimpleStatus(success=True, message=f"Extension '{ext_id}' deactivated.")
except Exception as exc:
logger.warning(exc)
@ -311,23 +282,19 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus:
) from exc
@extension_router.delete("/{ext_id}")
async def api_uninstall_extension(
ext_id: str,
user: User = Depends(check_admin),
access_token: Optional[str] = Depends(check_access_token),
) -> SimpleStatus:
installed_extensions = await get_installed_extensions()
@extension_router.delete("/{ext_id}", dependencies=[Depends(check_admin)])
async def api_uninstall_extension(ext_id: str) -> SimpleStatus:
extensions = [e for e in installed_extensions if e.id == ext_id]
if len(extensions) == 0:
extension = await get_installed_extension(ext_id)
if not extension:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Unknown extension id: {ext_id}",
)
installed_extensions = await get_installed_extensions()
# check that other extensions do not depend on this one
for valid_ext_id in [ext.code for ext in get_valid_extensions()]:
for valid_ext_id in [ext.code for ext in Extension.get_valid_extensions()]:
installed_ext = next(
(ext for ext in installed_extensions if ext.id == valid_ext_id), None
)
@ -341,14 +308,7 @@ async def api_uninstall_extension(
)
try:
# call stop while the old routes are still active
await stop_extension_background_work(ext_id, user.id, access_token)
settings.lnbits_deactivated_extensions.add(ext_id)
for ext_info in extensions:
ext_info.clean_extension_files()
await delete_installed_extension(ext_id=ext_info.id)
await uninstall_extension(ext_id)
logger.success(f"Extension '{ext_id}' uninstalled.")
return SimpleStatus(success=True, message=f"Extension '{ext_id}' uninstalled.")
@ -397,9 +357,8 @@ async def get_pay_to_install_invoice(
assert release, "Release not found."
assert release.pay_link, "Pay link not found for release."
payment_info = await fetch_release_payment_info(
release.pay_link, data.cost_sats
)
payment_info = await release.fetch_release_payment_info(data.cost_sats)
assert payment_info and payment_info.payment_request, "Cannot request invoice."
invoice = bolt11_decode(payment_info.payment_request)
@ -474,7 +433,7 @@ async def get_pay_to_enable_invoice(
)
async def get_extension_release(org: str, repo: str, tag_name: str):
try:
config = await fetch_github_release_config(org, repo, tag_name)
config = await ExtensionConfig.fetch_github_release_config(org, repo, tag_name)
if not config:
return {}

View file

@ -12,6 +12,7 @@ from lnurl import decode as lnurl_decode
from loguru import logger
from pydantic.types import UUID4
from lnbits.core.extensions.models import Extension, InstallableExtension
from lnbits.core.helpers import to_valid_user_id
from lnbits.core.models import User
from lnbits.core.services import create_invoice
@ -20,7 +21,6 @@ from lnbits.helpers import template_renderer
from lnbits.settings import settings
from lnbits.wallets import get_funding_source
from ...extension_manager import InstallableExtension, get_valid_extensions
from ...utils.exchange_rates import allowed_currencies, currencies
from ..crud import (
create_account,
@ -104,7 +104,7 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)):
installed_exts_ids = []
try:
all_ext_ids = [ext.code for ext in get_valid_extensions()]
all_ext_ids = [ext.code for ext in Extension.get_valid_extensions()]
inactive_extensions = [
e.id for e in await get_installed_extensions(active=False)
]

View file

@ -10,6 +10,7 @@ import shortuuid
from pydantic import BaseModel
from pydantic.schema import field_schema
from lnbits.core.extensions.models import Extension
from lnbits.db import get_placeholder
from lnbits.jinja2_templating import Jinja2Templates
from lnbits.nodes import get_node_class
@ -18,7 +19,6 @@ from lnbits.settings import settings
from lnbits.utils.crypto import AESCipher
from .db import FilterModel
from .extension_manager import get_valid_extensions
def get_db_vendor_name():
@ -93,7 +93,7 @@ def template_renderer(additional_folders: Optional[List] = None) -> Jinja2Templa
settings.lnbits_node_ui and get_node_class() is not None
)
t.env.globals["LNBITS_NODE_UI_AVAILABLE"] = get_node_class() is not None
t.env.globals["EXTENSIONS"] = get_valid_extensions(False)
t.env.globals["EXTENSIONS"] = Extension.get_valid_extensions(False)
if settings.lnbits_custom_logo:
t.env.globals["USE_CUSTOM_LOGO"] = settings.lnbits_custom_logo

View file

@ -1,5 +1,5 @@
from http import HTTPStatus
from typing import Any, List, Tuple, Union
from typing import Any, List, Union
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
@ -45,16 +45,11 @@ class InstalledExtensionMiddleware:
await self.app(scope, receive, send)
return
upgrade_path = next(
(
e
for e in settings.lnbits_upgraded_extensions
if e.endswith(f"/{top_path}")
),
None,
)
# re-route all trafic if the extension has been upgraded
if upgrade_path:
if top_path in settings.lnbits_upgraded_extensions:
upgrade_path = (
f"""{settings.lnbits_upgraded_extensions[top_path]}/{top_path}"""
)
tail = "/".join(rest)
scope["path"] = f"/upgrades/{upgrade_path}/{tail}"
@ -118,72 +113,12 @@ class ExtensionsRedirectMiddleware:
return
req_headers = scope["headers"] if "headers" in scope else []
redirect = self._find_redirect(scope["path"], req_headers)
redirect = settings.find_extension_redirect(scope["path"], req_headers)
if redirect:
scope["path"] = self._new_path(redirect, scope["path"])
scope["path"] = redirect.new_path_from(scope["path"])
await self.app(scope, receive, send)
def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]):
return next(
(
r
for r in settings.lnbits_extensions_redirects
if self._redirect_matches(r, path, req_headers)
),
None,
)
def _redirect_matches(
self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]]
) -> bool:
if "from_path" not in redirect:
return False
header_filters = (
redirect["header_filters"] if "header_filters" in redirect else {}
)
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
header_filters, req_headers
)
def _has_headers(
self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]]
) -> bool:
for h in filter_headers:
if not self._has_header(req_headers, (str(h), str(filter_headers[h]))):
return False
return True
def _has_header(
self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str]
) -> bool:
for h in req_headers:
if (
h[0].decode().lower() == header[0].lower()
and h[1].decode() == header[1]
):
return True
return False
def _has_common_path(self, redirect_path: str, req_path: str) -> bool:
redirect_path_elements = redirect_path.split("/")
req_path_elements = req_path.split("/")
if len(redirect_path) > len(req_path):
return False
sub_path = req_path_elements[: len(redirect_path_elements)]
return redirect_path == "/".join(sub_path)
def _new_path(self, redirect: dict, req_path: str) -> str:
from_path = redirect["from_path"].split("/")
redirect_to = redirect["redirect_to_path"].split("/")
req_tail_path = req_path.split("/")[len(from_path) :]
elements = [
e for e in ([redirect["ext_id"], *redirect_to, *req_tail_path]) if e != ""
]
return "/" + "/".join(elements)
def add_ratelimit_middleware(app: FastAPI):
core_app_extra.register_new_ratelimiter()

View file

@ -62,26 +62,132 @@ class ExtensionsInstallSettings(LNbitsSettings):
lnbits_ext_github_token: str = Field(default="")
class RedirectPath(BaseModel):
ext_id: str
from_path: str
redirect_to_path: str
header_filters: dict = {}
def in_conflict(self, other: RedirectPath) -> bool:
if self.ext_id == other.ext_id:
return False
return self.redirect_matches(
other.from_path, list(other.header_filters.items())
) or other.redirect_matches(self.from_path, list(self.header_filters.items()))
def find_in_conflict(self, others: list[RedirectPath]) -> Optional[RedirectPath]:
for other in others:
if self.in_conflict(other):
return other
return None
def new_path_from(self, req_path: str) -> str:
from_path = self.from_path.split("/")
redirect_to = self.redirect_to_path.split("/")
req_tail_path = req_path.split("/")[len(from_path) :]
elements = [e for e in ([self.ext_id, *redirect_to, *req_tail_path]) if e != ""]
return "/" + "/".join(elements)
def redirect_matches(self, path: str, req_headers: list[tuple[str, str]]) -> bool:
return self._has_common_path(path) and self._has_headers(req_headers)
def _has_common_path(self, req_path: str) -> bool:
if len(self.from_path) > len(req_path):
return False
redirect_path_elements = self.from_path.split("/")
req_path_elements = req_path.split("/")
sub_path = req_path_elements[: len(redirect_path_elements)]
return self.from_path == "/".join(sub_path)
def _has_headers(self, req_headers: list[tuple[str, str]]) -> bool:
for h in self.header_filters:
if not self._has_header(req_headers, (str(h), str(self.header_filters[h]))):
return False
return True
def _has_header(
self, req_headers: list[tuple[str, str]], header: tuple[str, str]
) -> bool:
for h in req_headers:
if h[0].lower() == header[0].lower() and h[1].lower() == header[1].lower():
return True
return False
class InstalledExtensionsSettings(LNbitsSettings):
# installed extensions that have been deactivated
lnbits_deactivated_extensions: set[str] = Field(default=[])
# upgraded extensions that require API redirects
lnbits_upgraded_extensions: set[str] = Field(default=[])
lnbits_upgraded_extensions: dict[str, str] = Field(default={})
# list of redirects that extensions want to perform
lnbits_extensions_redirects: list[Any] = Field(default=[])
lnbits_extensions_redirects: list[RedirectPath] = Field(default=[])
# list of all extension ids
lnbits_all_extensions_ids: set[Any] = Field(default=[])
def extension_upgrade_path(self, ext_id: str) -> Optional[str]:
def find_extension_redirect(
self, path: str, req_headers: list[tuple[bytes, bytes]]
) -> Optional[RedirectPath]:
headers = [(k.decode(), v.decode()) for k, v in req_headers]
return next(
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")),
(
r
for r in self.lnbits_extensions_redirects
if r.redirect_matches(path, headers)
),
None,
)
def extension_upgrade_hash(self, ext_id: str) -> Optional[str]:
path = settings.extension_upgrade_path(ext_id)
return path.split("/")[0] if path else None
def activate_extension_paths(
self,
ext_id: str,
upgrade_hash: Optional[str] = None,
ext_redirects: Optional[list[dict]] = None,
):
self.lnbits_deactivated_extensions.discard(ext_id)
"""
Update the list of upgraded extensions. The middleware will perform
redirects based on this
"""
if upgrade_hash:
self.lnbits_upgraded_extensions[ext_id] = upgrade_hash
if ext_redirects:
self._activate_extension_redirects(ext_id, ext_redirects)
self.lnbits_all_extensions_ids.add(ext_id)
def deactivate_extension_paths(self, ext_id: str):
self.lnbits_deactivated_extensions.add(ext_id)
self._remove_extension_redirects(ext_id)
def _activate_extension_redirects(self, ext_id: str, ext_redirects: list[dict]):
ext_redirect_paths = [
RedirectPath(**{"ext_id": ext_id, **er}) for er in ext_redirects
]
existing_redirects = {
r.ext_id
for r in self.lnbits_extensions_redirects
if r.find_in_conflict(ext_redirect_paths)
}
assert len(existing_redirects) == 0, (
f"Cannot redirect for extension '{ext_id}'."
f" Already mapped by {existing_redirects}."
)
self._remove_extension_redirects(ext_id)
self.lnbits_extensions_redirects += ext_redirect_paths
def _remove_extension_redirects(self, ext_id: str):
self.lnbits_extensions_redirects = [
er for er in self.lnbits_extensions_redirects if er.ext_id != ext_id
]
class ThemesSettings(LNbitsSettings):

168
tests/unit/test_settings.py Normal file
View file

@ -0,0 +1,168 @@
import pytest
from lnbits.settings import RedirectPath
lnurlp_redirect_path = {
"from_path": "/.well-known/lnurlp",
"redirect_to_path": "/api/v1/well-known",
}
lnurlp_redirect_path_with_headers = {
"from_path": "/.well-known/lnurlp",
"redirect_to_path": "/api/v1/well-known",
"header_filters": {"accept": "application/nostr+json"},
}
lnaddress_redirect_path = {
"from_path": "/.well-known/lnurlp",
"redirect_to_path": "/api/v1/well-known",
}
nostrrelay_redirect_path = {
"from_path": "/",
"redirect_to_path": "/api/v1/relay-info",
"header_filters": {"accept": "application/nostr+json"},
}
@pytest.fixture()
def lnurlp():
return RedirectPath(ext_id="lnurlp", **lnurlp_redirect_path)
@pytest.fixture()
def lnurlp_with_headers():
return RedirectPath(
ext_id="lnurlp_with_headers", **lnurlp_redirect_path_with_headers
)
@pytest.fixture()
def lnaddress():
return RedirectPath(ext_id="lnaddress", **lnaddress_redirect_path)
@pytest.fixture()
def nostrrelay():
return RedirectPath(ext_id="nostrrelay", **nostrrelay_redirect_path)
def test_redirect_path_self_not_in_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert not lnurlp.in_conflict(lnurlp), "Path is not in conflict with itself."
assert not lnaddress.in_conflict(lnaddress), "Path is not in conflict with itself."
assert not nostrrelay.in_conflict(
nostrrelay
), "Path is not in conflict with itself."
assert not lnurlp.in_conflict(nostrrelay)
assert not nostrrelay.in_conflict(lnurlp)
def test_redirect_path_not_in_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert not lnurlp.in_conflict(nostrrelay)
assert not nostrrelay.in_conflict(lnurlp)
assert not lnaddress.in_conflict(nostrrelay)
assert not nostrrelay.in_conflict(lnaddress)
def test_redirect_path_in_conflict(lnurlp: RedirectPath, lnaddress: RedirectPath):
assert lnurlp.in_conflict(lnaddress)
assert lnaddress.in_conflict(lnurlp)
def test_redirect_path_find_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert lnurlp.find_in_conflict([nostrrelay, lnaddress])
assert lnurlp.find_in_conflict([lnaddress, nostrrelay])
assert lnaddress.find_in_conflict([nostrrelay, lnurlp])
assert lnaddress.find_in_conflict([lnurlp, nostrrelay])
def test_redirect_path_find_no_conflict(
lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath
):
assert not nostrrelay.find_in_conflict([lnurlp, lnaddress])
assert not lnurlp.find_in_conflict([nostrrelay])
assert not lnaddress.find_in_conflict([nostrrelay])
def test_redirect_path_in_conflict_with_headers(
lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath
):
assert lnurlp.in_conflict(lnurlp_with_headers)
assert lnurlp_with_headers.in_conflict(lnurlp)
def test_redirect_path_matches_with_headers(
lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath
):
headers_list = list(lnurlp_with_headers.header_filters.items())
assert lnurlp.redirect_matches(
path=lnurlp_with_headers.from_path,
req_headers=headers_list,
)
assert lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"],
req_headers=[("ACCEPT", "APPlication/nostr+json")],
)
assert lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"],
req_headers=[("accept", "application/nostr+json"), ("my_header", "my_value")],
)
assert not lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"], req_headers=[]
)
assert not lnurlp_with_headers.redirect_matches(
path=lnurlp_redirect_path["from_path"],
req_headers=[("accept", "application/json")],
)
assert not lnurlp_with_headers.redirect_matches(path="/random/path", req_headers=[])
assert not lnurlp_with_headers.redirect_matches(path="/random_path", req_headers=[])
assert not lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp", req_headers=[]
)
assert lnurlp.redirect_matches(path="/.well-known/lnurlp", req_headers=[])
assert lnurlp.redirect_matches(
path="/.well-known/lnurlp/some/other/path", req_headers=[]
)
assert lnurlp.redirect_matches(
path="/.well-known/lnurlp/some/other/path",
req_headers=headers_list,
)
assert not lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp", req_headers=[]
)
assert not lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp/some/other/path", req_headers=[]
)
assert lnurlp_with_headers.redirect_matches(
path="/.well-known/lnurlp/some/other/path",
req_headers=headers_list,
)
def test_redirect_path_new_path_from(lnurlp: RedirectPath):
assert lnurlp.new_path_from("") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/path") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/path/more") == "/lnurlp/api/v1/well-known"
assert lnurlp.new_path_from("/.well-known/lnurlp") == "/lnurlp/api/v1/well-known"
assert (
lnurlp.new_path_from("/.well-known/lnurlp/path")
== "/lnurlp/api/v1/well-known/path"
)
assert (
lnurlp.new_path_from("/.well-known/lnurlp/path/more")
== "/lnurlp/api/v1/well-known/path/more"
)