mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-02-24 06:48:02 +01:00
[REFACTOR] WalletType into enum (#1888)
* [REFACTOR] WalletType into enum - move wallettype models from decorators.py into models.py - use enum instead of int - use HTTPStatus for consistency * Update lnbits/core/models.py Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com> * use dataclass --------- Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com>
This commit is contained in:
parent
3a653630f1
commit
c4da1dfdce
4 changed files with 44 additions and 36 deletions
|
@ -7,6 +7,7 @@ from uuid import UUID, uuid4
|
|||
import shortuuid
|
||||
|
||||
from lnbits import bolt11
|
||||
from lnbits.core.models import WalletType
|
||||
from lnbits.db import Connection, Database, Filters, Page
|
||||
from lnbits.extension_manager import InstallableExtension
|
||||
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
|
||||
|
@ -294,7 +295,7 @@ async def get_wallet_for_key(
|
|||
if not row:
|
||||
return None
|
||||
|
||||
if key_type == "admin" and row["adminkey"] != key:
|
||||
if key_type == WalletType.admin and row["adminkey"] != key:
|
||||
return None
|
||||
|
||||
return Wallet(**row)
|
||||
|
|
|
@ -3,6 +3,8 @@ import hashlib
|
|||
import hmac
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from sqlite3 import Row
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
|
@ -59,6 +61,22 @@ class Wallet(BaseModel):
|
|||
return await get_standalone_payment(payment_hash)
|
||||
|
||||
|
||||
class WalletType(Enum):
|
||||
admin = 0
|
||||
invoice = 1
|
||||
invalid = 2
|
||||
|
||||
# backwards compatibility
|
||||
def __eq__(self, other):
|
||||
return self.value == other
|
||||
|
||||
|
||||
@dataclass
|
||||
class WalletTypeInfo:
|
||||
wallet_type: WalletType
|
||||
wallet: Wallet
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
email: Optional[str] = None
|
||||
|
|
|
@ -38,6 +38,7 @@ from lnbits.core.models import (
|
|||
PaymentFilters,
|
||||
User,
|
||||
Wallet,
|
||||
WalletType,
|
||||
)
|
||||
from lnbits.db import Filters, Page
|
||||
from lnbits.decorators import (
|
||||
|
@ -102,7 +103,7 @@ async def health():
|
|||
|
||||
@core_app.get("/api/v1/wallet")
|
||||
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
|
||||
if wallet.wallet_type == 0:
|
||||
if wallet.wallet_type == WalletType.admin:
|
||||
return {
|
||||
"id": wallet.wallet.id,
|
||||
"name": wallet.wallet.name,
|
||||
|
@ -318,7 +319,7 @@ async def api_payments_create(
|
|||
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||
invoiceData: CreateInvoice = Body(...),
|
||||
):
|
||||
if invoiceData.out is True and wallet.wallet_type == 0:
|
||||
if invoiceData.out is True and wallet.wallet_type == WalletType.admin:
|
||||
if not invoiceData.bolt11:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Literal, Optional, Type
|
||||
|
||||
from fastapi import Query, Request, Security, status
|
||||
from fastapi import Query, Request, Security
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
|
@ -9,7 +9,7 @@ from fastapi.security.base import SecurityBase
|
|||
from pydantic.types import UUID4
|
||||
|
||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||
from lnbits.core.models import User, Wallet
|
||||
from lnbits.core.models import User, WalletType, WalletTypeInfo
|
||||
from lnbits.db import Filter, Filters, TFilterModel
|
||||
from lnbits.requestvars import g
|
||||
from lnbits.settings import settings
|
||||
|
@ -25,7 +25,7 @@ class KeyChecker(SecurityBase):
|
|||
):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
self._key_type = "invoice"
|
||||
self._key_type = WalletType.invoice
|
||||
self._api_key = api_key
|
||||
if api_key:
|
||||
key = APIKey(
|
||||
|
@ -39,7 +39,7 @@ class KeyChecker(SecurityBase):
|
|||
name="X-API-KEY",
|
||||
description="Wallet API Key - HEADER",
|
||||
)
|
||||
self.wallet = None # type: ignore
|
||||
self.wallet = None
|
||||
self.model: APIKey = key
|
||||
|
||||
async def __call__(self, request: Request):
|
||||
|
@ -81,7 +81,7 @@ class WalletInvoiceKeyChecker(KeyChecker):
|
|||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(scheme_name, auto_error, api_key)
|
||||
self._key_type = "invoice"
|
||||
self._key_type = WalletType.invoice
|
||||
|
||||
|
||||
class WalletAdminKeyChecker(KeyChecker):
|
||||
|
@ -100,16 +100,7 @@ class WalletAdminKeyChecker(KeyChecker):
|
|||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(scheme_name, auto_error, api_key)
|
||||
self._key_type = "admin"
|
||||
|
||||
|
||||
class WalletTypeInfo:
|
||||
wallet_type: int
|
||||
wallet: Wallet
|
||||
|
||||
def __init__(self, wallet_type: int, wallet: Wallet) -> None:
|
||||
self.wallet_type = wallet_type
|
||||
self.wallet = wallet
|
||||
self._key_type = WalletType.admin
|
||||
|
||||
|
||||
api_key_header = APIKeyHeader(
|
||||
|
@ -129,11 +120,6 @@ async def get_key_type(
|
|||
api_key_header: str = Security(api_key_header),
|
||||
api_key_query: str = Security(api_key_query),
|
||||
) -> WalletTypeInfo:
|
||||
# 0: admin
|
||||
# 1: invoice
|
||||
# 2: invalid
|
||||
pathname = r["path"].split("/")[1]
|
||||
|
||||
token = api_key_header or api_key_query
|
||||
|
||||
if not token:
|
||||
|
@ -142,33 +128,34 @@ async def get_key_type(
|
|||
detail="Invoice (or Admin) key required.",
|
||||
)
|
||||
|
||||
for typenr, WalletChecker in zip(
|
||||
[0, 1], [WalletAdminKeyChecker, WalletInvoiceKeyChecker]
|
||||
for wallet_type, WalletChecker in zip(
|
||||
[WalletType.admin, WalletType.invoice],
|
||||
[WalletAdminKeyChecker, WalletInvoiceKeyChecker],
|
||||
):
|
||||
try:
|
||||
checker = WalletChecker(api_key=token)
|
||||
await checker.__call__(r)
|
||||
wallet = WalletTypeInfo(typenr, checker.wallet) # type: ignore
|
||||
if wallet is None or wallet.wallet is None:
|
||||
if checker.wallet is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
|
||||
)
|
||||
wallet = WalletTypeInfo(wallet_type, checker.wallet)
|
||||
if (
|
||||
wallet.wallet.user != settings.super_user
|
||||
and wallet.wallet.user not in settings.lnbits_admin_users
|
||||
) and (
|
||||
settings.lnbits_admin_extensions
|
||||
and pathname in settings.lnbits_admin_extensions
|
||||
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="User not authorized for this extension.",
|
||||
)
|
||||
return wallet
|
||||
except HTTPException as e:
|
||||
if e.status_code == HTTPStatus.BAD_REQUEST:
|
||||
except HTTPException as exc:
|
||||
if exc.status_code == HTTPStatus.BAD_REQUEST:
|
||||
raise
|
||||
elif e.status_code == HTTPStatus.UNAUTHORIZED:
|
||||
elif exc.status_code == HTTPStatus.UNAUTHORIZED:
|
||||
# we pass this in case it is not an invoice key, nor an admin key, and then return NOT_FOUND at the end of this block
|
||||
pass
|
||||
else:
|
||||
|
@ -199,7 +186,7 @@ async def require_admin_key(
|
|||
# If wallet type is not admin then return the unauthorized status
|
||||
# This also covers when the user passes an invalid key type
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin key required."
|
||||
status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required."
|
||||
)
|
||||
else:
|
||||
return wallet
|
||||
|
@ -220,11 +207,12 @@ async def require_invoice_key(
|
|||
|
||||
wallet = await get_key_type(r, token)
|
||||
|
||||
if wallet.wallet_type > 1:
|
||||
# If wallet type is not invoice then return the unauthorized status
|
||||
# This also covers when the user passes an invalid key type
|
||||
if (
|
||||
wallet.wallet_type != WalletType.admin
|
||||
and wallet.wallet_type != WalletType.invoice
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
status_code=HTTPStatus.UNAUTHORIZED,
|
||||
detail="Invoice (or Admin) key required.",
|
||||
)
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue