From 44b458ebb8dc67099b86c36ac27da6e492978365 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Tue, 21 May 2024 13:17:02 +0300 Subject: [PATCH] [fix] check user extension access (#2519) * feat: check user extension access * fix: handle upgraded extensions --- lnbits/app.py | 7 ++--- lnbits/commands.py | 2 +- lnbits/core/crud.py | 17 ++++++++---- lnbits/core/views/extension_api.py | 9 ++---- lnbits/core/views/generic.py | 10 ++----- lnbits/decorators.py | 44 +++++++++++++++++------------- lnbits/extension_manager.py | 15 +++------- lnbits/settings.py | 18 ++++++++++-- 8 files changed, 66 insertions(+), 56 deletions(-) diff --git a/lnbits/app.py b/lnbits/app.py index b8fb73ca5..97fca316a 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -261,10 +261,10 @@ async def build_all_installed_extensions_list( MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL). """ installed_extensions = await get_installed_extensions() + settings.lnbits_all_extensions_ids = {e.id for e in installed_extensions} - installed_extensions_ids = [e.id for e in installed_extensions] for ext_id in settings.lnbits_extensions_default_install: - if ext_id in installed_extensions_ids: + if ext_id in settings.lnbits_all_extensions_ids: continue ext_releases = await InstallableExtension.get_extension_releases(ext_id) @@ -318,8 +318,7 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension): # mount routes for the new version core_app_extra.register_new_ext_routes(extension) - if extension.upgrade_hash: - ext.notify_upgrade() + ext.notify_upgrade(extension.upgrade_hash) def register_custom_extensions_path(): diff --git a/lnbits/commands.py b/lnbits/commands.py index 30b889321..0ce5008e9 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -316,7 +316,7 @@ async def check_invalid_payments( async def load_disabled_extension_list() -> None: """Update list of extensions that have been explicitly disabled""" inactive_extensions = await get_inactive_extensions() - settings.lnbits_deactivated_extensions += inactive_extensions + settings.lnbits_deactivated_extensions.update(inactive_extensions) @extensions.command("list") diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 4ad1fca76..92b40d4f8 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -322,10 +322,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[ ) if user: - extensions = await (conn or db).fetchall( - """SELECT extension FROM extensions WHERE "user" = ? AND active""", - (user_id,), - ) + extensions = await get_user_active_extensions_ids(user_id, conn) wallets = await (conn or db).fetchall( """ SELECT *, COALESCE(( @@ -344,7 +341,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[ email=user["email"], username=user["username"], extensions=[ - e[0] for e in extensions if User.is_extension_for_user(e[0], user["id"]) + e for e in extensions if User.is_extension_for_user(e[0], user["id"]) ], wallets=[Wallet(**w) for w in wallets], admin=user["id"] == settings.super_user @@ -482,6 +479,16 @@ async def update_user_extension( ) +async def get_user_active_extensions_ids( + user_id: str, conn: Optional[Connection] = None +) -> List[str]: + rows = await (conn or db).fetchall( + """SELECT extension FROM extensions WHERE "user" = ? AND active""", + (user_id,), + ) + return [e[0] for e in rows] + + # wallets # ------- diff --git a/lnbits/core/views/extension_api.py b/lnbits/core/views/extension_api.py index 853cc9b77..0cba2f95a 100644 --- a/lnbits/core/views/extension_api.py +++ b/lnbits/core/views/extension_api.py @@ -94,14 +94,12 @@ async def api_install_extension( # call stop while the old routes are still active await stop_extension_background_work(data.ext_id, user.id, access_token) - if data.ext_id not in settings.lnbits_deactivated_extensions: - settings.lnbits_deactivated_extensions += [data.ext_id] + settings.lnbits_deactivated_extensions.add(data.ext_id) # mount routes for the new version core_app_extra.register_new_ext_routes(extension) - if extension.upgrade_hash: - ext_info.notify_upgrade() + ext_info.notify_upgrade(extension.upgrade_hash) return extension except AssertionError as exc: @@ -151,8 +149,7 @@ async def api_uninstall_extension( # call stop while the old routes are still active await stop_extension_background_work(ext_id, user.id, access_token) - if ext_id not in settings.lnbits_deactivated_extensions: - settings.lnbits_deactivated_extensions += [ext_id] + settings.lnbits_deactivated_extensions.add(ext_id) for ext_info in extensions: ext_info.clean_extension_files() diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index d9a9e1801..06142fa9c 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -115,19 +115,15 @@ async def extensions_install( all_extensions = get_valid_extensions() ext = next((e for e in all_extensions if e.code == ext_id), None) if ext_id and user.admin: - if deactivate and deactivate not in settings.lnbits_deactivated_extensions: - settings.lnbits_deactivated_extensions += [deactivate] + if deactivate: + settings.lnbits_deactivated_extensions.add(deactivate) elif activate: # if extension never loaded (was deactivated on server startup) if ext_id not in sys.modules.keys(): # run extension start-up routine core_app_extra.register_new_ext_routes(ext) - settings.lnbits_deactivated_extensions = list( - filter( - lambda e: e != activate, settings.lnbits_deactivated_extensions - ) - ) + settings.lnbits_deactivated_extensions.remove(activate) await update_installed_extension_state( ext_id=ext_id, active=activate is not None diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 497c8ffed..26bf38eeb 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -15,6 +15,7 @@ from lnbits.core.crud import ( get_account_by_email, get_account_by_username, get_user, + get_user_active_extensions_ids, get_wallet_for_key, ) from lnbits.core.models import KeyType, User, WalletTypeInfo @@ -88,16 +89,7 @@ class KeyChecker(SecurityBase): detail="Invalid adminkey.", ) - if ( - wallet.user != settings.super_user - and wallet.user not in settings.lnbits_admin_users - and settings.lnbits_admin_extensions - and request["path"].split("/")[1] in settings.lnbits_admin_extensions - ): - raise HTTPException( - status_code=HTTPStatus.FORBIDDEN, - detail="User not authorized for this extension.", - ) + await _check_user_extension_access(wallet.user, request["path"]) key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice return WalletTypeInfo(key_type, wallet) @@ -161,15 +153,7 @@ async def check_user_exists( user = await get_user(account.id) assert user, "User not found for account." - if ( - user.id != settings.super_user - and user.id not in settings.lnbits_admin_users - and settings.lnbits_admin_extensions - and r["path"].split("/")[1] in settings.lnbits_admin_extensions - ): - raise HTTPException( - HTTPStatus.UNAUTHORIZED, "User not authorized for extension." - ) + await _check_user_extension_access(user.id, r["path"]) return user @@ -226,6 +210,28 @@ def parse_filters(model: Type[TFilterModel]): return dependency +async def _check_user_extension_access(user_id: str, current_path: str): + """ + Check if the user has access to a particular extension. + Raises HTTP Forbidden if the user is not allowed. + """ + path = current_path.split("/") + ext_id = path[3] if path[1] == "upgrades" else path[1] + if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id): + raise HTTPException( + HTTPStatus.FORBIDDEN, + f"User not authorized for extension '{ext_id}'.", + ) + + if settings.is_extension_id(ext_id): + ext_ids = await get_user_active_extensions_ids(user_id) + if ext_id not in ext_ids: + raise HTTPException( + HTTPStatus.FORBIDDEN, + f"User extension '{ext_id}' not enabled.", + ) + + async def _get_account_from_token(access_token): try: payload = jwt.decode(access_token, settings.auth_secret_key, "HS256") diff --git a/lnbits/extension_manager.py b/lnbits/extension_manager.py index e0d3d3106..feae949f7 100644 --- a/lnbits/extension_manager.py +++ b/lnbits/extension_manager.py @@ -479,22 +479,15 @@ class InstallableExtension(BaseModel): shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir)) logger.success(f"Extension {self.name} ({self.installed_version}) installed.") - def notify_upgrade(self) -> None: + def notify_upgrade(self, upgrade_hash: Optional[str]) -> None: """ Update the list of upgraded extensions. The middleware will perform redirects based on this """ + if upgrade_hash: + settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}") - clean_upgraded_exts = list( - filter( - lambda old_ext: not old_ext.endswith(f"/{self.id}"), - settings.lnbits_upgraded_extensions, - ) - ) - settings.lnbits_upgraded_extensions = [ - *clean_upgraded_exts, - f"{self.hash}/{self.id}", - ] + settings.lnbits_all_extensions_ids.add(self.id) def clean_extension_files(self): # remove downloaded archive diff --git a/lnbits/settings.py b/lnbits/settings.py index c0ff64f07..76f0f52d2 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -63,12 +63,15 @@ class ExtensionsInstallSettings(LNbitsSettings): class InstalledExtensionsSettings(LNbitsSettings): # installed extensions that have been deactivated - lnbits_deactivated_extensions: list[str] = Field(default=[]) + lnbits_deactivated_extensions: set[str] = Field(default=[]) # upgraded extensions that require API redirects - lnbits_upgraded_extensions: list[str] = Field(default=[]) + lnbits_upgraded_extensions: set[str] = Field(default=[]) # list of redirects that extensions want to perform lnbits_extensions_redirects: list[Any] = Field(default=[]) + # list of all extension ids + lnbits_all_extensions_ids: set[Any] = Field(default=[]) + def extension_upgrade_path(self, ext_id: str) -> Optional[str]: return next( (e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")), @@ -481,7 +484,7 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin case_sensitive = False json_loads = list_parse_fallback - def is_user_allowed(self, user_id: str): + def is_user_allowed(self, user_id: str) -> bool: return ( len(self.lnbits_allowed_users) == 0 or user_id in self.lnbits_allowed_users @@ -489,6 +492,15 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin or user_id == self.super_user ) + def is_admin_user(self, user_id: str) -> bool: + return user_id in self.lnbits_admin_users or user_id == self.super_user + + def is_admin_extension(self, ext_id: str) -> bool: + return ext_id in self.lnbits_admin_extensions + + def is_extension_id(self, ext_id: str) -> bool: + return ext_id in self.lnbits_all_extensions_ids + class SuperSettings(EditableSettings): super_user: str