diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index ea9edfe95..3656df9e7 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -7,6 +7,7 @@ from uuid import UUID, uuid4 import shortuuid from lnbits import bolt11 +from lnbits.core.models import WalletType from lnbits.db import Connection, Database, Filters, Page from lnbits.extension_manager import InstallableExtension from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings @@ -294,7 +295,7 @@ async def get_wallet_for_key( if not row: return None - if key_type == "admin" and row["adminkey"] != key: + if key_type == WalletType.admin and row["adminkey"] != key: return None return Wallet(**row) diff --git a/lnbits/core/models.py b/lnbits/core/models.py index c9fb56a36..c1ee7097b 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -3,6 +3,8 @@ import hashlib import hmac import json import time +from dataclasses import dataclass +from enum import Enum from sqlite3 import Row from typing import Callable, Dict, List, Optional @@ -59,6 +61,22 @@ class Wallet(BaseModel): return await get_standalone_payment(payment_hash) +class WalletType(Enum): + admin = 0 + invoice = 1 + invalid = 2 + + # backwards compatibility + def __eq__(self, other): + return self.value == other + + +@dataclass +class WalletTypeInfo: + wallet_type: WalletType + wallet: Wallet + + class User(BaseModel): id: str email: Optional[str] = None diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 08c552f18..801731a62 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -38,6 +38,7 @@ from lnbits.core.models import ( PaymentFilters, User, Wallet, + WalletType, ) from lnbits.db import Filters, Page from lnbits.decorators import ( @@ -102,7 +103,7 @@ async def health(): @core_app.get("/api/v1/wallet") async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)): - if wallet.wallet_type == 0: + if wallet.wallet_type == WalletType.admin: return { "id": wallet.wallet.id, "name": wallet.wallet.name, @@ -318,7 +319,7 @@ async def api_payments_create( wallet: WalletTypeInfo = Depends(require_invoice_key), invoiceData: CreateInvoice = Body(...), ): - if invoiceData.out is True and wallet.wallet_type == 0: + if invoiceData.out is True and wallet.wallet_type == WalletType.admin: if not invoiceData.bolt11: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, diff --git a/lnbits/decorators.py b/lnbits/decorators.py index b3c60570d..70f653d46 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -1,7 +1,7 @@ from http import HTTPStatus from typing import Literal, Optional, Type -from fastapi import Query, Request, Security, status +from fastapi import Query, Request, Security from fastapi.exceptions import HTTPException from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security import APIKeyHeader, APIKeyQuery @@ -9,7 +9,7 @@ from fastapi.security.base import SecurityBase from pydantic.types import UUID4 from lnbits.core.crud import get_user, get_wallet_for_key -from lnbits.core.models import User, Wallet +from lnbits.core.models import User, WalletType, WalletTypeInfo from lnbits.db import Filter, Filters, TFilterModel from lnbits.requestvars import g from lnbits.settings import settings @@ -25,7 +25,7 @@ class KeyChecker(SecurityBase): ): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error - self._key_type = "invoice" + self._key_type = WalletType.invoice self._api_key = api_key if api_key: key = APIKey( @@ -39,7 +39,7 @@ class KeyChecker(SecurityBase): name="X-API-KEY", description="Wallet API Key - HEADER", ) - self.wallet = None # type: ignore + self.wallet = None self.model: APIKey = key async def __call__(self, request: Request): @@ -81,7 +81,7 @@ class WalletInvoiceKeyChecker(KeyChecker): api_key: Optional[str] = None, ): super().__init__(scheme_name, auto_error, api_key) - self._key_type = "invoice" + self._key_type = WalletType.invoice class WalletAdminKeyChecker(KeyChecker): @@ -100,16 +100,7 @@ class WalletAdminKeyChecker(KeyChecker): api_key: Optional[str] = None, ): super().__init__(scheme_name, auto_error, api_key) - self._key_type = "admin" - - -class WalletTypeInfo: - wallet_type: int - wallet: Wallet - - def __init__(self, wallet_type: int, wallet: Wallet) -> None: - self.wallet_type = wallet_type - self.wallet = wallet + self._key_type = WalletType.admin api_key_header = APIKeyHeader( @@ -129,11 +120,6 @@ async def get_key_type( api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query), ) -> WalletTypeInfo: - # 0: admin - # 1: invoice - # 2: invalid - pathname = r["path"].split("/")[1] - token = api_key_header or api_key_query if not token: @@ -142,33 +128,34 @@ async def get_key_type( detail="Invoice (or Admin) key required.", ) - for typenr, WalletChecker in zip( - [0, 1], [WalletAdminKeyChecker, WalletInvoiceKeyChecker] + for wallet_type, WalletChecker in zip( + [WalletType.admin, WalletType.invoice], + [WalletAdminKeyChecker, WalletInvoiceKeyChecker], ): try: checker = WalletChecker(api_key=token) await checker.__call__(r) - wallet = WalletTypeInfo(typenr, checker.wallet) # type: ignore - if wallet is None or wallet.wallet is None: + if checker.wallet is None: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist." ) + wallet = WalletTypeInfo(wallet_type, checker.wallet) if ( wallet.wallet.user != settings.super_user and wallet.wallet.user not in settings.lnbits_admin_users ) and ( settings.lnbits_admin_extensions - and pathname in settings.lnbits_admin_extensions + and r["path"].split("/")[1] in settings.lnbits_admin_extensions ): raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail="User not authorized for this extension.", ) return wallet - except HTTPException as e: - if e.status_code == HTTPStatus.BAD_REQUEST: + except HTTPException as exc: + if exc.status_code == HTTPStatus.BAD_REQUEST: raise - elif e.status_code == HTTPStatus.UNAUTHORIZED: + elif exc.status_code == HTTPStatus.UNAUTHORIZED: # we pass this in case it is not an invoice key, nor an admin key, and then return NOT_FOUND at the end of this block pass else: @@ -199,7 +186,7 @@ async def require_admin_key( # If wallet type is not admin then return the unauthorized status # This also covers when the user passes an invalid key type raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin key required." + status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required." ) else: return wallet @@ -220,11 +207,12 @@ async def require_invoice_key( wallet = await get_key_type(r, token) - if wallet.wallet_type > 1: - # If wallet type is not invoice then return the unauthorized status - # This also covers when the user passes an invalid key type + if ( + wallet.wallet_type != WalletType.admin + and wallet.wallet_type != WalletType.invoice + ): raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=HTTPStatus.UNAUTHORIZED, detail="Invoice (or Admin) key required.", ) else: