feat: switch from Quart to FastAPI part I

This commit is contained in:
Stefan Stammberger 2021-08-22 20:07:24 +02:00
parent fc68e0a6da
commit 938fc54af3
No known key found for this signature in database
GPG Key ID: 645FA807E935D9D5
12 changed files with 245 additions and 114 deletions

View File

@ -1,4 +1,7 @@
from hypercorn.trio import serve
import trio
import trio_asyncio
from hypercorn.config import Config
from .commands import migrate_databases, transpile_scss, bundle_vendored
@ -8,7 +11,7 @@ bundle_vendored()
from .app import create_app
app = create_app()
app = trio.run(create_app)
from .settings import (
LNBITS_SITE_TITLE,
@ -17,6 +20,8 @@ from .settings import (
LNBITS_DATA_FOLDER,
WALLET,
LNBITS_COMMIT,
HOST,
PORT
)
print(
@ -30,4 +35,6 @@ print(
"""
)
app.run(host=app.config["HOST"], port=app.config["PORT"])
config = Config()
config.bind = [f"{HOST}:{PORT}"]
trio_asyncio.run(serve, app, config)

View File

@ -1,12 +1,15 @@
import jinja2
from lnbits.jinja2_templating import Jinja2Templates
import sys
import warnings
import importlib
import traceback
import trio
from quart import g
from quart_trio import QuartTrio
from quart_cors import cors # type: ignore
from quart_compress import Compress # type: ignore
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.staticfiles import StaticFiles
from .commands import db_migrate, handle_assets
from .core import core_app
@ -26,32 +29,66 @@ from .tasks import (
catch_everything_and_restart,
)
from .settings import WALLET
from .requestvars import g, request_global
import lnbits.settings
def create_app(config_object="lnbits.settings") -> QuartTrio:
async def create_app(config_object="lnbits.settings") -> FastAPI:
"""Create application factory.
:param config_object: The configuration object to use.
"""
app = QuartTrio(__name__, static_folder="static")
app.config.from_object(config_object)
app.asgi_http_class = ASGIProxyFix
app = FastAPI()
app.mount("/static", StaticFiles(directory="lnbits/static"), name="static")
cors(app)
Compress(app)
origins = [
"http://localhost",
"http://localhost:5000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
g().config = lnbits.settings
g().templates = build_standard_jinja_templates()
app.add_middleware(GZipMiddleware, minimum_size=1000)
# app.add_middleware(ASGIProxyFix)
check_funding_source(app)
register_assets(app)
register_blueprints(app)
register_filters(app)
register_commands(app)
register_routes(app)
# register_commands(app)
register_async_tasks(app)
register_exception_handlers(app)
# register_exception_handlers(app)
return app
def build_standard_jinja_templates():
t = Jinja2Templates(
loader=jinja2.FileSystemLoader(["lnbits/templates", "lnbits/core/templates"]),
)
t.env.globals["SITE_TITLE"] = lnbits.settings.LNBITS_SITE_TITLE
t.env.globals["SITE_TAGLINE"] = lnbits.settings.LNBITS_SITE_TAGLINE
t.env.globals["SITE_DESCRIPTION"] = lnbits.settings.LNBITS_SITE_DESCRIPTION
t.env.globals["LNBITS_THEME_OPTIONS"] = lnbits.settings.LNBITS_THEME_OPTIONS
t.env.globals["LNBITS_VERSION"] = lnbits.settings.LNBITS_COMMIT
t.env.globals["EXTENSIONS"] = get_valid_extensions()
if g().config.DEBUG:
t.env.globals["VENDORED_JS"] = map(url_for_vendored, get_js_vendored())
t.env.globals["VENDORED_CSS"] = map(url_for_vendored, get_css_vendored())
else:
t.env.globals["VENDORED_JS"] = ["/static/bundle.js"]
t.env.globals["VENDORED_CSS"] = ["/static/bundle.css"]
def check_funding_source(app: QuartTrio) -> None:
@app.before_serving
return t
def check_funding_source(app: FastAPI) -> None:
@app.on_event("startup")
async def check_wallet_status():
error_message, balance = await WALLET.status()
if error_message:
@ -67,64 +104,60 @@ def check_funding_source(app: QuartTrio) -> None:
)
def register_blueprints(app: QuartTrio) -> None:
def register_routes(app: FastAPI) -> None:
"""Register Flask blueprints / LNbits extensions."""
app.register_blueprint(core_app)
app.include_router(core_app)
for ext in get_valid_extensions():
try:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
bp = getattr(ext_module, f"{ext.code}_ext")
ext_route = getattr(ext_module, f"{ext.code}_ext")
app.register_blueprint(bp, url_prefix=f"/{ext.code}")
app.include_router(ext_route)
except Exception:
raise ImportError(
f"Please make sure that the extension `{ext.code}` follows conventions."
)
def register_commands(app: QuartTrio):
def register_commands(app: FastAPI):
"""Register Click commands."""
app.cli.add_command(db_migrate)
app.cli.add_command(handle_assets)
def register_assets(app: QuartTrio):
def register_assets(app: FastAPI):
"""Serve each vendored asset separately or a bundle."""
@app.before_request
@app.on_event("startup")
async def vendored_assets_variable():
if app.config["DEBUG"]:
g.VENDORED_JS = map(url_for_vendored, get_js_vendored())
g.VENDORED_CSS = map(url_for_vendored, get_css_vendored())
if g().config.DEBUG:
g().VENDORED_JS = map(url_for_vendored, get_js_vendored())
g().VENDORED_CSS = map(url_for_vendored, get_css_vendored())
else:
g.VENDORED_JS = ["/static/bundle.js"]
g.VENDORED_CSS = ["/static/bundle.css"]
def register_filters(app: QuartTrio):
"""Jinja filters."""
app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"]
app.jinja_env.globals["SITE_TAGLINE"] = app.config["LNBITS_SITE_TAGLINE"]
app.jinja_env.globals["SITE_DESCRIPTION"] = app.config["LNBITS_SITE_DESCRIPTION"]
app.jinja_env.globals["LNBITS_THEME_OPTIONS"] = app.config["LNBITS_THEME_OPTIONS"]
app.jinja_env.globals["LNBITS_VERSION"] = app.config["LNBITS_COMMIT"]
app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions()
g().VENDORED_JS = ["/static/bundle.js"]
g().VENDORED_CSS = ["/static/bundle.css"]
def register_async_tasks(app):
@app.route("/wallet/webhook", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
@app.route("/wallet/webhook")
async def webhook_listener():
return await webhook_handler()
@app.before_serving
@app.on_event("startup")
async def listeners():
run_deferred_async()
app.nursery.start_soon(catch_everything_and_restart, check_pending_payments)
app.nursery.start_soon(catch_everything_and_restart, invoice_listener)
app.nursery.start_soon(catch_everything_and_restart, internal_invoice_listener)
trio.open_process(check_pending_payments)
trio.open_process(invoice_listener)
trio.open_process(internal_invoice_listener)
async with trio.open_nursery() as n:
pass
# n.start_soon(catch_everything_and_restart, check_pending_payments)
# n.start_soon(catch_everything_and_restart, invoice_listener)
# n.start_soon(catch_everything_and_restart, internal_invoice_listener)
@app.after_serving
@app.on_event("shutdown")
async def stop_listeners():
pass

49
lnbits/auth_bearer.py Normal file
View File

@ -0,0 +1,49 @@
from fastapi import Request, HTTPException
from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey
# https://medium.com/data-rebels/fastapi-authentication-revisited-enabling-api-key-authentication-122dc5975680
from fastapi import Security, Depends, FastAPI, HTTPException
from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey
from fastapi.security.base import SecurityBase
API_KEY = "usr"
API_KEY_NAME = "X-API-key"
api_key_query = APIKeyQuery(name=API_KEY_NAME, auto_error=False)
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
class AuthBearer(SecurityBase):
def __init__(self, scheme_name: str = None, auto_error: bool = True):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request):
key = await self.get_api_key()
print(key)
# credentials: HTTPAuthorizationCredentials = await super(AuthBearer, self).__call__(request)
# if credentials:
# if not credentials.scheme == "Bearer":
# raise HTTPException(
# status_code=403, detail="Invalid authentication scheme.")
# if not self.verify_jwt(credentials.credentials):
# raise HTTPException(
# status_code=403, detail="Invalid token or expired token.")
# return credentials.credentials
# else:
# raise HTTPException(
# status_code=403, detail="Invalid authorization code.")
async def get_api_key(self,
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
):
if api_key_query == API_KEY:
return api_key_query
elif api_key_header == API_KEY:
return api_key_header
else:
raise HTTPException(status_code=403, detail="Could not validate credentials")

View File

@ -1,22 +1,19 @@
from quart import Blueprint
from fastapi.routing import APIRouter
from lnbits.db import Database
db = Database("database")
core_app: Blueprint = Blueprint(
"core",
__name__,
template_folder="templates",
static_folder="static",
static_url_path="/core/static",
)
from .views.api import * # noqa
from .views.generic import * # noqa
from .views.public_api import * # noqa
from .tasks import register_listeners
core_app: APIRouter = APIRouter()
from lnbits.tasks import record_async
core_app.record(record_async(register_listeners))
from .tasks import register_listeners
from .views.api import * # noqa
from .views.generic import * # noqa
from .views.public_api import * # noqa
@core_app.on_event("startup")
def do_startup():
record_async(register_listeners)

View File

@ -1,10 +1,12 @@
from fastapi.param_functions import Depends
from lnbits.auth_bearer import AuthBearer
from pydantic import BaseModel
import trio
import json
import httpx
import hashlib
from urllib.parse import urlparse, urlunparse, urlencode, parse_qs, ParseResult
from quart import g, current_app, make_response, url_for
from quart import current_app, make_response, url_for
from fastapi import Query
@ -15,6 +17,7 @@ from typing import Dict, List, Optional, Union
from lnbits import bolt11, lnurl
from lnbits.decorators import api_check_wallet_key, api_validate_post_request
from lnbits.utils.exchange_rates import currencies, fiat_amount_as_satoshis
from lnbits.requestvars import g
from .. import core_app, db
from ..crud import get_payments, save_balance_check, update_wallet
@ -28,11 +31,14 @@ from ..services import (
from ..tasks import api_invoice_listeners
@core_app.get("/api/v1/wallet")
@api_check_wallet_key("invoice")
@core_app.get(
"/api/v1/wallet",
# dependencies=[Depends(AuthBearer())]
)
# @api_check_wallet_key("invoice")
async def api_wallet():
return (
{"id": g.wallet.id, "name": g.wallet.name, "balance": g.wallet.balance_msat},
{"id": g().wallet.id, "name": g().wallet.name, "balance": g().wallet.balance_msat},
HTTPStatus.OK,
)
@ -40,12 +46,12 @@ async def api_wallet():
@core_app.put("/api/v1/wallet/<new_name>")
@api_check_wallet_key("invoice")
async def api_update_wallet(new_name: str):
await update_wallet(g.wallet.id, new_name)
await update_wallet(g().wallet.id, new_name)
return (
{
"id": g.wallet.id,
"name": g.wallet.name,
"balance": g.wallet.balance_msat,
"id": g().wallet.id,
"name": g().wallet.name,
"balance": g().wallet.balance_msat,
},
HTTPStatus.OK,
)
@ -55,7 +61,7 @@ async def api_update_wallet(new_name: str):
@api_check_wallet_key("invoice")
async def api_payments():
return (
await get_payments(wallet_id=g.wallet.id, pending=True, complete=True),
await get_payments(wallet_id=g().wallet.id, pending=True, complete=True),
HTTPStatus.OK,
)
@ -88,7 +94,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData):
async with db.connect() as conn:
try:
payment_hash, payment_request = await create_invoice(
wallet_id=g.wallet.id,
wallet_id=g().wallet.id,
amount=amount,
memo=memo,
description_hash=description_hash,
@ -105,8 +111,8 @@ async def api_payments_create_invoice(data: CreateInvoiceData):
lnurl_response: Union[None, bool, str] = None
if data.lnurl_callback:
if "lnurl_balance_check" in g.data:
save_balance_check(g.wallet.id, data.lnurl_balance_check)
if "lnurl_balance_check" in g().data:
save_balance_check(g().wallet.id, data.lnurl_balance_check)
async with httpx.AsyncClient() as client:
try:
@ -117,7 +123,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData):
"balanceNotify": url_for(
"core.lnurl_balance_notify",
service=urlparse(data.lnurl_callback).netloc,
wal=g.wallet.id,
wal=g().wallet.id,
_external=True,
),
},
@ -217,14 +223,14 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
if invoice.amount_msat != data.amount:
return (
{
"message": f"{domain} returned an invalid invoice. Expected {g.data['amount']} msat, got {invoice.amount_msat}."
"message": f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}."
},
HTTPStatus.BAD_REQUEST,
)
if invoice.description_hash != g.data["description_hash"]:
if invoice.description_hash != g().data["description_hash"]:
return (
{
"message": f"{domain} returned an invalid invoice. Expected description_hash == {g.data['description_hash']}, got {invoice.description_hash}."
"message": f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}."
},
HTTPStatus.BAD_REQUEST,
)
@ -237,7 +243,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
extra["comment"] = data.comment
payment_hash = await pay_invoice(
wallet_id=g.wallet.id,
wallet_id=g().wallet.id,
payment_request=params["pr"],
description=data.description,
extra=extra,
@ -257,7 +263,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
@core_app.get("/api/v1/payments/<payment_hash>")
@api_check_wallet_key("invoice")
async def api_payment(payment_hash):
payment = await g.wallet.get_payment(payment_hash)
payment = await g().wallet.get_payment(payment_hash)
if not payment:
return {"message": "Payment does not exist."}, HTTPStatus.NOT_FOUND
@ -278,7 +284,7 @@ async def api_payment(payment_hash):
@core_app.get("/api/v1/payments/sse")
@api_check_wallet_key("invoice", accept_querystring=True)
async def api_payments_sse():
this_wallet_id = g.wallet.id
this_wallet_id = g().wallet.id
send_payment, receive_payment = trio.open_memory_channel(0)
@ -356,7 +362,7 @@ async def api_lnurlscan(code: str):
params.update(kind="auth")
params.update(callback=url) # with k1 already in it
lnurlauth_key = g.wallet.lnurlauth_key(domain)
lnurlauth_key = g().wallet.lnurlauth_key(domain)
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
else:
async with httpx.AsyncClient() as client:

View File

@ -1,15 +1,10 @@
from lnbits.requestvars import g
from os import path
from http import HTTPStatus
from quart import (
g,
current_app,
abort,
request,
redirect,
render_template,
send_from_directory,
url_for,
)
from typing import Optional
import jinja2
from starlette.responses import HTMLResponse
from lnbits.core import core_app, db
from lnbits.decorators import check_user_exists, validate_uuids
@ -26,20 +21,18 @@ from ..crud import (
)
from ..services import redeem_lnurl_withdraw, pay_invoice
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.responses import FileResponse
from lnbits.jinja2_templating import Jinja2Templates
templates = Jinja2Templates(directory="templates")
@core_app.get("/favicon.ico")
async def favicon():
return await send_from_directory(
path.join(core_app.root_path, "static"), "favicon.ico"
)
return FileResponse("lnbits/core/static/favicon.ico")
@core_app.get("/")
@core_app.get("/", response_class=HTMLResponse)
async def home(request: Request, lightning: str = None):
return templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning})
return g().templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning})
@core_app.get("/extensions")

View File

@ -7,7 +7,7 @@ from uuid import UUID
from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.settings import LNBITS_ALLOWED_USERS
from lnbits.requestvars import g
def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False):
def wrap(view):
@ -15,14 +15,14 @@ def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False):
async def wrapped_view(**kwargs):
try:
key_value = request.headers.get("X-Api-Key") or request.args["api-key"]
g.wallet = await get_wallet_for_key(key_value, key_type)
g().wallet = await get_wallet_for_key(key_value, key_type)
except KeyError:
return (
jsonify({"message": "`X-Api-Key` header missing."}),
HTTPStatus.BAD_REQUEST,
)
if not g.wallet:
if not g().wallet:
return jsonify({"message": "Wrong keys."}), HTTPStatus.UNAUTHORIZED
return await view(**kwargs)
@ -44,9 +44,9 @@ def api_validate_post_request(*, schema: dict):
v = Validator(schema)
data = await request.get_json()
g.data = {key: data[key] for key in schema.keys() if key in data}
g().data = {key: data[key] for key in schema.keys() if key in data}
if not v.validate(g.data):
if not v.validate(g().data):
return (
jsonify({"message": f"Errors in request data: {v.errors}"}),
HTTPStatus.BAD_REQUEST,
@ -63,11 +63,11 @@ def check_user_exists(param: str = "usr"):
def wrap(view):
@wraps(view)
async def wrapped_view(**kwargs):
g.user = await get_user(request.args.get(param, type=str)) or abort(
g().user = await get_user(request.args.get(param, type=str)) or abort(
HTTPStatus.NOT_FOUND, "User does not exist."
)
if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS:
if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS:
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.")
return await view(**kwargs)

View File

@ -1,11 +1,12 @@
from quart import Blueprint
from fastapi import APIRouter
from lnbits.db import Database
db = Database("ext_offlineshop")
offlineshop_ext: Blueprint = Blueprint(
"offlineshop", __name__, static_folder="static", template_folder="templates"
offlineshop_ext: APIRouter = APIRouter(
prefix="/Extension",
tags=["Apps", "Offlineshop"]
)

View File

@ -0,0 +1,36 @@
# Borrowed from the excellent accent-starlette
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
import typing
from starlette import templating
from starlette.datastructures import QueryParams
from lnbits.requestvars import g
try:
import jinja2
except ImportError: # pragma: nocover
jinja2 = None # type: ignore
class Jinja2Templates(templating.Jinja2Templates):
def __init__(self, loader: jinja2.BaseLoader) -> None:
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
self.env = self.get_environment(loader)
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
@jinja2.contextfunction
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request = context["request"]
return request.url_for(name, **path_params)
def url_params_update(init: QueryParams, **new: typing.Any) -> QueryParams:
values = dict(init)
values.update(new)
return QueryParams(**values)
env = jinja2.Environment(loader=loader, autoescape=True)
env.globals["url_for"] = url_for
env.globals["url_params_update"] = url_params_update
return env

9
lnbits/requestvars.py Normal file
View File

@ -0,0 +1,9 @@
import contextvars
import types
request_global = contextvars.ContextVar("request_global",
default=types.SimpleNamespace())
def g() -> types.SimpleNamespace:
return request_global.get()

View File

@ -2,7 +2,7 @@
<html lang="en">
<head>
{% for url in g.VENDORED_CSS %}
{% for url in VENDORED_CSS %}
<link rel="stylesheet" type="text/css" href="{{ url }}" />
{% endfor %}
<!---->
@ -184,7 +184,7 @@
{% block vue_templates %}{% endblock %}
<!---->
{% for url in g.VENDORED_JS %}
{% for url in VENDORED_JS %}
<script src="{{ url }}"></script>
{% endfor %}
<!---->

View File

@ -2,7 +2,7 @@
<html lang="en">
<head>
{% for url in g.VENDORED_CSS %}
{% for url in VENDORED_CSS %}
<link rel="stylesheet" type="text/css" href="{{ url }}" />
{% endfor %}
<style>
@ -33,7 +33,7 @@
</q-page-container>
</q-layout>
{% for url in g.VENDORED_JS %}
{% for url in VENDORED_JS %}
<script src="{{ url }}"></script>
{% endfor %}
<!---->