diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 17eacef1b..72f4b9dff 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -3,7 +3,7 @@ from typing import Annotated, Literal, Optional, Type, Union from fastapi import Cookie, Depends, Query, Request, Security from fastapi.exceptions import HTTPException -from fastapi.openapi.models import APIKey, APIKeyIn +from fastapi.openapi.models import APIKey, APIKeyIn, SecuritySchemeType from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer from fastapi.security.base import SecurityBase from jose import ExpiredSignatureError, JWTError, jwt @@ -17,14 +17,13 @@ from lnbits.core.crud import ( get_user, get_wallet_for_key, ) -from lnbits.core.models import User, WalletType, WalletTypeInfo +from lnbits.core.models import User, Wallet, WalletType, WalletTypeInfo from lnbits.db import Filter, Filters, TFilterModel from lnbits.settings import AuthMethods, settings oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False) -# TODO: fix type ignores class KeyChecker(SecurityBase): def __init__( self, @@ -33,23 +32,25 @@ class KeyChecker(SecurityBase): api_key: Optional[str] = None, ): self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error - self._key_type = WalletType.invoice + self.auto_error: bool = auto_error + self._key_type: WalletType = WalletType.invoice self._api_key = api_key if api_key: - key = APIKey( - **{"in": APIKeyIn.query}, # type: ignore + openapi_model = APIKey( + **{"in": APIKeyIn.query}, + type=SecuritySchemeType.apiKey, name="X-API-KEY", description="Wallet API Key - QUERY", ) else: - key = APIKey( - **{"in": APIKeyIn.header}, # type: ignore + openapi_model = APIKey( + **{"in": APIKeyIn.header}, + type=SecuritySchemeType.apiKey, name="X-API-KEY", description="Wallet API Key - HEADER", ) - self.wallet = None - self.model: APIKey = key + self.wallet: Optional[Wallet] = None + self.model: APIKey = openapi_model async def __call__(self, request: Request): try: @@ -67,7 +68,7 @@ class KeyChecker(SecurityBase): status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or wallet.", ) - self.wallet = wallet # type: ignore + self.wallet = wallet except KeyError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="`X-API-KEY` header missing."