From 7ce4eddb0ed8d306a6d07c28c083550d27a816c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Tue, 12 Mar 2024 13:55:38 +0100 Subject: [PATCH] feat: add group_by to fetch_page (#2140) --------- Co-authored-by: Pavol Rusnak Co-authored-by: Vlad Stan --- lnbits/db.py | 15 ++++++- tests/core/test_db_fetch_page.py | 74 ++++++++++++++++++++++++++++++++ tests/core/test_helpers_query.py | 14 ++---- tests/helpers.py | 10 ++++- 4 files changed, 100 insertions(+), 13 deletions(-) create mode 100644 tests/core/test_db_fetch_page.py diff --git a/lnbits/db.py b/lnbits/db.py index be43d849c..2fed0bf62 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -179,16 +179,27 @@ class Connection(Compat): values: Optional[List[str]] = None, filters: Optional[Filters] = None, model: Optional[Type[TRowModel]] = None, + group_by: Optional[List[str]] = None, ) -> Page[TRowModel]: if not filters: filters = Filters() clause = filters.where(where) parsed_values = filters.values(values) + group_by_string = "" + if group_by: + for field in group_by: + if not re.fullmatch( + r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", field + ): + raise ValueError("Value for GROUP BY is invalid") + group_by_string = f"GROUP BY {', '.join(group_by)}" + rows = await self.fetchall( f""" {query} {clause} + {group_by_string} {filters.order_by()} {filters.pagination()} """, @@ -202,6 +213,7 @@ class Connection(Compat): SELECT COUNT(*) FROM ( {query} {clause} + {group_by_string} ) as count """, parsed_values, @@ -288,9 +300,10 @@ class Database(Compat): values: Optional[List[str]] = None, filters: Optional[Filters] = None, model: Optional[Type[TRowModel]] = None, + group_by: Optional[List[str]] = None, ) -> Page[TRowModel]: async with self.connect() as conn: - return await conn.fetch_page(query, where, values, filters, model) + return await conn.fetch_page(query, where, values, filters, model, group_by) async def execute(self, query: str, values: tuple = ()): async with self.connect() as conn: diff --git a/tests/core/test_db_fetch_page.py b/tests/core/test_db_fetch_page.py new file mode 100644 index 000000000..ba0c8409d --- /dev/null +++ b/tests/core/test_db_fetch_page.py @@ -0,0 +1,74 @@ +import pytest +import pytest_asyncio + +from tests.helpers import DbTestModel + + +@pytest_asyncio.fixture(scope="session") +async def fetch_page(db): + await db.execute("DROP TABLE IF EXISTS test_db_fetch_page") + await db.execute( + """ + CREATE TABLE test_db_fetch_page ( + id TEXT PRIMARY KEY, + value TEXT NOT NULL, + name TEXT NOT NULL + ) + """ + ) + await db.execute( + """ + INSERT INTO test_db_fetch_page (id, name, value) VALUES + ('1', 'Alice', 'foo'), + ('2', 'Bob', 'bar'), + ('3', 'Carol', 'bar'), + ('4', 'Dave', 'bar'), + ('5', 'Dave', 'foo') + """ + ) + yield + await db.execute("DROP TABLE test_db_fetch_page") + + +@pytest.mark.asyncio +async def test_db_fetch_page_simple(fetch_page, db): + row = await db.fetch_page( + query="select * from test_db_fetch_page", + model=DbTestModel, + ) + + assert row + assert row.total == 5 + assert len(row.data) == 5 + + +@pytest.mark.asyncio +async def test_db_fetch_page_group_by(fetch_page, db): + row = await db.fetch_page( + query="select max(id) as id, name from test_db_fetch_page", + model=DbTestModel, + group_by=["name"], + ) + assert row + assert row.total == 4 + + +@pytest.mark.asyncio +async def test_db_fetch_page_group_by_multiple(fetch_page, db): + row = await db.fetch_page( + query="select max(id) as id, name, value from test_db_fetch_page", + model=DbTestModel, + group_by=["value", "name"], + ) + assert row + assert row.total == 5 + + +@pytest.mark.asyncio +async def test_db_fetch_page_group_by_evil(fetch_page, db): + with pytest.raises(ValueError, match="Value for GROUP BY is invalid"): + await db.fetch_page( + query="select * from test_db_fetch_page", + model=DbTestModel, + group_by=["name;"], + ) diff --git a/tests/core/test_helpers_query.py b/tests/core/test_helpers_query.py index f81727687..3988eacf0 100644 --- a/tests/core/test_helpers_query.py +++ b/tests/core/test_helpers_query.py @@ -1,27 +1,21 @@ import pytest -from pydantic import BaseModel from lnbits.helpers import ( insert_query, update_query, ) +from tests.helpers import DbTestModel - -class DbTestModel(BaseModel): - id: int - name: str - - -test = DbTestModel(id=1, name="test") +test = DbTestModel(id=1, name="test", value="yes") @pytest.mark.asyncio async def test_helpers_insert_query(): q = insert_query("test_helpers_query", test) - assert q == "INSERT INTO test_helpers_query (id, name) VALUES (?, ?)" + assert q == "INSERT INTO test_helpers_query (id, name, value) VALUES (?, ?, ?)" @pytest.mark.asyncio async def test_helpers_update_query(): q = update_query("test_helpers_query", test) - assert q == "UPDATE test_helpers_query SET id = ?, name = ? WHERE id = ?" + assert q == "UPDATE test_helpers_query SET id = ?, name = ?, value = ? WHERE id = ?" diff --git a/tests/helpers.py b/tests/helpers.py index dc10d654a..4b649075f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -5,17 +5,23 @@ import random import string import time from subprocess import PIPE, Popen, TimeoutExpired -from typing import Tuple +from typing import Optional, Tuple from loguru import logger from psycopg2 import connect from psycopg2.errors import InvalidCatalogName from lnbits import core -from lnbits.db import DB_TYPE, POSTGRES +from lnbits.db import DB_TYPE, POSTGRES, FromRowModel from lnbits.wallets import get_wallet_class, set_wallet_class +class DbTestModel(FromRowModel): + id: int + name: str + value: Optional[str] = None + + def get_random_string(N: int = 10): return "".join( random.SystemRandom().choice(string.ascii_uppercase + string.digits)