diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index b234b5d41..c1253802b 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -62,7 +62,9 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[ return User( id=user["id"], email=user["email"], - extensions=[e[0] for e in extensions], + extensions=[ + e[0] for e in extensions if User.is_extension_for_user(e[0], user["id"]) + ], wallets=[Wallet(**w) for w in wallets], admin=user["id"] == settings.super_user or user["id"] in settings.lnbits_admin_users, diff --git a/lnbits/core/models.py b/lnbits/core/models.py index c3ff6fd9c..4bcdd3311 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from lnbits.db import Connection from lnbits.helpers import url_for -from lnbits.settings import get_wallet_class +from lnbits.settings import get_wallet_class, settings from lnbits.wallets.base import PaymentStatus @@ -75,6 +75,16 @@ class User(BaseModel): w = [wallet for wallet in self.wallets if wallet.id == wallet_id] return w[0] if w else None + @classmethod + def is_extension_for_user(cls, ext: str, user: str) -> bool: + if ext not in settings.lnbits_admin_extensions: + return True + if user == settings.super_user: + return True + if user in settings.lnbits_admin_users: + return True + return False + class Payment(BaseModel): checking_id: str diff --git a/lnbits/core/templates/admin/_tab_server.html b/lnbits/core/templates/admin/_tab_server.html index 814a490f4..f4d61bbf6 100644 --- a/lnbits/core/templates/admin/_tab_server.html +++ b/lnbits/core/templates/admin/_tab_server.html @@ -63,7 +63,7 @@ multiple hint="Extensions only user with admin privileges can use" label="Admin extensions" - :options="g.extensions.map(e => e.name)" + :options="g.extensions.map(e => e.code)" >
diff --git a/lnbits/middleware.py b/lnbits/middleware.py index 93a5671c4..2815ddde9 100644 --- a/lnbits/middleware.py +++ b/lnbits/middleware.py @@ -1,9 +1,11 @@ from http import HTTPStatus -from typing import List, Tuple +from typing import Any, List, Tuple, Union +from urllib.parse import parse_qs -from fastapi.responses import JSONResponse +from fastapi.responses import HTMLResponse, JSONResponse from starlette.types import ASGIApp, Receive, Scope, Send +from lnbits.helpers import template_renderer from lnbits.settings import settings @@ -28,11 +30,19 @@ class InstalledExtensionMiddleware: path_type = None rest = [] + headers = scope.get("headers", []) + # block path for all users if the extension is disabled if path_name in settings.lnbits_deactivated_extensions: - response = JSONResponse( - status_code=HTTPStatus.NOT_FOUND, - content={"detail": f"Extension '{path_name}' disabled"}, + response = self._response_by_accepted_type( + headers, f"Extension '{path_name}' disabled", HTTPStatus.NOT_FOUND + ) + await response(scope, receive, send) + return + + if not self._user_allowed_to_extension(path_name, scope): + response = self._response_by_accepted_type( + headers, "User not authorized.", HTTPStatus.FORBIDDEN ) await response(scope, receive, send) return @@ -52,6 +62,53 @@ class InstalledExtensionMiddleware: await self.app(scope, receive, send) + def _user_allowed_to_extension(self, ext_name: str, scope) -> bool: + if ext_name not in settings.lnbits_admin_extensions: + return True + if "query_string" not in scope: + return True + + # parse the URL query string into a `dict` + q = parse_qs(scope["query_string"].decode("UTF-8")) + user = q.get("usr", [""])[0] + if not user: + return True + + if user == settings.super_user or user in settings.lnbits_admin_users: + return True + + return False + + def _response_by_accepted_type( + self, headers: List[Any], msg: str, status_code: HTTPStatus + ) -> Union[HTMLResponse, JSONResponse]: + """ + Build an HTTP response containing the `msg` as HTTP body and the `status_code` as HTTP code. + If the `accept` HTTP header is present int the request and contains the value of `text/html` + then return an `HTMLResponse`, otherwise return an `JSONResponse`. + """ + accept_header: str = next( + ( + h[1].decode("UTF-8") + for h in headers + if len(h) >= 2 and h[0].decode("UTF-8") == "accept" + ), + "", + ) + + if "text/html" in [a for a in accept_header.split(",")]: + return HTMLResponse( + status_code=status_code, + content=template_renderer() + .TemplateResponse("error.html", {"request": {}, "err": msg}) + .body, + ) + + return JSONResponse( + status_code=status_code, + content={"detail": msg}, + ) + class ExtensionsRedirectMiddleware: # Extensions are allowed to specify redirect paths.