[REFACTOR] WalletType into enum (#1888)

* [REFACTOR] WalletType into enum

- move wallettype models from decorators.py into models.py
- use enum instead of int
- use HTTPStatus for consistency

* Update lnbits/core/models.py

Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com>

* use dataclass

---------

Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com>
This commit is contained in:
dni ⚡ 2023-08-23 12:41:22 +02:00 committed by GitHub
parent 3a653630f1
commit c4da1dfdce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 36 deletions

View file

@ -7,6 +7,7 @@ from uuid import UUID, uuid4
import shortuuid
from lnbits import bolt11
from lnbits.core.models import WalletType
from lnbits.db import Connection, Database, Filters, Page
from lnbits.extension_manager import InstallableExtension
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
@ -294,7 +295,7 @@ async def get_wallet_for_key(
if not row:
return None
if key_type == "admin" and row["adminkey"] != key:
if key_type == WalletType.admin and row["adminkey"] != key:
return None
return Wallet(**row)

View file

@ -3,6 +3,8 @@ import hashlib
import hmac
import json
import time
from dataclasses import dataclass
from enum import Enum
from sqlite3 import Row
from typing import Callable, Dict, List, Optional
@ -59,6 +61,22 @@ class Wallet(BaseModel):
return await get_standalone_payment(payment_hash)
class WalletType(Enum):
admin = 0
invoice = 1
invalid = 2
# backwards compatibility
def __eq__(self, other):
return self.value == other
@dataclass
class WalletTypeInfo:
wallet_type: WalletType
wallet: Wallet
class User(BaseModel):
id: str
email: Optional[str] = None

View file

@ -38,6 +38,7 @@ from lnbits.core.models import (
PaymentFilters,
User,
Wallet,
WalletType,
)
from lnbits.db import Filters, Page
from lnbits.decorators import (
@ -102,7 +103,7 @@ async def health():
@core_app.get("/api/v1/wallet")
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
if wallet.wallet_type == 0:
if wallet.wallet_type == WalletType.admin:
return {
"id": wallet.wallet.id,
"name": wallet.wallet.name,
@ -318,7 +319,7 @@ async def api_payments_create(
wallet: WalletTypeInfo = Depends(require_invoice_key),
invoiceData: CreateInvoice = Body(...),
):
if invoiceData.out is True and wallet.wallet_type == 0:
if invoiceData.out is True and wallet.wallet_type == WalletType.admin:
if not invoiceData.bolt11:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,

View file

@ -1,7 +1,7 @@
from http import HTTPStatus
from typing import Literal, Optional, Type
from fastapi import Query, Request, Security, status
from fastapi import Query, Request, Security
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security import APIKeyHeader, APIKeyQuery
@ -9,7 +9,7 @@ from fastapi.security.base import SecurityBase
from pydantic.types import UUID4
from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet
from lnbits.core.models import User, WalletType, WalletTypeInfo
from lnbits.db import Filter, Filters, TFilterModel
from lnbits.requestvars import g
from lnbits.settings import settings
@ -25,7 +25,7 @@ class KeyChecker(SecurityBase):
):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
self._key_type = "invoice"
self._key_type = WalletType.invoice
self._api_key = api_key
if api_key:
key = APIKey(
@ -39,7 +39,7 @@ class KeyChecker(SecurityBase):
name="X-API-KEY",
description="Wallet API Key - HEADER",
)
self.wallet = None # type: ignore
self.wallet = None
self.model: APIKey = key
async def __call__(self, request: Request):
@ -81,7 +81,7 @@ class WalletInvoiceKeyChecker(KeyChecker):
api_key: Optional[str] = None,
):
super().__init__(scheme_name, auto_error, api_key)
self._key_type = "invoice"
self._key_type = WalletType.invoice
class WalletAdminKeyChecker(KeyChecker):
@ -100,16 +100,7 @@ class WalletAdminKeyChecker(KeyChecker):
api_key: Optional[str] = None,
):
super().__init__(scheme_name, auto_error, api_key)
self._key_type = "admin"
class WalletTypeInfo:
wallet_type: int
wallet: Wallet
def __init__(self, wallet_type: int, wallet: Wallet) -> None:
self.wallet_type = wallet_type
self.wallet = wallet
self._key_type = WalletType.admin
api_key_header = APIKeyHeader(
@ -129,11 +120,6 @@ async def get_key_type(
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
) -> WalletTypeInfo:
# 0: admin
# 1: invoice
# 2: invalid
pathname = r["path"].split("/")[1]
token = api_key_header or api_key_query
if not token:
@ -142,33 +128,34 @@ async def get_key_type(
detail="Invoice (or Admin) key required.",
)
for typenr, WalletChecker in zip(
[0, 1], [WalletAdminKeyChecker, WalletInvoiceKeyChecker]
for wallet_type, WalletChecker in zip(
[WalletType.admin, WalletType.invoice],
[WalletAdminKeyChecker, WalletInvoiceKeyChecker],
):
try:
checker = WalletChecker(api_key=token)
await checker.__call__(r)
wallet = WalletTypeInfo(typenr, checker.wallet) # type: ignore
if wallet is None or wallet.wallet is None:
if checker.wallet is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
)
wallet = WalletTypeInfo(wallet_type, checker.wallet)
if (
wallet.wallet.user != settings.super_user
and wallet.wallet.user not in settings.lnbits_admin_users
) and (
settings.lnbits_admin_extensions
and pathname in settings.lnbits_admin_extensions
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User not authorized for this extension.",
)
return wallet
except HTTPException as e:
if e.status_code == HTTPStatus.BAD_REQUEST:
except HTTPException as exc:
if exc.status_code == HTTPStatus.BAD_REQUEST:
raise
elif e.status_code == HTTPStatus.UNAUTHORIZED:
elif exc.status_code == HTTPStatus.UNAUTHORIZED:
# we pass this in case it is not an invoice key, nor an admin key, and then return NOT_FOUND at the end of this block
pass
else:
@ -199,7 +186,7 @@ async def require_admin_key(
# If wallet type is not admin then return the unauthorized status
# This also covers when the user passes an invalid key type
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin key required."
status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required."
)
else:
return wallet
@ -220,11 +207,12 @@ async def require_invoice_key(
wallet = await get_key_type(r, token)
if wallet.wallet_type > 1:
# If wallet type is not invoice then return the unauthorized status
# This also covers when the user passes an invalid key type
if (
wallet.wallet_type != WalletType.admin
and wallet.wallet_type != WalletType.invoice
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
status_code=HTTPStatus.UNAUTHORIZED,
detail="Invoice (or Admin) key required.",
)
else: