lnbits-legend/lnbits/middleware.py
2024-09-11 12:41:37 +03:00

163 lines
6 KiB
Python

from http import HTTPStatus
from typing import Any, List, Union
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.core.db import core_app_extra
from lnbits.helpers import template_renderer
from lnbits.settings import settings
class InstalledExtensionMiddleware:
# This middleware class intercepts calls made to the extensions API and:
# - it blocks the calls if the extension has been disabled or uninstalled.
# - it redirects the calls to the latest version of the extension
# if the extension has been upgraded.
# - otherwise it has no effect
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
full_path = scope.get("path", "/")
if full_path == "/":
await self.app(scope, receive, send)
return
top_path, *rest = (p for p in full_path.split("/") if p)
headers = scope.get("headers", [])
# block path for all users if the extension is disabled
if top_path in settings.lnbits_deactivated_extensions:
response = self._response_by_accepted_type(
scope, headers, f"Extension '{top_path}' disabled", HTTPStatus.NOT_FOUND
)
await response(scope, receive, send)
return
# static resources do not require redirect
if rest[0:1] == ["static"]:
await self.app(scope, receive, send)
return
# re-route all trafic if the extension has been upgraded
if top_path in settings.lnbits_upgraded_extensions:
upgrade_path = (
f"""{settings.lnbits_upgraded_extensions[top_path]}/{top_path}"""
)
tail = "/".join(rest)
scope["path"] = f"/upgrades/{upgrade_path}/{tail}"
await self.app(scope, receive, send)
def _response_by_accepted_type(
self, scope: Scope, headers: List[Any], msg: str, status_code: HTTPStatus
) -> Union[HTMLResponse, JSONResponse]:
"""
Build an HTTP response containing the `msg` as HTTP body and the `status_code`
as HTTP code. If the `accept` HTTP header is present int the request and
contains the value of `text/html` then return an `HTMLResponse`,
otherwise return an `JSONResponse`.
"""
accept_header: str = next(
(
h[1].decode("UTF-8")
for h in headers
if len(h) >= 2 and h[0].decode("UTF-8") == "accept"
),
"",
)
if "text/html" in accept_header.split(","):
return HTMLResponse(
status_code=status_code,
content=template_renderer()
.TemplateResponse(Request(scope), "error.html", {"err": msg})
.body,
)
return JSONResponse(
status_code=status_code,
content={"detail": msg},
)
class CustomGZipMiddleware(GZipMiddleware):
def __init__(self, *args, exclude_paths=None, **kwargs):
super().__init__(*args, **kwargs)
self.exclude_paths = exclude_paths or []
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" in scope and scope["path"] in self.exclude_paths:
await self.app(scope, receive, send)
return
await super().__call__(scope, receive, send)
class ExtensionsRedirectMiddleware:
# Extensions are allowed to specify redirect paths. A call to a path outside the
# scope of the extension can be redirected to one of the extension's endpoints.
# Eg: redirect `GET /.well-known` to `GET /lnurlp/api/v1/well-known`
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" not in scope:
await self.app(scope, receive, send)
return
req_headers = scope["headers"] if "headers" in scope else []
redirect = settings.find_extension_redirect(scope["path"], req_headers)
if redirect:
scope["path"] = redirect.new_path_from(scope["path"])
await self.app(scope, receive, send)
def add_ratelimit_middleware(app: FastAPI):
core_app_extra.register_new_ratelimiter()
# latest https://slowapi.readthedocs.io/en/latest/
# shows this as a valid way to add the handler
app.add_exception_handler(
RateLimitExceeded,
_rate_limit_exceeded_handler, # type: ignore
)
app.add_middleware(SlowAPIMiddleware)
def add_ip_block_middleware(app: FastAPI):
@app.middleware("http")
async def block_allow_ip_middleware(request: Request, call_next):
if not request.client:
return JSONResponse(
status_code=403, # Forbidden
content={"detail": "No request client"},
)
if (
request.client.host in settings.lnbits_blocked_ips
and request.client.host not in settings.lnbits_allowed_ips
):
return JSONResponse(
status_code=403, # Forbidden
content={"detail": "IP is blocked"},
)
return await call_next(request)
def add_first_install_middleware(app: FastAPI):
@app.middleware("http")
async def first_install_middleware(request: Request, call_next):
if (
settings.first_install
and request.url.path != "/api/v1/auth/first_install"
and request.url.path != "/first_install"
and not request.url.path.startswith("/static")
):
return RedirectResponse("/first_install")
return await call_next(request)