feat: add group_by to fetch_page (#2140)

---------

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
Co-authored-by: Vlad Stan <stan.v.vlad@gmail.com>
This commit is contained in:
dni ⚡ 2024-03-12 13:55:38 +01:00 committed by GitHub
parent 14519135d8
commit 7ce4eddb0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 100 additions and 13 deletions

View file

@ -179,16 +179,27 @@ class Connection(Compat):
values: Optional[List[str]] = None,
filters: Optional[Filters] = None,
model: Optional[Type[TRowModel]] = None,
group_by: Optional[List[str]] = None,
) -> Page[TRowModel]:
if not filters:
filters = Filters()
clause = filters.where(where)
parsed_values = filters.values(values)
group_by_string = ""
if group_by:
for field in group_by:
if not re.fullmatch(
r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", field
):
raise ValueError("Value for GROUP BY is invalid")
group_by_string = f"GROUP BY {', '.join(group_by)}"
rows = await self.fetchall(
f"""
{query}
{clause}
{group_by_string}
{filters.order_by()}
{filters.pagination()}
""",
@ -202,6 +213,7 @@ class Connection(Compat):
SELECT COUNT(*) FROM (
{query}
{clause}
{group_by_string}
) as count
""",
parsed_values,
@ -288,9 +300,10 @@ class Database(Compat):
values: Optional[List[str]] = None,
filters: Optional[Filters] = None,
model: Optional[Type[TRowModel]] = None,
group_by: Optional[List[str]] = None,
) -> Page[TRowModel]:
async with self.connect() as conn:
return await conn.fetch_page(query, where, values, filters, model)
return await conn.fetch_page(query, where, values, filters, model, group_by)
async def execute(self, query: str, values: tuple = ()):
async with self.connect() as conn:

View file

@ -0,0 +1,74 @@
import pytest
import pytest_asyncio
from tests.helpers import DbTestModel
@pytest_asyncio.fixture(scope="session")
async def fetch_page(db):
await db.execute("DROP TABLE IF EXISTS test_db_fetch_page")
await db.execute(
"""
CREATE TABLE test_db_fetch_page (
id TEXT PRIMARY KEY,
value TEXT NOT NULL,
name TEXT NOT NULL
)
"""
)
await db.execute(
"""
INSERT INTO test_db_fetch_page (id, name, value) VALUES
('1', 'Alice', 'foo'),
('2', 'Bob', 'bar'),
('3', 'Carol', 'bar'),
('4', 'Dave', 'bar'),
('5', 'Dave', 'foo')
"""
)
yield
await db.execute("DROP TABLE test_db_fetch_page")
@pytest.mark.asyncio
async def test_db_fetch_page_simple(fetch_page, db):
row = await db.fetch_page(
query="select * from test_db_fetch_page",
model=DbTestModel,
)
assert row
assert row.total == 5
assert len(row.data) == 5
@pytest.mark.asyncio
async def test_db_fetch_page_group_by(fetch_page, db):
row = await db.fetch_page(
query="select max(id) as id, name from test_db_fetch_page",
model=DbTestModel,
group_by=["name"],
)
assert row
assert row.total == 4
@pytest.mark.asyncio
async def test_db_fetch_page_group_by_multiple(fetch_page, db):
row = await db.fetch_page(
query="select max(id) as id, name, value from test_db_fetch_page",
model=DbTestModel,
group_by=["value", "name"],
)
assert row
assert row.total == 5
@pytest.mark.asyncio
async def test_db_fetch_page_group_by_evil(fetch_page, db):
with pytest.raises(ValueError, match="Value for GROUP BY is invalid"):
await db.fetch_page(
query="select * from test_db_fetch_page",
model=DbTestModel,
group_by=["name;"],
)

View file

@ -1,27 +1,21 @@
import pytest
from pydantic import BaseModel
from lnbits.helpers import (
insert_query,
update_query,
)
from tests.helpers import DbTestModel
class DbTestModel(BaseModel):
id: int
name: str
test = DbTestModel(id=1, name="test")
test = DbTestModel(id=1, name="test", value="yes")
@pytest.mark.asyncio
async def test_helpers_insert_query():
q = insert_query("test_helpers_query", test)
assert q == "INSERT INTO test_helpers_query (id, name) VALUES (?, ?)"
assert q == "INSERT INTO test_helpers_query (id, name, value) VALUES (?, ?, ?)"
@pytest.mark.asyncio
async def test_helpers_update_query():
q = update_query("test_helpers_query", test)
assert q == "UPDATE test_helpers_query SET id = ?, name = ? WHERE id = ?"
assert q == "UPDATE test_helpers_query SET id = ?, name = ?, value = ? WHERE id = ?"

View file

@ -5,17 +5,23 @@ import random
import string
import time
from subprocess import PIPE, Popen, TimeoutExpired
from typing import Tuple
from typing import Optional, Tuple
from loguru import logger
from psycopg2 import connect
from psycopg2.errors import InvalidCatalogName
from lnbits import core
from lnbits.db import DB_TYPE, POSTGRES
from lnbits.db import DB_TYPE, POSTGRES, FromRowModel
from lnbits.wallets import get_wallet_class, set_wallet_class
class DbTestModel(FromRowModel):
id: int
name: str
value: Optional[str] = None
def get_random_string(N: int = 10):
return "".join(
random.SystemRandom().choice(string.ascii_uppercase + string.digits)