mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-02-24 06:48:02 +01:00
[REFACTOR] WalletType into enum (#1888)
* [REFACTOR] WalletType into enum - move wallettype models from decorators.py into models.py - use enum instead of int - use HTTPStatus for consistency * Update lnbits/core/models.py Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com> * use dataclass --------- Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com>
This commit is contained in:
parent
3a653630f1
commit
c4da1dfdce
4 changed files with 44 additions and 36 deletions
|
@ -7,6 +7,7 @@ from uuid import UUID, uuid4
|
||||||
import shortuuid
|
import shortuuid
|
||||||
|
|
||||||
from lnbits import bolt11
|
from lnbits import bolt11
|
||||||
|
from lnbits.core.models import WalletType
|
||||||
from lnbits.db import Connection, Database, Filters, Page
|
from lnbits.db import Connection, Database, Filters, Page
|
||||||
from lnbits.extension_manager import InstallableExtension
|
from lnbits.extension_manager import InstallableExtension
|
||||||
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
|
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
|
||||||
|
@ -294,7 +295,7 @@ async def get_wallet_for_key(
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if key_type == "admin" and row["adminkey"] != key:
|
if key_type == WalletType.admin and row["adminkey"] != key:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return Wallet(**row)
|
return Wallet(**row)
|
||||||
|
|
|
@ -3,6 +3,8 @@ import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from sqlite3 import Row
|
from sqlite3 import Row
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -59,6 +61,22 @@ class Wallet(BaseModel):
|
||||||
return await get_standalone_payment(payment_hash)
|
return await get_standalone_payment(payment_hash)
|
||||||
|
|
||||||
|
|
||||||
|
class WalletType(Enum):
|
||||||
|
admin = 0
|
||||||
|
invoice = 1
|
||||||
|
invalid = 2
|
||||||
|
|
||||||
|
# backwards compatibility
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.value == other
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WalletTypeInfo:
|
||||||
|
wallet_type: WalletType
|
||||||
|
wallet: Wallet
|
||||||
|
|
||||||
|
|
||||||
class User(BaseModel):
|
class User(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
email: Optional[str] = None
|
email: Optional[str] = None
|
||||||
|
|
|
@ -38,6 +38,7 @@ from lnbits.core.models import (
|
||||||
PaymentFilters,
|
PaymentFilters,
|
||||||
User,
|
User,
|
||||||
Wallet,
|
Wallet,
|
||||||
|
WalletType,
|
||||||
)
|
)
|
||||||
from lnbits.db import Filters, Page
|
from lnbits.db import Filters, Page
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
|
@ -102,7 +103,7 @@ async def health():
|
||||||
|
|
||||||
@core_app.get("/api/v1/wallet")
|
@core_app.get("/api/v1/wallet")
|
||||||
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
|
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
|
||||||
if wallet.wallet_type == 0:
|
if wallet.wallet_type == WalletType.admin:
|
||||||
return {
|
return {
|
||||||
"id": wallet.wallet.id,
|
"id": wallet.wallet.id,
|
||||||
"name": wallet.wallet.name,
|
"name": wallet.wallet.name,
|
||||||
|
@ -318,7 +319,7 @@ async def api_payments_create(
|
||||||
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||||
invoiceData: CreateInvoice = Body(...),
|
invoiceData: CreateInvoice = Body(...),
|
||||||
):
|
):
|
||||||
if invoiceData.out is True and wallet.wallet_type == 0:
|
if invoiceData.out is True and wallet.wallet_type == WalletType.admin:
|
||||||
if not invoiceData.bolt11:
|
if not invoiceData.bolt11:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Literal, Optional, Type
|
from typing import Literal, Optional, Type
|
||||||
|
|
||||||
from fastapi import Query, Request, Security, status
|
from fastapi import Query, Request, Security
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||||
|
@ -9,7 +9,7 @@ from fastapi.security.base import SecurityBase
|
||||||
from pydantic.types import UUID4
|
from pydantic.types import UUID4
|
||||||
|
|
||||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||||
from lnbits.core.models import User, Wallet
|
from lnbits.core.models import User, WalletType, WalletTypeInfo
|
||||||
from lnbits.db import Filter, Filters, TFilterModel
|
from lnbits.db import Filter, Filters, TFilterModel
|
||||||
from lnbits.requestvars import g
|
from lnbits.requestvars import g
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
@ -25,7 +25,7 @@ class KeyChecker(SecurityBase):
|
||||||
):
|
):
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
self.auto_error = auto_error
|
self.auto_error = auto_error
|
||||||
self._key_type = "invoice"
|
self._key_type = WalletType.invoice
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
if api_key:
|
if api_key:
|
||||||
key = APIKey(
|
key = APIKey(
|
||||||
|
@ -39,7 +39,7 @@ class KeyChecker(SecurityBase):
|
||||||
name="X-API-KEY",
|
name="X-API-KEY",
|
||||||
description="Wallet API Key - HEADER",
|
description="Wallet API Key - HEADER",
|
||||||
)
|
)
|
||||||
self.wallet = None # type: ignore
|
self.wallet = None
|
||||||
self.model: APIKey = key
|
self.model: APIKey = key
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request):
|
||||||
|
@ -81,7 +81,7 @@ class WalletInvoiceKeyChecker(KeyChecker):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
super().__init__(scheme_name, auto_error, api_key)
|
||||||
self._key_type = "invoice"
|
self._key_type = WalletType.invoice
|
||||||
|
|
||||||
|
|
||||||
class WalletAdminKeyChecker(KeyChecker):
|
class WalletAdminKeyChecker(KeyChecker):
|
||||||
|
@ -100,16 +100,7 @@ class WalletAdminKeyChecker(KeyChecker):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
super().__init__(scheme_name, auto_error, api_key)
|
||||||
self._key_type = "admin"
|
self._key_type = WalletType.admin
|
||||||
|
|
||||||
|
|
||||||
class WalletTypeInfo:
|
|
||||||
wallet_type: int
|
|
||||||
wallet: Wallet
|
|
||||||
|
|
||||||
def __init__(self, wallet_type: int, wallet: Wallet) -> None:
|
|
||||||
self.wallet_type = wallet_type
|
|
||||||
self.wallet = wallet
|
|
||||||
|
|
||||||
|
|
||||||
api_key_header = APIKeyHeader(
|
api_key_header = APIKeyHeader(
|
||||||
|
@ -129,11 +120,6 @@ async def get_key_type(
|
||||||
api_key_header: str = Security(api_key_header),
|
api_key_header: str = Security(api_key_header),
|
||||||
api_key_query: str = Security(api_key_query),
|
api_key_query: str = Security(api_key_query),
|
||||||
) -> WalletTypeInfo:
|
) -> WalletTypeInfo:
|
||||||
# 0: admin
|
|
||||||
# 1: invoice
|
|
||||||
# 2: invalid
|
|
||||||
pathname = r["path"].split("/")[1]
|
|
||||||
|
|
||||||
token = api_key_header or api_key_query
|
token = api_key_header or api_key_query
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
|
@ -142,33 +128,34 @@ async def get_key_type(
|
||||||
detail="Invoice (or Admin) key required.",
|
detail="Invoice (or Admin) key required.",
|
||||||
)
|
)
|
||||||
|
|
||||||
for typenr, WalletChecker in zip(
|
for wallet_type, WalletChecker in zip(
|
||||||
[0, 1], [WalletAdminKeyChecker, WalletInvoiceKeyChecker]
|
[WalletType.admin, WalletType.invoice],
|
||||||
|
[WalletAdminKeyChecker, WalletInvoiceKeyChecker],
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
checker = WalletChecker(api_key=token)
|
checker = WalletChecker(api_key=token)
|
||||||
await checker.__call__(r)
|
await checker.__call__(r)
|
||||||
wallet = WalletTypeInfo(typenr, checker.wallet) # type: ignore
|
if checker.wallet is None:
|
||||||
if wallet is None or wallet.wallet is None:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
|
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
|
||||||
)
|
)
|
||||||
|
wallet = WalletTypeInfo(wallet_type, checker.wallet)
|
||||||
if (
|
if (
|
||||||
wallet.wallet.user != settings.super_user
|
wallet.wallet.user != settings.super_user
|
||||||
and wallet.wallet.user not in settings.lnbits_admin_users
|
and wallet.wallet.user not in settings.lnbits_admin_users
|
||||||
) and (
|
) and (
|
||||||
settings.lnbits_admin_extensions
|
settings.lnbits_admin_extensions
|
||||||
and pathname in settings.lnbits_admin_extensions
|
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.FORBIDDEN,
|
status_code=HTTPStatus.FORBIDDEN,
|
||||||
detail="User not authorized for this extension.",
|
detail="User not authorized for this extension.",
|
||||||
)
|
)
|
||||||
return wallet
|
return wallet
|
||||||
except HTTPException as e:
|
except HTTPException as exc:
|
||||||
if e.status_code == HTTPStatus.BAD_REQUEST:
|
if exc.status_code == HTTPStatus.BAD_REQUEST:
|
||||||
raise
|
raise
|
||||||
elif e.status_code == HTTPStatus.UNAUTHORIZED:
|
elif exc.status_code == HTTPStatus.UNAUTHORIZED:
|
||||||
# we pass this in case it is not an invoice key, nor an admin key, and then return NOT_FOUND at the end of this block
|
# we pass this in case it is not an invoice key, nor an admin key, and then return NOT_FOUND at the end of this block
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -199,7 +186,7 @@ async def require_admin_key(
|
||||||
# If wallet type is not admin then return the unauthorized status
|
# If wallet type is not admin then return the unauthorized status
|
||||||
# This also covers when the user passes an invalid key type
|
# This also covers when the user passes an invalid key type
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin key required."
|
status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return wallet
|
return wallet
|
||||||
|
@ -220,11 +207,12 @@ async def require_invoice_key(
|
||||||
|
|
||||||
wallet = await get_key_type(r, token)
|
wallet = await get_key_type(r, token)
|
||||||
|
|
||||||
if wallet.wallet_type > 1:
|
if (
|
||||||
# If wallet type is not invoice then return the unauthorized status
|
wallet.wallet_type != WalletType.admin
|
||||||
# This also covers when the user passes an invalid key type
|
and wallet.wallet_type != WalletType.invoice
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=HTTPStatus.UNAUTHORIZED,
|
||||||
detail="Invoice (or Admin) key required.",
|
detail="Invoice (or Admin) key required.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Reference in a new issue