from functools import wraps from http import HTTPStatus from base64 import b64decode from fastapi.security import api_key from pydantic.types import UUID4 from lnbits.core.models import User, Wallet from typing import List, Union from uuid import UUID from cerberus import Validator # type: ignore from fastapi.exceptions import HTTPException from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.params import Security from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from fastapi.security import OAuth2PasswordBearer from fastapi.security.base import SecurityBase from fastapi import status from starlette.requests import Request from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.requestvars import g from lnbits.settings import LNBITS_ALLOWED_USERS class KeyChecker(SecurityBase): def __init__(self, scheme_name: str = None, auto_error: bool = True, api_key: str = None): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error self._key_type = "invoice" self._api_key = api_key if api_key: self.model: APIKey= APIKey( **{"in": APIKeyIn.query}, name="X-API-KEY", description="Wallet API Key - QUERY" ) else: self.model: APIKey= APIKey( **{"in": APIKeyIn.header}, name="X-API-KEY", description="Wallet API Key - HEADER" ) self.wallet = None async def __call__(self, request: Request) -> Wallet: try: key_value = self._api_key if self._api_key else request.headers.get("X-API-KEY") or request.query_params["api-key"] # FIXME: Find another way to validate the key. A fetch from DB should be avoided here. # Also, we should not return the wallet here - thats silly. # Possibly store it in a Redis DB self.wallet = await get_wallet_for_key(key_value, self._key_type) if not self.wallet: raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.") except KeyError: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="`X-API-KEY` header missing.") class WalletInvoiceKeyChecker(KeyChecker): """ WalletInvoiceKeyChecker will ensure that the provided invoice wallet key is correct and populate g().wallet with the wallet for the key in `X-API-key`. The checker will raise an HTTPException when the key is wrong in some ways. """ def __init__(self, scheme_name: str = None, auto_error: bool = True, api_key: str = None): super().__init__(scheme_name, auto_error, api_key) self._key_type = "invoice" class WalletAdminKeyChecker(KeyChecker): """ WalletAdminKeyChecker will ensure that the provided admin wallet key is correct and populate g().wallet with the wallet for the key in `X-API-key`. The checker will raise an HTTPException when the key is wrong in some ways. """ def __init__(self, scheme_name: str = None, auto_error: bool = True, api_key: str = None): super().__init__(scheme_name, auto_error, api_key) self._key_type = "admin" class WalletTypeInfo(): wallet_type: int wallet: Wallet def __init__(self, wallet_type: int, wallet: Wallet) -> None: self.wallet_type = wallet_type self.wallet = wallet api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's") api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's") async def get_key_type(r: Request, api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query)) -> WalletTypeInfo: # 0: admin # 1: invoice # 2: invalid if not api_key_header and not api_key_query: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) token = api_key_header if api_key_header else api_key_query try: checker = WalletAdminKeyChecker(api_key=token) await checker.__call__(r) return WalletTypeInfo(0, checker.wallet) except HTTPException as e: if e.status_code == HTTPStatus.BAD_REQUEST: raise if e.status_code == HTTPStatus.UNAUTHORIZED: pass except: raise try: checker = WalletInvoiceKeyChecker(api_key=token) await checker.__call__(r) return WalletTypeInfo(1, checker.wallet) except HTTPException as e: if e.status_code == HTTPStatus.BAD_REQUEST: raise if e.status_code == HTTPStatus.UNAUTHORIZED: return WalletTypeInfo(2, None) except: raise # api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's") # api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's") # oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # async def get_key_type(r: Request, # token: str = Security(oauth2_scheme), # api_key_header: str = Security(api_key_header), # api_key_query: str = Security(api_key_query)) -> WalletTypeInfo: # # 0: admin # # 1: invoice # # 2: invalid # # print("TOKEN", b64decode(token).decode("utf-8").split(":")) # # key_type, key = b64decode(token).decode("utf-8").split(":") # try: # checker = WalletAdminKeyChecker(api_key=key if token else api_key_query) # await checker.__call__(r) # return WalletTypeInfo(0, checker.wallet) # except HTTPException as e: # if e.status_code == HTTPStatus.BAD_REQUEST: # raise # if e.status_code == HTTPStatus.UNAUTHORIZED: # pass # except: # raise # # try: # checker = WalletInvoiceKeyChecker(api_key=key if token else None) # await checker.__call__(r) # return WalletTypeInfo(1, checker.wallet) # except HTTPException as e: # if e.status_code == HTTPStatus.BAD_REQUEST: # raise # if e.status_code == HTTPStatus.UNAUTHORIZED: # return WalletTypeInfo(2, None) # except: # raise def api_validate_post_request(*, schema: dict): def wrap(view): @wraps(view) async def wrapped_view(**kwargs): if "application/json" not in request.headers["Content-Type"]: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail=jsonify({"message": "Content-Type must be `application/json`."}) ) v = Validator(schema) data = await request.get_json() g().data = {key: data[key] for key in schema.keys() if key in data} if not v.validate(g().data): raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail=jsonify({"message": f"Errors in request data: {v.errors}"}) ) return await view(**kwargs) return wrapped_view return wrap async def check_user_exists(usr: UUID4) -> User: g().user = await get_user(usr.hex) if not g().user: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="User does not exist." ) if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS: raise HTTPException( status_code=HTTPStatus.UNAUTHORIZED, detail="User not authorized." ) return g().user