[fix] check user extension access (#2519)

* feat: check user extension access
* fix: handle upgraded extensions
This commit is contained in:
Vlad Stan 2024-05-21 13:17:02 +03:00 committed by dni ⚡
parent d4da96597e
commit 44b458ebb8
No known key found for this signature in database
GPG key ID: 886317704CC4E618
8 changed files with 66 additions and 56 deletions

View file

@ -261,10 +261,10 @@ async def build_all_installed_extensions_list(
MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL).
"""
installed_extensions = await get_installed_extensions()
settings.lnbits_all_extensions_ids = {e.id for e in installed_extensions}
installed_extensions_ids = [e.id for e in installed_extensions]
for ext_id in settings.lnbits_extensions_default_install:
if ext_id in installed_extensions_ids:
if ext_id in settings.lnbits_all_extensions_ids:
continue
ext_releases = await InstallableExtension.get_extension_releases(ext_id)
@ -318,8 +318,7 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
# mount routes for the new version
core_app_extra.register_new_ext_routes(extension)
if extension.upgrade_hash:
ext.notify_upgrade()
ext.notify_upgrade(extension.upgrade_hash)
def register_custom_extensions_path():

View file

@ -316,7 +316,7 @@ async def check_invalid_payments(
async def load_disabled_extension_list() -> None:
"""Update list of extensions that have been explicitly disabled"""
inactive_extensions = await get_inactive_extensions()
settings.lnbits_deactivated_extensions += inactive_extensions
settings.lnbits_deactivated_extensions.update(inactive_extensions)
@extensions.command("list")

View file

@ -322,10 +322,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
)
if user:
extensions = await (conn or db).fetchall(
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
(user_id,),
)
extensions = await get_user_active_extensions_ids(user_id, conn)
wallets = await (conn or db).fetchall(
"""
SELECT *, COALESCE((
@ -344,7 +341,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
email=user["email"],
username=user["username"],
extensions=[
e[0] for e in extensions if User.is_extension_for_user(e[0], user["id"])
e for e in extensions if User.is_extension_for_user(e[0], user["id"])
],
wallets=[Wallet(**w) for w in wallets],
admin=user["id"] == settings.super_user
@ -482,6 +479,16 @@ async def update_user_extension(
)
async def get_user_active_extensions_ids(
user_id: str, conn: Optional[Connection] = None
) -> List[str]:
rows = await (conn or db).fetchall(
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
(user_id,),
)
return [e[0] for e in rows]
# wallets
# -------

View file

@ -94,14 +94,12 @@ async def api_install_extension(
# call stop while the old routes are still active
await stop_extension_background_work(data.ext_id, user.id, access_token)
if data.ext_id not in settings.lnbits_deactivated_extensions:
settings.lnbits_deactivated_extensions += [data.ext_id]
settings.lnbits_deactivated_extensions.add(data.ext_id)
# mount routes for the new version
core_app_extra.register_new_ext_routes(extension)
if extension.upgrade_hash:
ext_info.notify_upgrade()
ext_info.notify_upgrade(extension.upgrade_hash)
return extension
except AssertionError as exc:
@ -151,8 +149,7 @@ async def api_uninstall_extension(
# call stop while the old routes are still active
await stop_extension_background_work(ext_id, user.id, access_token)
if ext_id not in settings.lnbits_deactivated_extensions:
settings.lnbits_deactivated_extensions += [ext_id]
settings.lnbits_deactivated_extensions.add(ext_id)
for ext_info in extensions:
ext_info.clean_extension_files()

View file

@ -115,19 +115,15 @@ async def extensions_install(
all_extensions = get_valid_extensions()
ext = next((e for e in all_extensions if e.code == ext_id), None)
if ext_id and user.admin:
if deactivate and deactivate not in settings.lnbits_deactivated_extensions:
settings.lnbits_deactivated_extensions += [deactivate]
if deactivate:
settings.lnbits_deactivated_extensions.add(deactivate)
elif activate:
# if extension never loaded (was deactivated on server startup)
if ext_id not in sys.modules.keys():
# run extension start-up routine
core_app_extra.register_new_ext_routes(ext)
settings.lnbits_deactivated_extensions = list(
filter(
lambda e: e != activate, settings.lnbits_deactivated_extensions
)
)
settings.lnbits_deactivated_extensions.remove(activate)
await update_installed_extension_state(
ext_id=ext_id, active=activate is not None

View file

@ -15,6 +15,7 @@ from lnbits.core.crud import (
get_account_by_email,
get_account_by_username,
get_user,
get_user_active_extensions_ids,
get_wallet_for_key,
)
from lnbits.core.models import KeyType, User, WalletTypeInfo
@ -88,16 +89,7 @@ class KeyChecker(SecurityBase):
detail="Invalid adminkey.",
)
if (
wallet.user != settings.super_user
and wallet.user not in settings.lnbits_admin_users
and settings.lnbits_admin_extensions
and request["path"].split("/")[1] in settings.lnbits_admin_extensions
):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User not authorized for this extension.",
)
await _check_user_extension_access(wallet.user, request["path"])
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
return WalletTypeInfo(key_type, wallet)
@ -161,15 +153,7 @@ async def check_user_exists(
user = await get_user(account.id)
assert user, "User not found for account."
if (
user.id != settings.super_user
and user.id not in settings.lnbits_admin_users
and settings.lnbits_admin_extensions
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
):
raise HTTPException(
HTTPStatus.UNAUTHORIZED, "User not authorized for extension."
)
await _check_user_extension_access(user.id, r["path"])
return user
@ -226,6 +210,28 @@ def parse_filters(model: Type[TFilterModel]):
return dependency
async def _check_user_extension_access(user_id: str, current_path: str):
"""
Check if the user has access to a particular extension.
Raises HTTP Forbidden if the user is not allowed.
"""
path = current_path.split("/")
ext_id = path[3] if path[1] == "upgrades" else path[1]
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
raise HTTPException(
HTTPStatus.FORBIDDEN,
f"User not authorized for extension '{ext_id}'.",
)
if settings.is_extension_id(ext_id):
ext_ids = await get_user_active_extensions_ids(user_id)
if ext_id not in ext_ids:
raise HTTPException(
HTTPStatus.FORBIDDEN,
f"User extension '{ext_id}' not enabled.",
)
async def _get_account_from_token(access_token):
try:
payload = jwt.decode(access_token, settings.auth_secret_key, "HS256")

View file

@ -479,22 +479,15 @@ class InstallableExtension(BaseModel):
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
def notify_upgrade(self) -> None:
def notify_upgrade(self, upgrade_hash: Optional[str]) -> None:
"""
Update the list of upgraded extensions. The middleware will perform
redirects based on this
"""
if upgrade_hash:
settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}")
clean_upgraded_exts = list(
filter(
lambda old_ext: not old_ext.endswith(f"/{self.id}"),
settings.lnbits_upgraded_extensions,
)
)
settings.lnbits_upgraded_extensions = [
*clean_upgraded_exts,
f"{self.hash}/{self.id}",
]
settings.lnbits_all_extensions_ids.add(self.id)
def clean_extension_files(self):
# remove downloaded archive

View file

@ -63,12 +63,15 @@ class ExtensionsInstallSettings(LNbitsSettings):
class InstalledExtensionsSettings(LNbitsSettings):
# installed extensions that have been deactivated
lnbits_deactivated_extensions: list[str] = Field(default=[])
lnbits_deactivated_extensions: set[str] = Field(default=[])
# upgraded extensions that require API redirects
lnbits_upgraded_extensions: list[str] = Field(default=[])
lnbits_upgraded_extensions: set[str] = Field(default=[])
# list of redirects that extensions want to perform
lnbits_extensions_redirects: list[Any] = Field(default=[])
# list of all extension ids
lnbits_all_extensions_ids: set[Any] = Field(default=[])
def extension_upgrade_path(self, ext_id: str) -> Optional[str]:
return next(
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")),
@ -481,7 +484,7 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
case_sensitive = False
json_loads = list_parse_fallback
def is_user_allowed(self, user_id: str):
def is_user_allowed(self, user_id: str) -> bool:
return (
len(self.lnbits_allowed_users) == 0
or user_id in self.lnbits_allowed_users
@ -489,6 +492,15 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
or user_id == self.super_user
)
def is_admin_user(self, user_id: str) -> bool:
return user_id in self.lnbits_admin_users or user_id == self.super_user
def is_admin_extension(self, ext_id: str) -> bool:
return ext_id in self.lnbits_admin_extensions
def is_extension_id(self, ext_id: str) -> bool:
return ext_id in self.lnbits_all_extensions_ids
class SuperSettings(EditableSettings):
super_user: str