From 8ce84ce592349ddc80a093928052974b7dc03c66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Mon, 3 Apr 2023 14:55:49 +0200 Subject: [PATCH] FEAT: Filters for GET requests, add it to GET /payments (#1557) * feat filters, add them to GET payments * add limit and offset to filters (#1563) * add limit and offset to filters * move filters example to parse_filters doc string * black * add openapi docs * remove example commentC * improve typing and make nested filter possible in openapi * typo in fn name * readd Type --------- Co-authored-by: jackstar12 <62219658+jackstar12@users.noreply.github.com> Co-authored-by: calle <93376500+callebtc@users.noreply.github.com> --- lnbits/core/crud.py | 26 +++----- lnbits/core/views/api.py | 22 ++++--- lnbits/db.py | 124 ++++++++++++++++++++++++++++++++++++++- lnbits/decorators.py | 29 +++++++++ lnbits/helpers.py | 46 ++++++++++++++- 5 files changed, 218 insertions(+), 29 deletions(-) diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index db6f503a2..f92249115 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -7,7 +7,7 @@ from uuid import uuid4 import shortuuid from lnbits import bolt11 -from lnbits.db import COCKROACH, POSTGRES, Connection +from lnbits.db import COCKROACH, POSTGRES, Connection, Filters from lnbits.extension_manager import InstallableExtension from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings @@ -347,8 +347,7 @@ async def get_payments( incoming: bool = False, since: Optional[int] = None, exclude_uncheckable: bool = False, - limit: Optional[int] = None, - offset: Optional[int] = None, + filters: Optional[Filters[Payment]] = None, conn: Optional[Connection] = None, ) -> List[Payment]: """ @@ -393,29 +392,20 @@ async def get_payments( clause.append("checking_id NOT LIKE 'temp_%'") clause.append("checking_id NOT LIKE 'internal_%'") - limit_clause = f"LIMIT {limit}" if type(limit) == int and limit > 0 else "" - offset_clause = f"OFFSET {offset}" if type(offset) == int and offset > 0 else "" - # combine limit and offset clauses - limit_offset_clause = ( - f"{limit_clause} {offset_clause}" - if limit_clause and offset_clause - else limit_clause or offset_clause - ) - - where = "" - if clause: - where = f"WHERE {' AND '.join(clause)}" + if not filters: + filters = Filters() rows = await (conn or db).fetchall( f""" SELECT * FROM apipayments - {where} + {filters.where(clause)} ORDER BY time DESC - {limit_offset_clause} + {filters.pagination()} """, - tuple(args), + filters.values(args), ) + return [Payment.from_row(row) for row in rows] diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 43ffffd98..8d4685a34 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -34,10 +34,12 @@ from lnbits.core.helpers import ( stop_extension_background_work, ) from lnbits.core.models import Payment, User, Wallet +from lnbits.db import Filters from lnbits.decorators import ( WalletTypeInfo, check_admin, get_key_type, + parse_filters, require_admin_key, require_invoice_key, ) @@ -48,7 +50,7 @@ from lnbits.extension_manager import ( InstallableExtension, get_valid_extensions, ) -from lnbits.helpers import url_for +from lnbits.helpers import generate_filter_params_openapi, url_for from lnbits.settings import get_wallet_class, settings from lnbits.utils.exchange_rates import ( currencies, @@ -114,18 +116,23 @@ async def api_update_wallet( } -@core_app.get("/api/v1/payments") +@core_app.get( + "/api/v1/payments", + name="Payment List", + summary="get list of payments", + response_description="list of payments", + response_model=List[Payment], + openapi_extra=generate_filter_params_openapi(Payment), +) async def api_payments( - limit: Optional[int] = None, - offset: Optional[int] = None, wallet: WalletTypeInfo = Depends(get_key_type), + filters: Filters = Depends(parse_filters(Payment)), ): pendingPayments = await get_payments( wallet_id=wallet.wallet.id, pending=True, exclude_uncheckable=True, - limit=limit, - offset=offset, + filters=filters, ) for payment in pendingPayments: await check_transaction_status( @@ -135,8 +142,7 @@ async def api_payments( wallet_id=wallet.wallet.id, pending=True, complete=True, - limit=limit, - offset=offset, + filters=filters, ) diff --git a/lnbits/db.py b/lnbits/db.py index 985a658bf..3af11e36c 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -4,9 +4,11 @@ import os import re import time from contextlib import asynccontextmanager -from typing import Optional +from enum import Enum +from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar from loguru import logger +from pydantic import BaseModel, ValidationError from sqlalchemy import create_engine from sqlalchemy_aio.base import AsyncConnection from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY @@ -224,3 +226,123 @@ class Database(Compat): @asynccontextmanager async def reuse_conn(self, conn: Connection): yield conn + + +class Operator(Enum): + GT = "gt" + LT = "lt" + EQ = "eq" + NE = "ne" + INCLUDE = "in" + EXCLUDE = "ex" + + @property + def as_sql(self): + if self == Operator.EQ: + return "=" + elif self == Operator.NE: + return "!=" + elif self == Operator.INCLUDE: + return "IN" + elif self == Operator.EXCLUDE: + return "NOT IN" + elif self == Operator.GT: + return ">" + elif self == Operator.LT: + return "<" + else: + raise ValueError("Unknown SQL Operator") + + +TModel = TypeVar("TModel", bound=BaseModel) + + +class Filter(BaseModel, Generic[TModel]): + field: str + nested: Optional[list[str]] + op: Operator = Operator.EQ + values: list[Any] + + @classmethod + def parse_query(cls, key: str, raw_values: list[Any], model: Type[TModel]): + # Key format: + # key[operator] + # e.g. name[eq] + if key.endswith("]"): + split = key[:-1].split("[") + if len(split) != 2: + raise ValueError("Invalid key") + field_names = split[0].split(".") + op = Operator(split[1]) + else: + field_names = key.split(".") + op = Operator("eq") + + field = field_names[0] + nested = field_names[1:] + + if field in model.__fields__: + compare_field = model.__fields__[field] + values = [] + for raw_value in raw_values: + # If there is a nested field, pydantic expects a dict, so the raw value is turned into a dict before + # and the converted value is extracted afterwards + for name in reversed(nested): + raw_value = {name: raw_value} + + validated, errors = compare_field.validate(raw_value, {}, loc="none") + if errors: + raise ValidationError(errors=[errors], model=model) + + for name in nested: + if isinstance(validated, dict): + validated = validated[name] + else: + validated = getattr(validated, name) + + values.append(validated) + else: + raise ValueError("Unknown filter field") + + return cls(field=field, op=op, nested=nested, values=values) + + @property + def statement(self): + accessor = self.field + if self.nested: + for name in self.nested: + accessor = f"({accessor} ->> '{name}')" + if self.op in (Operator.INCLUDE, Operator.EXCLUDE): + placeholders = ", ".join(["?"] * len(self.values)) + stmt = [f"{accessor} {self.op.as_sql} ({placeholders})"] + else: + stmt = [f"{accessor} {self.op.as_sql} ?"] * len(self.values) + return " OR ".join(stmt) + + +class Filters(BaseModel, Generic[TModel]): + filters: List[Filter[TModel]] = [] + limit: Optional[int] + offset: Optional[int] + + def pagination(self) -> str: + stmt = "" + if self.limit: + stmt += f"LIMIT {self.limit} " + if self.offset: + stmt += f"OFFSET {self.offset}" + return stmt + + def where(self, where_stmts: List[str]) -> str: + if self.filters: + for filter in self.filters: + where_stmts.append(filter.statement) + if where_stmts: + return "WHERE " + " AND ".join(where_stmts) + return "" + + def values(self, values: List[str]) -> Tuple: + if self.filters: + for filter in self.filters: + values.extend(filter.values) + return tuple(values) diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 17134f863..bd1c05207 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -1,15 +1,18 @@ from http import HTTPStatus +from typing import Optional, Type from fastapi import Security, status from fastapi.exceptions import HTTPException from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from fastapi.security.base import SecurityBase +from pydantic import BaseModel from pydantic.types import UUID4 from starlette.requests import Request from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.models import User, Wallet +from lnbits.db import Filter, Filters from lnbits.requestvars import g from lnbits.settings import settings @@ -266,3 +269,29 @@ async def check_super_user(usr: UUID4) -> User: detail="User not authorized. No super user privileges.", ) return user + + +def parse_filters(model: Type[BaseModel]): + """ + Parses the query params as filters. + :param model: model used for validation of filter values + """ + + def dependency( + request: Request, limit: Optional[int] = None, offset: Optional[int] = None + ): + params = request.query_params + filters = [] + for key in params.keys(): + try: + filters.append(Filter.parse_query(key, params.getlist(key), model)) + except ValueError: + continue + + return Filters( + filters=filters, + limit=limit, + offset=offset, + ) + + return dependency diff --git a/lnbits/helpers.py b/lnbits/helpers.py index d1a9cba81..184a8f857 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -1,7 +1,13 @@ -from typing import Any, List, Optional +from typing import Any, List, Optional, Type import jinja2 -import shortuuid # type: ignore +import shortuuid +from pydantic import BaseModel +from pydantic.schema import ( + field_schema, + get_flat_models_from_fields, + get_model_name_map, +) from lnbits.jinja2_templating import Jinja2Templates from lnbits.requestvars import g @@ -102,3 +108,39 @@ def get_current_extension_name() -> str: except: ext_name = extension_director_name return ext_name + + +def generate_filter_params_openapi(model: Type[BaseModel], keep_optional=False): + """ + Generate openapi documentation for Filters. This is intended to be used along parse_filters (see example) + :param model: Filter model + :param keep_optional: If false, all parameters will be optional, otherwise inferred from model + """ + fields = list(model.__fields__.values()) + models = get_flat_models_from_fields(fields, set()) + namemap = get_model_name_map(models) + params = [] + for field in fields: + schema, definitions, _ = field_schema(field, model_name_map=namemap) + + # Support nested definition + if "$ref" in schema: + name = schema["$ref"].split("/")[-1] + schema = definitions[name] + + description = "Supports Filtering" + if schema["type"] == "object": + description += f". Nested attributes can be filtered too, e.g. `{field.alias}.[additional].[attributes]`" + + parameter = { + "name": field.alias, + "in": "query", + "required": field.required if keep_optional else False, + "schema": schema, + "description": description, + } + params.append(parameter) + + return { + "parameters": params, + }