lnbits-legend/lnbits/decorators.py

230 lines
7.2 KiB
Python
Raw Normal View History

from http import HTTPStatus
2021-08-29 19:38:42 +02:00
from cerberus import Validator # type: ignore
2021-10-18 17:06:06 +02:00
from fastapi import status
2021-08-29 19:38:42 +02:00
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.params import Security
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
from fastapi.security.base import SecurityBase
2021-10-18 17:06:06 +02:00
from pydantic.types import UUID4
2021-08-29 19:38:42 +02:00
from starlette.requests import Request
from lnbits.core.crud import get_user, get_wallet_for_key
2021-10-18 17:06:06 +02:00
from lnbits.core.models import User, Wallet
from lnbits.requestvars import g
2022-06-01 14:53:05 +02:00
from lnbits.settings import (
LNBITS_ADMIN_EXTENSIONS,
2022-07-16 14:23:03 +02:00
LNBITS_ADMIN_USERS,
LNBITS_ALLOWED_USERS,
2022-06-01 14:53:05 +02:00
)
2021-08-29 19:38:42 +02:00
class KeyChecker(SecurityBase):
2021-10-17 19:33:29 +02:00
def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
):
2021-08-29 19:38:42 +02:00
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
self._key_type = "invoice"
self._api_key = api_key
if api_key:
2021-10-17 19:33:29 +02:00
self.model: APIKey = APIKey(
**{"in": APIKeyIn.query},
name="X-API-KEY",
description="Wallet API Key - QUERY",
2021-08-29 19:38:42 +02:00
)
else:
2021-10-17 19:33:29 +02:00
self.model: APIKey = APIKey(
**{"in": APIKeyIn.header},
name="X-API-KEY",
description="Wallet API Key - HEADER",
2021-08-29 19:38:42 +02:00
)
self.wallet = None
async def __call__(self, request: Request) -> Wallet:
try:
2021-10-17 19:33:29 +02:00
key_value = (
self._api_key
if self._api_key
else request.headers.get("X-API-KEY") or request.query_params["api-key"]
)
2021-08-29 19:38:42 +02:00
# FIXME: Find another way to validate the key. A fetch from DB should be avoided here.
# Also, we should not return the wallet here - thats silly.
# Possibly store it in a Redis DB
self.wallet = await get_wallet_for_key(key_value, self._key_type)
if not self.wallet:
2021-10-17 19:33:29 +02:00
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Invalid key or expired key.",
)
2021-08-29 19:38:42 +02:00
except KeyError:
2021-10-17 19:33:29 +02:00
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="`X-API-KEY` header missing."
)
2021-08-29 19:38:42 +02:00
class WalletInvoiceKeyChecker(KeyChecker):
"""
WalletInvoiceKeyChecker will ensure that the provided invoice
wallet key is correct and populate g().wallet with the wallet
2021-08-29 19:38:42 +02:00
for the key in `X-API-key`.
The checker will raise an HTTPException when the key is wrong in some ways.
"""
2021-10-17 19:33:29 +02:00
def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
):
2021-08-29 19:38:42 +02:00
super().__init__(scheme_name, auto_error, api_key)
self._key_type = "invoice"
2021-10-17 19:33:29 +02:00
2021-08-29 19:38:42 +02:00
class WalletAdminKeyChecker(KeyChecker):
"""
WalletAdminKeyChecker will ensure that the provided admin
wallet key is correct and populate g().wallet with the wallet
2021-08-29 19:38:42 +02:00
for the key in `X-API-key`.
The checker will raise an HTTPException when the key is wrong in some ways.
"""
2021-10-17 19:33:29 +02:00
def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
):
2021-08-29 19:38:42 +02:00
super().__init__(scheme_name, auto_error, api_key)
self._key_type = "admin"
2021-10-17 19:33:29 +02:00
class WalletTypeInfo:
2021-08-29 19:38:42 +02:00
wallet_type: int
wallet: Wallet
2021-08-29 19:38:42 +02:00
def __init__(self, wallet_type: int, wallet: Wallet) -> None:
self.wallet_type = wallet_type
self.wallet = wallet
2021-10-17 19:33:29 +02:00
api_key_header = APIKeyHeader(
name="X-API-KEY",
auto_error=False,
description="Admin or Invoice key for wallet API's",
)
api_key_query = APIKeyQuery(
name="api-key",
auto_error=False,
description="Admin or Invoice key for wallet API's",
)
async def get_key_type(
r: Request,
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
) -> WalletTypeInfo:
2021-08-29 19:38:42 +02:00
# 0: admin
# 1: invoice
# 2: invalid
2022-06-01 14:53:05 +02:00
pathname = r["path"].split("/")[1]
2021-10-17 19:33:29 +02:00
if not api_key_header and not api_key_query:
2021-10-15 19:55:24 +02:00
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
2021-10-15 18:05:38 +02:00
2021-10-15 19:55:24 +02:00
token = api_key_header if api_key_header else api_key_query
2021-10-17 19:33:29 +02:00
2021-08-29 19:38:42 +02:00
try:
2021-10-15 18:05:38 +02:00
checker = WalletAdminKeyChecker(api_key=token)
2021-08-29 19:38:42 +02:00
await checker.__call__(r)
2022-02-07 21:43:47 +01:00
wallet = WalletTypeInfo(0, checker.wallet)
2022-06-01 14:53:05 +02:00
if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and (
LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS
):
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, detail="User not authorized."
)
2022-02-07 21:43:47 +01:00
return wallet
2021-08-29 19:38:42 +02:00
except HTTPException as e:
if e.status_code == HTTPStatus.BAD_REQUEST:
raise
2021-08-29 19:38:42 +02:00
if e.status_code == HTTPStatus.UNAUTHORIZED:
pass
except:
raise
2021-08-29 19:38:42 +02:00
try:
2021-10-15 18:05:38 +02:00
checker = WalletInvoiceKeyChecker(api_key=token)
2021-08-29 19:38:42 +02:00
await checker.__call__(r)
2022-03-08 23:15:45 +01:00
wallet = WalletTypeInfo(1, checker.wallet)
2022-06-01 14:53:05 +02:00
if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and (
LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS
):
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, detail="User not authorized."
)
2022-02-07 21:43:47 +01:00
return wallet
2021-08-29 19:38:42 +02:00
except HTTPException as e:
if e.status_code == HTTPStatus.BAD_REQUEST:
raise
2021-08-29 19:38:42 +02:00
if e.status_code == HTTPStatus.UNAUTHORIZED:
return WalletTypeInfo(2, None)
except:
raise
2021-10-17 19:33:29 +02:00
2021-10-18 17:06:06 +02:00
async def require_admin_key(
r: Request,
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
):
token = api_key_header if api_key_header else api_key_query
2021-10-18 17:06:06 +02:00
wallet = await get_key_type(r, token)
2021-10-18 17:06:06 +02:00
if wallet.wallet_type != 0:
# If wallet type is not admin then return the unauthorized status
# This also covers when the user passes an invalid key type
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin key required."
)
else:
return wallet
2022-01-30 20:43:30 +01:00
2021-12-28 16:22:45 +01:00
async def require_invoice_key(
r: Request,
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
):
token = api_key_header if api_key_header else api_key_query
wallet = await get_key_type(r, token)
if wallet.wallet_type > 1:
# If wallet type is not invoice then return the unauthorized status
# This also covers when the user passes an invalid key type
raise HTTPException(
2022-01-30 20:43:30 +01:00
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invoice (or Admin) key required.",
2021-12-28 16:22:45 +01:00
)
else:
return wallet
async def check_user_exists(usr: UUID4) -> User:
g().user = await get_user(usr.hex)
if not g().user:
raise HTTPException(
2021-10-17 19:33:29 +02:00
status_code=HTTPStatus.NOT_FOUND, detail="User does not exist."
)
if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS:
raise HTTPException(
2021-10-17 19:33:29 +02:00
status_code=HTTPStatus.UNAUTHORIZED, detail="User not authorized."
)
2022-01-31 17:29:42 +01:00
if LNBITS_ADMIN_USERS and g().user.id in LNBITS_ADMIN_USERS:
g().user.admin = True
return g().user