[REFACTOR] WalletType into enum (#1888)

* [REFACTOR] WalletType into enum

- move wallettype models from decorators.py into models.py
- use enum instead of int
- use HTTPStatus for consistency

* Update lnbits/core/models.py

Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com>

* use dataclass

---------

Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com>
This commit is contained in:
dni ⚡ 2023-08-23 12:41:22 +02:00 committed by GitHub
parent 3a653630f1
commit c4da1dfdce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 36 deletions

View file

@ -7,6 +7,7 @@ from uuid import UUID, uuid4
import shortuuid import shortuuid
from lnbits import bolt11 from lnbits import bolt11
from lnbits.core.models import WalletType
from lnbits.db import Connection, Database, Filters, Page from lnbits.db import Connection, Database, Filters, Page
from lnbits.extension_manager import InstallableExtension from lnbits.extension_manager import InstallableExtension
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
@ -294,7 +295,7 @@ async def get_wallet_for_key(
if not row: if not row:
return None return None
if key_type == "admin" and row["adminkey"] != key: if key_type == WalletType.admin and row["adminkey"] != key:
return None return None
return Wallet(**row) return Wallet(**row)

View file

@ -3,6 +3,8 @@ import hashlib
import hmac import hmac
import json import json
import time import time
from dataclasses import dataclass
from enum import Enum
from sqlite3 import Row from sqlite3 import Row
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
@ -59,6 +61,22 @@ class Wallet(BaseModel):
return await get_standalone_payment(payment_hash) return await get_standalone_payment(payment_hash)
class WalletType(Enum):
admin = 0
invoice = 1
invalid = 2
# backwards compatibility
def __eq__(self, other):
return self.value == other
@dataclass
class WalletTypeInfo:
wallet_type: WalletType
wallet: Wallet
class User(BaseModel): class User(BaseModel):
id: str id: str
email: Optional[str] = None email: Optional[str] = None

View file

@ -38,6 +38,7 @@ from lnbits.core.models import (
PaymentFilters, PaymentFilters,
User, User,
Wallet, Wallet,
WalletType,
) )
from lnbits.db import Filters, Page from lnbits.db import Filters, Page
from lnbits.decorators import ( from lnbits.decorators import (
@ -102,7 +103,7 @@ async def health():
@core_app.get("/api/v1/wallet") @core_app.get("/api/v1/wallet")
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)): async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
if wallet.wallet_type == 0: if wallet.wallet_type == WalletType.admin:
return { return {
"id": wallet.wallet.id, "id": wallet.wallet.id,
"name": wallet.wallet.name, "name": wallet.wallet.name,
@ -318,7 +319,7 @@ async def api_payments_create(
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
invoiceData: CreateInvoice = Body(...), invoiceData: CreateInvoice = Body(...),
): ):
if invoiceData.out is True and wallet.wallet_type == 0: if invoiceData.out is True and wallet.wallet_type == WalletType.admin:
if not invoiceData.bolt11: if not invoiceData.bolt11:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,

View file

@ -1,7 +1,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Literal, Optional, Type from typing import Literal, Optional, Type
from fastapi import Query, Request, Security, status from fastapi import Query, Request, Security
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.security import APIKeyHeader, APIKeyQuery
@ -9,7 +9,7 @@ from fastapi.security.base import SecurityBase
from pydantic.types import UUID4 from pydantic.types import UUID4
from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet from lnbits.core.models import User, WalletType, WalletTypeInfo
from lnbits.db import Filter, Filters, TFilterModel from lnbits.db import Filter, Filters, TFilterModel
from lnbits.requestvars import g from lnbits.requestvars import g
from lnbits.settings import settings from lnbits.settings import settings
@ -25,7 +25,7 @@ class KeyChecker(SecurityBase):
): ):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
self._key_type = "invoice" self._key_type = WalletType.invoice
self._api_key = api_key self._api_key = api_key
if api_key: if api_key:
key = APIKey( key = APIKey(
@ -39,7 +39,7 @@ class KeyChecker(SecurityBase):
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - HEADER", description="Wallet API Key - HEADER",
) )
self.wallet = None # type: ignore self.wallet = None
self.model: APIKey = key self.model: APIKey = key
async def __call__(self, request: Request): async def __call__(self, request: Request):
@ -81,7 +81,7 @@ class WalletInvoiceKeyChecker(KeyChecker):
api_key: Optional[str] = None, api_key: Optional[str] = None,
): ):
super().__init__(scheme_name, auto_error, api_key) super().__init__(scheme_name, auto_error, api_key)
self._key_type = "invoice" self._key_type = WalletType.invoice
class WalletAdminKeyChecker(KeyChecker): class WalletAdminKeyChecker(KeyChecker):
@ -100,16 +100,7 @@ class WalletAdminKeyChecker(KeyChecker):
api_key: Optional[str] = None, api_key: Optional[str] = None,
): ):
super().__init__(scheme_name, auto_error, api_key) super().__init__(scheme_name, auto_error, api_key)
self._key_type = "admin" self._key_type = WalletType.admin
class WalletTypeInfo:
wallet_type: int
wallet: Wallet
def __init__(self, wallet_type: int, wallet: Wallet) -> None:
self.wallet_type = wallet_type
self.wallet = wallet
api_key_header = APIKeyHeader( api_key_header = APIKeyHeader(
@ -129,11 +120,6 @@ async def get_key_type(
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query), api_key_query: str = Security(api_key_query),
) -> WalletTypeInfo: ) -> WalletTypeInfo:
# 0: admin
# 1: invoice
# 2: invalid
pathname = r["path"].split("/")[1]
token = api_key_header or api_key_query token = api_key_header or api_key_query
if not token: if not token:
@ -142,33 +128,34 @@ async def get_key_type(
detail="Invoice (or Admin) key required.", detail="Invoice (or Admin) key required.",
) )
for typenr, WalletChecker in zip( for wallet_type, WalletChecker in zip(
[0, 1], [WalletAdminKeyChecker, WalletInvoiceKeyChecker] [WalletType.admin, WalletType.invoice],
[WalletAdminKeyChecker, WalletInvoiceKeyChecker],
): ):
try: try:
checker = WalletChecker(api_key=token) checker = WalletChecker(api_key=token)
await checker.__call__(r) await checker.__call__(r)
wallet = WalletTypeInfo(typenr, checker.wallet) # type: ignore if checker.wallet is None:
if wallet is None or wallet.wallet is None:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist." status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
) )
wallet = WalletTypeInfo(wallet_type, checker.wallet)
if ( if (
wallet.wallet.user != settings.super_user wallet.wallet.user != settings.super_user
and wallet.wallet.user not in settings.lnbits_admin_users and wallet.wallet.user not in settings.lnbits_admin_users
) and ( ) and (
settings.lnbits_admin_extensions settings.lnbits_admin_extensions
and pathname in settings.lnbits_admin_extensions and r["path"].split("/")[1] in settings.lnbits_admin_extensions
): ):
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, status_code=HTTPStatus.FORBIDDEN,
detail="User not authorized for this extension.", detail="User not authorized for this extension.",
) )
return wallet return wallet
except HTTPException as e: except HTTPException as exc:
if e.status_code == HTTPStatus.BAD_REQUEST: if exc.status_code == HTTPStatus.BAD_REQUEST:
raise raise
elif e.status_code == HTTPStatus.UNAUTHORIZED: elif exc.status_code == HTTPStatus.UNAUTHORIZED:
# we pass this in case it is not an invoice key, nor an admin key, and then return NOT_FOUND at the end of this block # we pass this in case it is not an invoice key, nor an admin key, and then return NOT_FOUND at the end of this block
pass pass
else: else:
@ -199,7 +186,7 @@ async def require_admin_key(
# If wallet type is not admin then return the unauthorized status # If wallet type is not admin then return the unauthorized status
# This also covers when the user passes an invalid key type # This also covers when the user passes an invalid key type
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin key required." status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required."
) )
else: else:
return wallet return wallet
@ -220,11 +207,12 @@ async def require_invoice_key(
wallet = await get_key_type(r, token) wallet = await get_key_type(r, token)
if wallet.wallet_type > 1: if (
# If wallet type is not invoice then return the unauthorized status wallet.wallet_type != WalletType.admin
# This also covers when the user passes an invalid key type and wallet.wallet_type != WalletType.invoice
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=HTTPStatus.UNAUTHORIZED,
detail="Invoice (or Admin) key required.", detail="Invoice (or Admin) key required.",
) )
else: else: