mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-02-21 14:04:25 +01:00
fix proxyfix.
This commit is contained in:
parent
098089af75
commit
49baa07141
2 changed files with 37 additions and 42 deletions
|
@ -9,7 +9,7 @@ from .commands import db_migrate
|
||||||
from .core import core_app
|
from .core import core_app
|
||||||
from .db import open_db
|
from .db import open_db
|
||||||
from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored
|
from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored
|
||||||
from .proxy_fix import ProxyFix
|
from .proxy_fix import ASGIProxyFix
|
||||||
|
|
||||||
secure_headers = SecureHeaders(hsts=False)
|
secure_headers = SecureHeaders(hsts=False)
|
||||||
|
|
||||||
|
@ -20,10 +20,10 @@ def create_app(config_object="lnbits.settings") -> Quart:
|
||||||
"""
|
"""
|
||||||
app = Quart(__name__, static_folder="static")
|
app = Quart(__name__, static_folder="static")
|
||||||
app.config.from_object(config_object)
|
app.config.from_object(config_object)
|
||||||
|
app.asgi_http_class = ASGIProxyFix
|
||||||
|
|
||||||
cors(app)
|
cors(app)
|
||||||
Compress(app)
|
Compress(app)
|
||||||
ProxyFix(app, x_proto=1, x_host=1)
|
|
||||||
|
|
||||||
register_assets(app)
|
register_assets(app)
|
||||||
register_blueprints(app)
|
register_blueprints(app)
|
||||||
|
|
|
@ -1,48 +1,46 @@
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Callable
|
||||||
|
from functools import partial
|
||||||
from urllib.request import parse_http_list as _parse_list_header
|
from urllib.request import parse_http_list as _parse_list_header
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from werkzeug.datastructures import Headers
|
||||||
|
|
||||||
from quart import request
|
from quart import Request
|
||||||
|
from quart.asgi import ASGIHTTPConnection
|
||||||
|
|
||||||
|
|
||||||
class ProxyFix:
|
class ASGIProxyFix(ASGIHTTPConnection):
|
||||||
def __init__(self, app=None, x_for: int = 1, x_proto: int = 1, x_host: int = 0, x_port: int = 0, x_prefix: int = 0):
|
def _create_request_from_scope(self, send: Callable) -> Request:
|
||||||
self.app = app
|
headers = Headers()
|
||||||
self.x_for = x_for
|
headers["Remote-Addr"] = (self.scope.get("client") or ["<local>"])[0]
|
||||||
self.x_proto = x_proto
|
for name, value in self.scope["headers"]:
|
||||||
self.x_host = x_host
|
headers.add(name.decode("latin1").title(), value.decode("latin1"))
|
||||||
self.x_port = x_port
|
if self.scope["http_version"] < "1.1":
|
||||||
self.x_prefix = x_prefix
|
headers.setdefault("Host", self.app.config["SERVER_NAME"] or "")
|
||||||
|
|
||||||
if app:
|
path = self.scope["path"]
|
||||||
self.init_app(app)
|
path = path if path[0] == "/" else urlparse(path).path
|
||||||
|
|
||||||
def init_app(self, app):
|
x_proto = self._get_real_value(1, headers.get("X-Forwarded-Proto"))
|
||||||
@app.before_request
|
if x_proto:
|
||||||
async def before_request():
|
self.scope["scheme"] = x_proto
|
||||||
x_for = self._get_real_value(self.x_for, request.headers.get("X-Forwarded-For"))
|
|
||||||
if x_for:
|
|
||||||
request.headers["Remote-Addr"] = x_for
|
|
||||||
|
|
||||||
x_proto = self._get_real_value(self.x_proto, request.headers.get("X-Forwarded-Proto"))
|
x_host = self._get_real_value(1, headers.get("X-Forwarded-Host"))
|
||||||
if x_proto:
|
if x_host:
|
||||||
request.scheme = x_proto
|
headers["host"] = x_host.lower()
|
||||||
|
|
||||||
x_host = self._get_real_value(self.x_host, request.headers.get("X-Forwarded-Host"))
|
return self.app.request_class(
|
||||||
if x_host:
|
self.scope["method"],
|
||||||
request.headers["host"] = x_host.lower()
|
self.scope["scheme"],
|
||||||
parts = x_host.split(":", 1)
|
path,
|
||||||
# environ["SERVER_NAME"] = parts[0]
|
self.scope["query_string"],
|
||||||
# if len(parts) == 2:
|
headers,
|
||||||
# environ["SERVER_PORT"] = parts[1]
|
self.scope.get("root_path", ""),
|
||||||
|
self.scope["http_version"],
|
||||||
x_port = self._get_real_value(self.x_port, request.headers.get("X-Forwarded-Port"))
|
max_content_length=self.app.config["MAX_CONTENT_LENGTH"],
|
||||||
if x_port:
|
body_timeout=self.app.config["BODY_TIMEOUT"],
|
||||||
host = request.host
|
send_push_promise=partial(self._send_push_promise, send),
|
||||||
if host:
|
scope=self.scope,
|
||||||
parts = host.split(":", 1)
|
)
|
||||||
host = parts[0] if len(parts) == 2 else host
|
|
||||||
request.headers["host"] = f"{host}:{x_port}"
|
|
||||||
# environ["SERVER_PORT"] = x_port
|
|
||||||
|
|
||||||
def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]:
|
def _get_real_value(self, trusted: int, value: Optional[str]) -> Optional[str]:
|
||||||
"""Get the real value from a list header based on the configured
|
"""Get the real value from a list header based on the configured
|
||||||
|
@ -95,6 +93,3 @@ class ProxyFix:
|
||||||
if not is_filename or value[:2] != "\\\\":
|
if not is_filename or value[:2] != "\\\\":
|
||||||
return value.replace("\\\\", "\\").replace('\\"', '"')
|
return value.replace("\\\\", "\\").replace('\\"', '"')
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
# host, request.root_path, subdomain, request.scheme, request.method, request.path, request.query_string.decode(),
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue