2020-03-04 23:11:15 +01:00
|
|
|
from functools import wraps
|
2020-05-03 15:57:05 +02:00
|
|
|
from http import HTTPStatus
|
2021-08-29 19:38:42 +02:00
|
|
|
|
|
|
|
from fastapi.security import api_key
|
2021-09-11 15:18:09 +02:00
|
|
|
from pydantic.types import UUID4
|
|
|
|
from lnbits.core.models import User, Wallet
|
2020-03-04 23:11:15 +01:00
|
|
|
from typing import List, Union
|
|
|
|
from uuid import UUID
|
|
|
|
|
2021-08-29 19:38:42 +02:00
|
|
|
from cerberus import Validator # type: ignore
|
|
|
|
from fastapi.exceptions import HTTPException
|
|
|
|
from fastapi.openapi.models import APIKey, APIKeyIn
|
|
|
|
from fastapi.params import Security
|
|
|
|
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
|
|
|
from fastapi.security.base import SecurityBase
|
|
|
|
from starlette.requests import Request
|
|
|
|
|
2020-03-04 23:11:15 +01:00
|
|
|
from lnbits.core.crud import get_user, get_wallet_for_key
|
2021-08-22 20:07:24 +02:00
|
|
|
from lnbits.requestvars import g
|
2021-08-29 19:38:42 +02:00
|
|
|
from lnbits.settings import LNBITS_ALLOWED_USERS
|
2020-03-04 23:11:15 +01:00
|
|
|
|
|
|
|
|
2021-08-29 19:38:42 +02:00
|
|
|
class KeyChecker(SecurityBase):
|
|
|
|
def __init__(self, scheme_name: str = None, auto_error: bool = True, api_key: str = None):
|
|
|
|
self.scheme_name = scheme_name or self.__class__.__name__
|
|
|
|
self.auto_error = auto_error
|
|
|
|
self._key_type = "invoice"
|
|
|
|
self._api_key = api_key
|
|
|
|
if api_key:
|
|
|
|
self.model: APIKey= APIKey(
|
|
|
|
**{"in": APIKeyIn.query}, name="X-API-KEY", description="Wallet API Key - QUERY"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.model: APIKey= APIKey(
|
|
|
|
**{"in": APIKeyIn.header}, name="X-API-KEY", description="Wallet API Key - HEADER"
|
|
|
|
)
|
|
|
|
self.wallet = None
|
|
|
|
|
|
|
|
async def __call__(self, request: Request) -> Wallet:
|
|
|
|
try:
|
|
|
|
key_value = self._api_key if self._api_key else request.headers.get("X-API-KEY") or request.query_params["api-key"]
|
|
|
|
# FIXME: Find another way to validate the key. A fetch from DB should be avoided here.
|
|
|
|
# Also, we should not return the wallet here - thats silly.
|
|
|
|
# Possibly store it in a Redis DB
|
|
|
|
self.wallet = await get_wallet_for_key(key_value, self._key_type)
|
|
|
|
if not self.wallet:
|
|
|
|
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.")
|
|
|
|
|
|
|
|
except KeyError:
|
|
|
|
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,
|
|
|
|
detail="`X-API-KEY` header missing.")
|
|
|
|
|
|
|
|
class WalletInvoiceKeyChecker(KeyChecker):
|
|
|
|
"""
|
|
|
|
WalletInvoiceKeyChecker will ensure that the provided invoice
|
|
|
|
wallet key is correct and populate g().wallet with the wallet
|
|
|
|
for the key in `X-API-key`.
|
|
|
|
|
|
|
|
The checker will raise an HTTPException when the key is wrong in some ways.
|
|
|
|
"""
|
|
|
|
def __init__(self, scheme_name: str = None, auto_error: bool = True, api_key: str = None):
|
|
|
|
super().__init__(scheme_name, auto_error, api_key)
|
|
|
|
self._key_type = "invoice"
|
|
|
|
|
|
|
|
class WalletAdminKeyChecker(KeyChecker):
|
|
|
|
"""
|
|
|
|
WalletAdminKeyChecker will ensure that the provided admin
|
|
|
|
wallet key is correct and populate g().wallet with the wallet
|
|
|
|
for the key in `X-API-key`.
|
|
|
|
|
|
|
|
The checker will raise an HTTPException when the key is wrong in some ways.
|
|
|
|
"""
|
|
|
|
def __init__(self, scheme_name: str = None, auto_error: bool = True, api_key: str = None):
|
|
|
|
super().__init__(scheme_name, auto_error, api_key)
|
|
|
|
self._key_type = "admin"
|
|
|
|
|
|
|
|
class WalletTypeInfo():
|
|
|
|
wallet_type: int
|
|
|
|
wallet: Wallet
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2021-08-29 19:38:42 +02:00
|
|
|
def __init__(self, wallet_type: int, wallet: Wallet) -> None:
|
|
|
|
self.wallet_type = wallet_type
|
|
|
|
self.wallet = wallet
|
2020-03-04 23:11:15 +01:00
|
|
|
|
|
|
|
|
2021-08-29 19:38:42 +02:00
|
|
|
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's")
|
|
|
|
api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's")
|
|
|
|
async def get_key_type(r: Request,
|
|
|
|
api_key_header: str = Security(api_key_header),
|
|
|
|
api_key_query: str = Security(api_key_query)) -> WalletTypeInfo:
|
|
|
|
# 0: admin
|
|
|
|
# 1: invoice
|
|
|
|
# 2: invalid
|
|
|
|
try:
|
|
|
|
checker = WalletAdminKeyChecker(api_key=api_key_query)
|
|
|
|
await checker.__call__(r)
|
|
|
|
return WalletTypeInfo(0, checker.wallet)
|
|
|
|
except HTTPException as e:
|
2021-09-28 21:13:04 +02:00
|
|
|
if e.status_code == HTTPStatus.BAD_REQUEST:
|
|
|
|
raise
|
2021-08-29 19:38:42 +02:00
|
|
|
if e.status_code == HTTPStatus.UNAUTHORIZED:
|
|
|
|
pass
|
|
|
|
except:
|
|
|
|
raise
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2021-08-29 19:38:42 +02:00
|
|
|
try:
|
|
|
|
checker = WalletInvoiceKeyChecker()
|
|
|
|
await checker.__call__(r)
|
|
|
|
return WalletTypeInfo(1, checker.wallet)
|
|
|
|
except HTTPException as e:
|
2021-09-28 21:13:04 +02:00
|
|
|
if e.status_code == HTTPStatus.BAD_REQUEST:
|
|
|
|
raise
|
2021-08-29 19:38:42 +02:00
|
|
|
if e.status_code == HTTPStatus.UNAUTHORIZED:
|
|
|
|
return WalletTypeInfo(2, None)
|
|
|
|
except:
|
|
|
|
raise
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2020-04-11 19:47:25 +02:00
|
|
|
def api_validate_post_request(*, schema: dict):
|
2020-03-04 23:11:15 +01:00
|
|
|
def wrap(view):
|
|
|
|
@wraps(view)
|
2020-09-14 02:31:05 +02:00
|
|
|
async def wrapped_view(**kwargs):
|
2020-03-04 23:11:15 +01:00
|
|
|
if "application/json" not in request.headers["Content-Type"]:
|
2021-09-10 21:40:14 +02:00
|
|
|
raise HTTPException(
|
|
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
|
|
detail=jsonify({"message": "Content-Type must be `application/json`."})
|
2020-09-03 23:02:15 +02:00
|
|
|
)
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2020-04-11 19:47:25 +02:00
|
|
|
v = Validator(schema)
|
2020-09-14 02:31:05 +02:00
|
|
|
data = await request.get_json()
|
2021-08-22 20:07:24 +02:00
|
|
|
g().data = {key: data[key] for key in schema.keys() if key in data}
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2021-08-22 20:07:24 +02:00
|
|
|
if not v.validate(g().data):
|
2021-09-10 21:40:14 +02:00
|
|
|
raise HTTPException(
|
|
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
|
|
detail=jsonify({"message": f"Errors in request data: {v.errors}"})
|
2020-09-03 23:02:15 +02:00
|
|
|
)
|
2021-09-10 21:40:14 +02:00
|
|
|
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2020-09-14 02:31:05 +02:00
|
|
|
return await view(**kwargs)
|
2020-03-04 23:11:15 +01:00
|
|
|
|
|
|
|
return wrapped_view
|
|
|
|
|
|
|
|
return wrap
|
|
|
|
|
|
|
|
|
2021-09-11 15:18:09 +02:00
|
|
|
async def check_user_exists(usr: UUID4) -> User:
|
|
|
|
g().user = await get_user(usr.hex)
|
|
|
|
if not g().user:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=HTTPStatus.NOT_FOUND,
|
|
|
|
detail="User does not exist."
|
|
|
|
)
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2021-09-11 15:18:09 +02:00
|
|
|
if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=HTTPStatus.UNAUTHORIZED,
|
|
|
|
detail="User not authorized."
|
|
|
|
)
|
2020-03-04 23:11:15 +01:00
|
|
|
|
2021-09-11 15:18:09 +02:00
|
|
|
return g().user
|