mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-03-15 12:20:21 +01:00
[fix] check user extension access (#2519)
* feat: check user extension access * fix: handle upgraded extensions
This commit is contained in:
parent
d4da96597e
commit
44b458ebb8
8 changed files with 66 additions and 56 deletions
|
@ -261,10 +261,10 @@ async def build_all_installed_extensions_list(
|
|||
MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL).
|
||||
"""
|
||||
installed_extensions = await get_installed_extensions()
|
||||
settings.lnbits_all_extensions_ids = {e.id for e in installed_extensions}
|
||||
|
||||
installed_extensions_ids = [e.id for e in installed_extensions]
|
||||
for ext_id in settings.lnbits_extensions_default_install:
|
||||
if ext_id in installed_extensions_ids:
|
||||
if ext_id in settings.lnbits_all_extensions_ids:
|
||||
continue
|
||||
|
||||
ext_releases = await InstallableExtension.get_extension_releases(ext_id)
|
||||
|
@ -318,8 +318,7 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
|
|||
|
||||
# mount routes for the new version
|
||||
core_app_extra.register_new_ext_routes(extension)
|
||||
if extension.upgrade_hash:
|
||||
ext.notify_upgrade()
|
||||
ext.notify_upgrade(extension.upgrade_hash)
|
||||
|
||||
|
||||
def register_custom_extensions_path():
|
||||
|
|
|
@ -316,7 +316,7 @@ async def check_invalid_payments(
|
|||
async def load_disabled_extension_list() -> None:
|
||||
"""Update list of extensions that have been explicitly disabled"""
|
||||
inactive_extensions = await get_inactive_extensions()
|
||||
settings.lnbits_deactivated_extensions += inactive_extensions
|
||||
settings.lnbits_deactivated_extensions.update(inactive_extensions)
|
||||
|
||||
|
||||
@extensions.command("list")
|
||||
|
|
|
@ -322,10 +322,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
|
|||
)
|
||||
|
||||
if user:
|
||||
extensions = await (conn or db).fetchall(
|
||||
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
|
||||
(user_id,),
|
||||
)
|
||||
extensions = await get_user_active_extensions_ids(user_id, conn)
|
||||
wallets = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT *, COALESCE((
|
||||
|
@ -344,7 +341,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
|
|||
email=user["email"],
|
||||
username=user["username"],
|
||||
extensions=[
|
||||
e[0] for e in extensions if User.is_extension_for_user(e[0], user["id"])
|
||||
e for e in extensions if User.is_extension_for_user(e[0], user["id"])
|
||||
],
|
||||
wallets=[Wallet(**w) for w in wallets],
|
||||
admin=user["id"] == settings.super_user
|
||||
|
@ -482,6 +479,16 @@ async def update_user_extension(
|
|||
)
|
||||
|
||||
|
||||
async def get_user_active_extensions_ids(
|
||||
user_id: str, conn: Optional[Connection] = None
|
||||
) -> List[str]:
|
||||
rows = await (conn or db).fetchall(
|
||||
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
|
||||
(user_id,),
|
||||
)
|
||||
return [e[0] for e in rows]
|
||||
|
||||
|
||||
# wallets
|
||||
# -------
|
||||
|
||||
|
|
|
@ -94,14 +94,12 @@ async def api_install_extension(
|
|||
# call stop while the old routes are still active
|
||||
await stop_extension_background_work(data.ext_id, user.id, access_token)
|
||||
|
||||
if data.ext_id not in settings.lnbits_deactivated_extensions:
|
||||
settings.lnbits_deactivated_extensions += [data.ext_id]
|
||||
settings.lnbits_deactivated_extensions.add(data.ext_id)
|
||||
|
||||
# mount routes for the new version
|
||||
core_app_extra.register_new_ext_routes(extension)
|
||||
|
||||
if extension.upgrade_hash:
|
||||
ext_info.notify_upgrade()
|
||||
ext_info.notify_upgrade(extension.upgrade_hash)
|
||||
|
||||
return extension
|
||||
except AssertionError as exc:
|
||||
|
@ -151,8 +149,7 @@ async def api_uninstall_extension(
|
|||
# call stop while the old routes are still active
|
||||
await stop_extension_background_work(ext_id, user.id, access_token)
|
||||
|
||||
if ext_id not in settings.lnbits_deactivated_extensions:
|
||||
settings.lnbits_deactivated_extensions += [ext_id]
|
||||
settings.lnbits_deactivated_extensions.add(ext_id)
|
||||
|
||||
for ext_info in extensions:
|
||||
ext_info.clean_extension_files()
|
||||
|
|
|
@ -115,19 +115,15 @@ async def extensions_install(
|
|||
all_extensions = get_valid_extensions()
|
||||
ext = next((e for e in all_extensions if e.code == ext_id), None)
|
||||
if ext_id and user.admin:
|
||||
if deactivate and deactivate not in settings.lnbits_deactivated_extensions:
|
||||
settings.lnbits_deactivated_extensions += [deactivate]
|
||||
if deactivate:
|
||||
settings.lnbits_deactivated_extensions.add(deactivate)
|
||||
elif activate:
|
||||
# if extension never loaded (was deactivated on server startup)
|
||||
if ext_id not in sys.modules.keys():
|
||||
# run extension start-up routine
|
||||
core_app_extra.register_new_ext_routes(ext)
|
||||
|
||||
settings.lnbits_deactivated_extensions = list(
|
||||
filter(
|
||||
lambda e: e != activate, settings.lnbits_deactivated_extensions
|
||||
)
|
||||
)
|
||||
settings.lnbits_deactivated_extensions.remove(activate)
|
||||
|
||||
await update_installed_extension_state(
|
||||
ext_id=ext_id, active=activate is not None
|
||||
|
|
|
@ -15,6 +15,7 @@ from lnbits.core.crud import (
|
|||
get_account_by_email,
|
||||
get_account_by_username,
|
||||
get_user,
|
||||
get_user_active_extensions_ids,
|
||||
get_wallet_for_key,
|
||||
)
|
||||
from lnbits.core.models import KeyType, User, WalletTypeInfo
|
||||
|
@ -88,16 +89,7 @@ class KeyChecker(SecurityBase):
|
|||
detail="Invalid adminkey.",
|
||||
)
|
||||
|
||||
if (
|
||||
wallet.user != settings.super_user
|
||||
and wallet.user not in settings.lnbits_admin_users
|
||||
and settings.lnbits_admin_extensions
|
||||
and request["path"].split("/")[1] in settings.lnbits_admin_extensions
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="User not authorized for this extension.",
|
||||
)
|
||||
await _check_user_extension_access(wallet.user, request["path"])
|
||||
|
||||
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
||||
return WalletTypeInfo(key_type, wallet)
|
||||
|
@ -161,15 +153,7 @@ async def check_user_exists(
|
|||
user = await get_user(account.id)
|
||||
assert user, "User not found for account."
|
||||
|
||||
if (
|
||||
user.id != settings.super_user
|
||||
and user.id not in settings.lnbits_admin_users
|
||||
and settings.lnbits_admin_extensions
|
||||
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
|
||||
):
|
||||
raise HTTPException(
|
||||
HTTPStatus.UNAUTHORIZED, "User not authorized for extension."
|
||||
)
|
||||
await _check_user_extension_access(user.id, r["path"])
|
||||
|
||||
return user
|
||||
|
||||
|
@ -226,6 +210,28 @@ def parse_filters(model: Type[TFilterModel]):
|
|||
return dependency
|
||||
|
||||
|
||||
async def _check_user_extension_access(user_id: str, current_path: str):
|
||||
"""
|
||||
Check if the user has access to a particular extension.
|
||||
Raises HTTP Forbidden if the user is not allowed.
|
||||
"""
|
||||
path = current_path.split("/")
|
||||
ext_id = path[3] if path[1] == "upgrades" else path[1]
|
||||
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
|
||||
raise HTTPException(
|
||||
HTTPStatus.FORBIDDEN,
|
||||
f"User not authorized for extension '{ext_id}'.",
|
||||
)
|
||||
|
||||
if settings.is_extension_id(ext_id):
|
||||
ext_ids = await get_user_active_extensions_ids(user_id)
|
||||
if ext_id not in ext_ids:
|
||||
raise HTTPException(
|
||||
HTTPStatus.FORBIDDEN,
|
||||
f"User extension '{ext_id}' not enabled.",
|
||||
)
|
||||
|
||||
|
||||
async def _get_account_from_token(access_token):
|
||||
try:
|
||||
payload = jwt.decode(access_token, settings.auth_secret_key, "HS256")
|
||||
|
|
|
@ -479,22 +479,15 @@ class InstallableExtension(BaseModel):
|
|||
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
|
||||
logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
|
||||
|
||||
def notify_upgrade(self) -> None:
|
||||
def notify_upgrade(self, upgrade_hash: Optional[str]) -> None:
|
||||
"""
|
||||
Update the list of upgraded extensions. The middleware will perform
|
||||
redirects based on this
|
||||
"""
|
||||
if upgrade_hash:
|
||||
settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}")
|
||||
|
||||
clean_upgraded_exts = list(
|
||||
filter(
|
||||
lambda old_ext: not old_ext.endswith(f"/{self.id}"),
|
||||
settings.lnbits_upgraded_extensions,
|
||||
)
|
||||
)
|
||||
settings.lnbits_upgraded_extensions = [
|
||||
*clean_upgraded_exts,
|
||||
f"{self.hash}/{self.id}",
|
||||
]
|
||||
settings.lnbits_all_extensions_ids.add(self.id)
|
||||
|
||||
def clean_extension_files(self):
|
||||
# remove downloaded archive
|
||||
|
|
|
@ -63,12 +63,15 @@ class ExtensionsInstallSettings(LNbitsSettings):
|
|||
|
||||
class InstalledExtensionsSettings(LNbitsSettings):
|
||||
# installed extensions that have been deactivated
|
||||
lnbits_deactivated_extensions: list[str] = Field(default=[])
|
||||
lnbits_deactivated_extensions: set[str] = Field(default=[])
|
||||
# upgraded extensions that require API redirects
|
||||
lnbits_upgraded_extensions: list[str] = Field(default=[])
|
||||
lnbits_upgraded_extensions: set[str] = Field(default=[])
|
||||
# list of redirects that extensions want to perform
|
||||
lnbits_extensions_redirects: list[Any] = Field(default=[])
|
||||
|
||||
# list of all extension ids
|
||||
lnbits_all_extensions_ids: set[Any] = Field(default=[])
|
||||
|
||||
def extension_upgrade_path(self, ext_id: str) -> Optional[str]:
|
||||
return next(
|
||||
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")),
|
||||
|
@ -481,7 +484,7 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
|
|||
case_sensitive = False
|
||||
json_loads = list_parse_fallback
|
||||
|
||||
def is_user_allowed(self, user_id: str):
|
||||
def is_user_allowed(self, user_id: str) -> bool:
|
||||
return (
|
||||
len(self.lnbits_allowed_users) == 0
|
||||
or user_id in self.lnbits_allowed_users
|
||||
|
@ -489,6 +492,15 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
|
|||
or user_id == self.super_user
|
||||
)
|
||||
|
||||
def is_admin_user(self, user_id: str) -> bool:
|
||||
return user_id in self.lnbits_admin_users or user_id == self.super_user
|
||||
|
||||
def is_admin_extension(self, ext_id: str) -> bool:
|
||||
return ext_id in self.lnbits_admin_extensions
|
||||
|
||||
def is_extension_id(self, ext_id: str) -> bool:
|
||||
return ext_id in self.lnbits_all_extensions_ids
|
||||
|
||||
|
||||
class SuperSettings(EditableSettings):
|
||||
super_user: str
|
||||
|
|
Loading…
Add table
Reference in a new issue