mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-02-22 06:21:53 +01:00
mega chore: update sqlalchemy (#2611)
* update sqlalchemy to 1.4 * async postgres --------- Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
This commit is contained in:
parent
c637e8d31e
commit
21d87adc52
17 changed files with 1020 additions and 951 deletions
5
.github/actions/prepare/action.yml
vendored
5
.github/actions/prepare/action.yml
vendored
|
@ -46,7 +46,10 @@ runs:
|
|||
|
||||
- name: Install the project dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
run: |
|
||||
poetry install
|
||||
# needed for conv tests
|
||||
poetry add psycopg2-binary
|
||||
|
||||
- name: Use Node.js ${{ inputs.node-version }}
|
||||
if: ${{ (inputs.npm == 'true') }}
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
meta.rev = self.dirtyRev or self.rev;
|
||||
meta.mainProgram = projectName;
|
||||
overrides = pkgs.poetry2nix.overrides.withDefaults (final: prev: {
|
||||
coincurve = prev.coincurve.override { preferWheel = true; };
|
||||
protobuf = prev.protobuf.override { preferWheel = true; };
|
||||
ruff = prev.ruff.override { preferWheel = true; };
|
||||
wallycore = prev.wallycore.override { preferWheel = true; };
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,3 @@
|
|||
import datetime
|
||||
from time import time
|
||||
|
||||
from loguru import logger
|
||||
|
@ -102,7 +101,7 @@ async def m002_add_fields_to_apipayments(db):
|
|||
|
||||
import json
|
||||
|
||||
rows = await (await db.execute("SELECT * FROM apipayments")).fetchall()
|
||||
rows = await db.fetchall("SELECT * FROM apipayments")
|
||||
for row in rows:
|
||||
if not row["memo"] or not row["memo"].startswith("#"):
|
||||
continue
|
||||
|
@ -113,15 +112,15 @@ async def m002_add_fields_to_apipayments(db):
|
|||
new = row["memo"][len(prefix) :]
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE apipayments SET extra = ?, memo = ?
|
||||
WHERE checking_id = ? AND memo = ?
|
||||
UPDATE apipayments SET extra = :extra, memo = :memo1
|
||||
WHERE checking_id = :checking_id AND memo = :memo2
|
||||
""",
|
||||
(
|
||||
json.dumps({"tag": ext}),
|
||||
new,
|
||||
row["checking_id"],
|
||||
row["memo"],
|
||||
),
|
||||
{
|
||||
"extra": json.dumps({"tag": ext}),
|
||||
"memo1": new,
|
||||
"checking_id": row["checking_id"],
|
||||
"memo2": row["memo"],
|
||||
},
|
||||
)
|
||||
break
|
||||
except OperationalError:
|
||||
|
@ -212,19 +211,17 @@ async def m007_set_invoice_expiries(db):
|
|||
Precomputes invoice expiry for existing pending incoming payments.
|
||||
"""
|
||||
try:
|
||||
rows = await (
|
||||
await db.execute(
|
||||
f"""
|
||||
SELECT bolt11, checking_id
|
||||
FROM apipayments
|
||||
WHERE pending = true
|
||||
AND amount > 0
|
||||
AND bolt11 IS NOT NULL
|
||||
AND expiry IS NULL
|
||||
AND time < {db.timestamp_now}
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
rows = await db.fetchall(
|
||||
f"""
|
||||
SELECT bolt11, checking_id
|
||||
FROM apipayments
|
||||
WHERE pending = true
|
||||
AND amount > 0
|
||||
AND bolt11 IS NOT NULL
|
||||
AND expiry IS NULL
|
||||
AND time < {db.timestamp_now}
|
||||
"""
|
||||
)
|
||||
if len(rows):
|
||||
logger.info(f"Migration: Checking expiry of {len(rows)} invoices")
|
||||
for i, (
|
||||
|
@ -236,22 +233,17 @@ async def m007_set_invoice_expiries(db):
|
|||
if invoice.expiry is None:
|
||||
continue
|
||||
|
||||
expiration_date = datetime.datetime.fromtimestamp(
|
||||
invoice.date + invoice.expiry
|
||||
)
|
||||
expiration_date = invoice.date + invoice.expiry
|
||||
logger.info(
|
||||
f"Migration: {i+1}/{len(rows)} setting expiry of invoice"
|
||||
f" {invoice.payment_hash} to {expiration_date}"
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE apipayments SET expiry = ?
|
||||
WHERE checking_id = ? AND amount > 0
|
||||
f"""
|
||||
UPDATE apipayments SET expiry = {db.timestamp_placeholder('expiry')}
|
||||
WHERE checking_id = :checking_id AND amount > 0
|
||||
""",
|
||||
(
|
||||
db.datetime_to_timestamp(expiration_date),
|
||||
checking_id,
|
||||
),
|
||||
{"expiry": expiration_date, "checking_id": checking_id},
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
@ -347,17 +339,15 @@ async def m014_set_deleted_wallets(db):
|
|||
Sets deleted column to wallets.
|
||||
"""
|
||||
try:
|
||||
rows = await (
|
||||
await db.execute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM wallets
|
||||
WHERE user LIKE 'del:%'
|
||||
AND adminkey LIKE 'del:%'
|
||||
AND inkey LIKE 'del:%'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
rows = await db.fetchall(
|
||||
"""
|
||||
SELECT *
|
||||
FROM wallets
|
||||
WHERE user LIKE 'del:%'
|
||||
AND adminkey LIKE 'del:%'
|
||||
AND inkey LIKE 'del:%'
|
||||
"""
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
|
@ -367,10 +357,15 @@ async def m014_set_deleted_wallets(db):
|
|||
await db.execute(
|
||||
"""
|
||||
UPDATE wallets SET
|
||||
"user" = ?, adminkey = ?, inkey = ?, deleted = true
|
||||
WHERE id = ?
|
||||
"user" = :user, adminkey = :adminkey, inkey = :inkey, deleted = true
|
||||
WHERE id = :wallet
|
||||
""",
|
||||
(user, adminkey, inkey, row[0]),
|
||||
{
|
||||
"user": user,
|
||||
"adminkey": adminkey,
|
||||
"inkey": inkey,
|
||||
"wallet": row.get("id"),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
@ -456,17 +451,17 @@ async def m017_add_timestamp_columns_to_accounts_and_wallets(db):
|
|||
now = int(time())
|
||||
await db.execute(
|
||||
f"""
|
||||
UPDATE wallets SET created_at = {db.timestamp_placeholder}
|
||||
UPDATE wallets SET created_at = {db.timestamp_placeholder('now')}
|
||||
WHERE created_at IS NULL
|
||||
""",
|
||||
(now,),
|
||||
{"now": now},
|
||||
)
|
||||
await db.execute(
|
||||
f"""
|
||||
UPDATE accounts SET created_at = {db.timestamp_placeholder}
|
||||
UPDATE accounts SET created_at = {db.timestamp_placeholder('now')}
|
||||
WHERE created_at IS NULL
|
||||
""",
|
||||
(now,),
|
||||
{"now": now},
|
||||
)
|
||||
|
||||
except OperationalError as exc:
|
||||
|
|
|
@ -7,7 +7,6 @@ import json
|
|||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from sqlite3 import Row
|
||||
from typing import Callable, Optional
|
||||
|
||||
from ecdsa import SECP256k1, SigningKey
|
||||
|
@ -240,7 +239,7 @@ class Payment(FromRowModel):
|
|||
return self.status == PaymentState.FAILED.value
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: Row):
|
||||
def from_row(cls, row: dict):
|
||||
return cls(
|
||||
checking_id=row["checking_id"],
|
||||
payment_hash=row["hash"] or "0" * 64,
|
||||
|
@ -347,7 +346,7 @@ class TinyURL(BaseModel):
|
|||
time: float
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: Row):
|
||||
def from_row(cls, row: dict):
|
||||
return cls(**dict(row))
|
||||
|
||||
|
||||
|
|
243
lnbits/db.py
243
lnbits/db.py
|
@ -7,14 +7,13 @@ import re
|
|||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from sqlite3 import Row
|
||||
from typing import Any, Generic, Literal, Optional, TypeVar
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError, root_validator
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy_aio.base import AsyncConnection
|
||||
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
@ -24,31 +23,15 @@ SQLITE = "SQLITE"
|
|||
|
||||
if settings.lnbits_database_url:
|
||||
database_uri = settings.lnbits_database_url
|
||||
|
||||
if database_uri.startswith("cockroachdb://"):
|
||||
DB_TYPE = COCKROACH
|
||||
else:
|
||||
if not database_uri.startswith("postgres://"):
|
||||
raise ValueError(
|
||||
"Please use the 'postgres://...' " "format for the database URL."
|
||||
)
|
||||
DB_TYPE = POSTGRES
|
||||
|
||||
from psycopg2.extensions import DECIMAL, new_type, register_type
|
||||
|
||||
def _parse_timestamp(value, _):
|
||||
if value is None:
|
||||
return None
|
||||
f = "%Y-%m-%d %H:%M:%S.%f"
|
||||
if "." not in value:
|
||||
f = "%Y-%m-%d %H:%M:%S"
|
||||
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||
|
||||
register_type(
|
||||
new_type(
|
||||
DECIMAL.values,
|
||||
"DEC2FLOAT",
|
||||
lambda value, curs: float(value) if value is not None else None,
|
||||
)
|
||||
)
|
||||
|
||||
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
|
||||
else:
|
||||
if not os.path.isdir(settings.lnbits_data_folder):
|
||||
os.mkdir(settings.lnbits_data_folder)
|
||||
|
@ -56,21 +39,21 @@ else:
|
|||
DB_TYPE = SQLITE
|
||||
|
||||
|
||||
def compat_timestamp_placeholder():
|
||||
def compat_timestamp_placeholder(key: str):
|
||||
if DB_TYPE == POSTGRES:
|
||||
return "to_timestamp(?)"
|
||||
return f"to_timestamp(:{key})"
|
||||
elif DB_TYPE == COCKROACH:
|
||||
return "cast(? AS timestamp)"
|
||||
return f"cast(:{key} AS timestamp)"
|
||||
else:
|
||||
return "?"
|
||||
return f":{key}"
|
||||
|
||||
|
||||
def get_placeholder(model: Any, field: str) -> str:
|
||||
type_ = model.__fields__[field].type_
|
||||
if type_ == datetime.datetime:
|
||||
return compat_timestamp_placeholder()
|
||||
return compat_timestamp_placeholder(field)
|
||||
else:
|
||||
return "?"
|
||||
return f":{field}"
|
||||
|
||||
|
||||
class Compat:
|
||||
|
@ -127,15 +110,13 @@ class Compat:
|
|||
return "BIGINT"
|
||||
return "INT"
|
||||
|
||||
@property
|
||||
def timestamp_placeholder(self) -> str:
|
||||
return compat_timestamp_placeholder()
|
||||
def timestamp_placeholder(self, key: str) -> str:
|
||||
return compat_timestamp_placeholder(key)
|
||||
|
||||
|
||||
class Connection(Compat):
|
||||
def __init__(self, conn: AsyncConnection, txn, typ, name, schema):
|
||||
def __init__(self, conn: AsyncConnection, typ, name, schema):
|
||||
self.conn = conn
|
||||
self.txn = txn
|
||||
self.type = typ
|
||||
self.name = name
|
||||
self.schema = schema
|
||||
|
@ -146,45 +127,42 @@ class Connection(Compat):
|
|||
query = query.replace("?", "%s")
|
||||
return query
|
||||
|
||||
def rewrite_values(self, values):
|
||||
def rewrite_values(self, values: dict) -> dict:
|
||||
# strip html
|
||||
clean_regex = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
|
||||
|
||||
# tuple to list and back to tuple
|
||||
raw_values = [values] if isinstance(values, str) else list(values)
|
||||
values = []
|
||||
for raw_value in raw_values:
|
||||
clean_values: dict = {}
|
||||
for key, raw_value in values.items():
|
||||
if isinstance(raw_value, str):
|
||||
values.append(re.sub(clean_regex, "", raw_value))
|
||||
clean_values[key] = re.sub(clean_regex, "", raw_value)
|
||||
elif isinstance(raw_value, datetime.datetime):
|
||||
ts = raw_value.timestamp()
|
||||
if self.type == SQLITE:
|
||||
values.append(int(ts))
|
||||
clean_values[key] = int(ts)
|
||||
else:
|
||||
values.append(ts)
|
||||
clean_values[key] = ts
|
||||
else:
|
||||
values.append(raw_value)
|
||||
return tuple(values)
|
||||
clean_values[key] = raw_value
|
||||
return clean_values
|
||||
|
||||
async def fetchall(self, query: str, values: tuple = ()) -> list:
|
||||
result = await self.conn.execute(
|
||||
self.rewrite_query(query), self.rewrite_values(values)
|
||||
)
|
||||
return await result.fetchall()
|
||||
async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]:
|
||||
params = self.rewrite_values(values) if values else {}
|
||||
result = await self.conn.execute(text(self.rewrite_query(query)), params)
|
||||
row = result.mappings().all()
|
||||
result.close()
|
||||
return row
|
||||
|
||||
async def fetchone(self, query: str, values: tuple = ()):
|
||||
result = await self.conn.execute(
|
||||
self.rewrite_query(query), self.rewrite_values(values)
|
||||
)
|
||||
row = await result.fetchone()
|
||||
await result.close()
|
||||
async def fetchone(self, query: str, values: Optional[dict] = None) -> dict:
|
||||
params = self.rewrite_values(values) if values else {}
|
||||
result = await self.conn.execute(text(self.rewrite_query(query)), params)
|
||||
row = result.mappings().first()
|
||||
result.close()
|
||||
return row
|
||||
|
||||
async def fetch_page(
|
||||
self,
|
||||
query: str,
|
||||
where: Optional[list[str]] = None,
|
||||
values: Optional[list[str]] = None,
|
||||
values: Optional[dict] = None,
|
||||
filters: Optional[Filters] = None,
|
||||
model: Optional[type[TRowModel]] = None,
|
||||
group_by: Optional[list[str]] = None,
|
||||
|
@ -211,14 +189,14 @@ class Connection(Compat):
|
|||
{filters.order_by()}
|
||||
{filters.pagination()}
|
||||
""",
|
||||
parsed_values,
|
||||
self.rewrite_values(parsed_values),
|
||||
)
|
||||
if rows:
|
||||
# no need for extra query if no pagination is specified
|
||||
if filters.offset or filters.limit:
|
||||
count = await self.fetchone(
|
||||
result = await self.fetchone(
|
||||
f"""
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT COUNT(*) as count FROM (
|
||||
{query}
|
||||
{clause}
|
||||
{group_by_string}
|
||||
|
@ -226,21 +204,22 @@ class Connection(Compat):
|
|||
""",
|
||||
parsed_values,
|
||||
)
|
||||
count = int(count[0])
|
||||
count = int(result.get("count", 0))
|
||||
else:
|
||||
count = len(rows)
|
||||
else:
|
||||
count = 0
|
||||
|
||||
return Page(
|
||||
data=[model.from_row(row) for row in rows] if model else rows,
|
||||
data=[model.from_row(row) for row in rows] if model else [],
|
||||
total=count,
|
||||
)
|
||||
|
||||
async def execute(self, query: str, values: tuple = ()):
|
||||
return await self.conn.execute(
|
||||
self.rewrite_query(query), self.rewrite_values(values)
|
||||
)
|
||||
async def execute(self, query: str, values: Optional[dict] = None):
|
||||
params = self.rewrite_values(values) if values else {}
|
||||
result = await self.conn.execute(text(self.rewrite_query(query)), params)
|
||||
await self.conn.commit()
|
||||
return result
|
||||
|
||||
|
||||
class Database(Compat):
|
||||
|
@ -253,18 +232,44 @@ class Database(Compat):
|
|||
self.path = os.path.join(
|
||||
settings.lnbits_data_folder, f"{self.name}.sqlite3"
|
||||
)
|
||||
database_uri = f"sqlite:///{self.path}"
|
||||
database_uri = f"sqlite+aiosqlite:///{self.path}"
|
||||
else:
|
||||
database_uri = settings.lnbits_database_url
|
||||
database_uri = settings.lnbits_database_url.replace(
|
||||
"postgres://", "postgresql+asyncpg://"
|
||||
)
|
||||
|
||||
if self.name.startswith("ext_"):
|
||||
self.schema = self.name[4:]
|
||||
else:
|
||||
self.schema = None
|
||||
|
||||
self.engine = create_engine(
|
||||
database_uri, strategy=ASYNCIO_STRATEGY, echo=settings.debug_database
|
||||
self.engine: AsyncEngine = create_async_engine(
|
||||
database_uri, echo=settings.debug_database
|
||||
)
|
||||
|
||||
if self.type in {POSTGRES, COCKROACH}:
|
||||
|
||||
@event.listens_for(self.engine.sync_engine, "connect")
|
||||
def register_custom_types(dbapi_connection, *_):
|
||||
def _parse_timestamp(value):
|
||||
if value is None:
|
||||
return None
|
||||
f = "%Y-%m-%d %H:%M:%S.%f"
|
||||
if "." not in value:
|
||||
f = "%Y-%m-%d %H:%M:%S"
|
||||
return int(
|
||||
time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||
)
|
||||
|
||||
dbapi_connection.run_async(
|
||||
lambda connection: connection.set_type_codec(
|
||||
"TIMESTAMP",
|
||||
encoder=datetime.datetime,
|
||||
decoder=_parse_timestamp,
|
||||
schema="pg_catalog",
|
||||
)
|
||||
)
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
logger.trace(f"database {self.type} added for {self.name}")
|
||||
|
@ -273,41 +278,37 @@ class Database(Compat):
|
|||
async def connect(self):
|
||||
await self.lock.acquire()
|
||||
try:
|
||||
async with self.engine.connect() as conn: # type: ignore
|
||||
async with conn.begin() as txn:
|
||||
wconn = Connection(conn, txn, self.type, self.name, self.schema)
|
||||
async with self.engine.connect() as conn:
|
||||
if not conn:
|
||||
raise Exception("Could not connect to the database")
|
||||
|
||||
if self.schema:
|
||||
if self.type in {POSTGRES, COCKROACH}:
|
||||
await wconn.execute(
|
||||
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
|
||||
)
|
||||
elif self.type == SQLITE:
|
||||
await wconn.execute(
|
||||
f"ATTACH '{self.path}' AS {self.schema}"
|
||||
)
|
||||
wconn = Connection(conn, self.type, self.name, self.schema)
|
||||
|
||||
yield wconn
|
||||
if self.schema:
|
||||
if self.type in {POSTGRES, COCKROACH}:
|
||||
await wconn.execute(
|
||||
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
|
||||
)
|
||||
elif self.type == SQLITE:
|
||||
await wconn.execute(f"ATTACH '{self.path}' AS {self.schema}")
|
||||
|
||||
yield wconn
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
||||
async def fetchall(self, query: str, values: tuple = ()) -> list:
|
||||
async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]:
|
||||
async with self.connect() as conn:
|
||||
result = await conn.execute(query, values)
|
||||
return await result.fetchall()
|
||||
return await conn.fetchall(query, values)
|
||||
|
||||
async def fetchone(self, query: str, values: tuple = ()):
|
||||
async def fetchone(self, query: str, values: Optional[dict] = None) -> dict:
|
||||
async with self.connect() as conn:
|
||||
result = await conn.execute(query, values)
|
||||
row = await result.fetchone()
|
||||
await result.close()
|
||||
return row
|
||||
return await conn.fetchone(query, values)
|
||||
|
||||
async def fetch_page(
|
||||
self,
|
||||
query: str,
|
||||
where: Optional[list[str]] = None,
|
||||
values: Optional[list[str]] = None,
|
||||
values: Optional[dict] = None,
|
||||
filters: Optional[Filters] = None,
|
||||
model: Optional[type[TRowModel]] = None,
|
||||
group_by: Optional[list[str]] = None,
|
||||
|
@ -315,7 +316,7 @@ class Database(Compat):
|
|||
async with self.connect() as conn:
|
||||
return await conn.fetch_page(query, where, values, filters, model, group_by)
|
||||
|
||||
async def execute(self, query: str, values: tuple = ()):
|
||||
async def execute(self, query: str, values: Optional[dict] = None):
|
||||
async with self.connect() as conn:
|
||||
return await conn.execute(query, values)
|
||||
|
||||
|
@ -373,8 +374,8 @@ class Operator(Enum):
|
|||
|
||||
class FromRowModel(BaseModel):
|
||||
@classmethod
|
||||
def from_row(cls, row: Row):
|
||||
return cls(**dict(row))
|
||||
def from_row(cls, row: dict):
|
||||
return cls(**row)
|
||||
|
||||
|
||||
class FilterModel(BaseModel):
|
||||
|
@ -396,12 +397,13 @@ class Page(BaseModel, Generic[T]):
|
|||
class Filter(BaseModel, Generic[TFilterModel]):
|
||||
field: str
|
||||
op: Operator = Operator.EQ
|
||||
values: list[Any]
|
||||
|
||||
model: Optional[type[TFilterModel]]
|
||||
values: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def parse_query(cls, key: str, raw_values: list[Any], model: type[TFilterModel]):
|
||||
def parse_query(
|
||||
cls, key: str, raw_values: list[Any], model: type[TFilterModel], i: int = 0
|
||||
):
|
||||
# Key format:
|
||||
# key[operator]
|
||||
# e.g. name[eq]
|
||||
|
@ -417,12 +419,12 @@ class Filter(BaseModel, Generic[TFilterModel]):
|
|||
|
||||
if field in model.__fields__:
|
||||
compare_field = model.__fields__[field]
|
||||
values = []
|
||||
values: dict = {}
|
||||
for raw_value in raw_values:
|
||||
validated, errors = compare_field.validate(raw_value, {}, loc="none")
|
||||
if errors:
|
||||
raise ValidationError(errors=[errors], model=model)
|
||||
values.append(validated)
|
||||
values[f"{field}__{i}"] = validated
|
||||
else:
|
||||
raise ValueError("Unknown filter field")
|
||||
|
||||
|
@ -430,13 +432,17 @@ class Filter(BaseModel, Generic[TFilterModel]):
|
|||
|
||||
@property
|
||||
def statement(self):
|
||||
assert self.model, "Model is required for statement generation"
|
||||
placeholder = get_placeholder(self.model, self.field)
|
||||
if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
|
||||
placeholders = ", ".join([placeholder] * len(self.values))
|
||||
stmt = [f"{self.field} {self.op.as_sql} ({placeholders})"]
|
||||
else:
|
||||
stmt = [f"{self.field} {self.op.as_sql} {placeholder}"] * len(self.values)
|
||||
stmt = []
|
||||
for key in self.values.keys() if self.values else []:
|
||||
clean_key = key.split("__")[0]
|
||||
if (
|
||||
self.model
|
||||
and self.model.__fields__[clean_key].type_ == datetime.datetime
|
||||
):
|
||||
placeholder = compat_timestamp_placeholder(key)
|
||||
else:
|
||||
placeholder = f":{key}"
|
||||
stmt.append(f"{clean_key} {self.op.as_sql} {placeholder}")
|
||||
return " OR ".join(stmt)
|
||||
|
||||
|
||||
|
@ -487,14 +493,11 @@ class Filters(BaseModel, Generic[TFilterModel]):
|
|||
for page_filter in self.filters:
|
||||
where_stmts.append(page_filter.statement)
|
||||
if self.search and self.model:
|
||||
fields = self.model.__search_fields__
|
||||
if DB_TYPE == POSTGRES:
|
||||
where_stmts.append(
|
||||
f"lower(concat({', '.join(self.model.__search_fields__)})) LIKE ?"
|
||||
)
|
||||
where_stmts.append(f"lower(concat({', '.join(fields)})) LIKE :search")
|
||||
elif DB_TYPE == SQLITE:
|
||||
where_stmts.append(
|
||||
f"lower({'||'.join(self.model.__search_fields__)}) LIKE ?"
|
||||
)
|
||||
where_stmts.append(f"lower({'||'.join(fields)}) LIKE :search")
|
||||
if where_stmts:
|
||||
return "WHERE " + " AND ".join(where_stmts)
|
||||
return ""
|
||||
|
@ -504,12 +507,14 @@ class Filters(BaseModel, Generic[TFilterModel]):
|
|||
return f"ORDER BY {self.sortby} {self.direction or 'asc'}"
|
||||
return ""
|
||||
|
||||
def values(self, values: Optional[list[str]] = None) -> tuple:
|
||||
def values(self, values: Optional[dict] = None) -> dict:
|
||||
if not values:
|
||||
values = []
|
||||
values = {}
|
||||
if self.filters:
|
||||
for page_filter in self.filters:
|
||||
values.extend(page_filter.values)
|
||||
if page_filter.values:
|
||||
for key, value in page_filter.values.items():
|
||||
values[key] = value
|
||||
if self.search and self.model:
|
||||
values.append(f"%{self.search}%")
|
||||
return tuple(values)
|
||||
values["search"] = f"%{self.search}%"
|
||||
return values
|
||||
|
|
|
@ -204,9 +204,9 @@ def parse_filters(model: Type[TFilterModel]):
|
|||
):
|
||||
params = request.query_params
|
||||
filters = []
|
||||
for key in params.keys():
|
||||
for i, key in enumerate(params.keys()):
|
||||
try:
|
||||
filters.append(Filter.parse_query(key, params.getlist(key), model))
|
||||
filters.append(Filter.parse_query(key, params.getlist(key), model, i))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
|
|
|
@ -187,12 +187,14 @@ def insert_query(table_name: str, model: BaseModel) -> str:
|
|||
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})"
|
||||
|
||||
|
||||
def update_query(table_name: str, model: BaseModel, where: str = "WHERE id = ?") -> str:
|
||||
def update_query(
|
||||
table_name: str, model: BaseModel, where: str = "WHERE id = :id"
|
||||
) -> str:
|
||||
"""
|
||||
Generate an update query with placeholders for a given table and model
|
||||
:param table_name: Name of the table
|
||||
:param model: Pydantic model
|
||||
:param where: Where string, default to `WHERE id = ?`
|
||||
:param where: Where string, default to `WHERE id = :id`
|
||||
"""
|
||||
fields = []
|
||||
for field in model.dict().keys():
|
||||
|
|
|
@ -74,10 +74,10 @@ def configure_logger() -> None:
|
|||
logging.getLogger("uvicorn.error").propagate = False
|
||||
|
||||
logging.getLogger("sqlalchemy").handlers = [InterceptHandler()]
|
||||
logging.getLogger("sqlalchemy.engine.base").handlers = [InterceptHandler()]
|
||||
logging.getLogger("sqlalchemy.engine.base").propagate = False
|
||||
logging.getLogger("sqlalchemy.engine.base.Engine").handlers = [InterceptHandler()]
|
||||
logging.getLogger("sqlalchemy.engine.base.Engine").propagate = False
|
||||
logging.getLogger("sqlalchemy.engine").handlers = [InterceptHandler()]
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
logging.getLogger("sqlalchemy.engine.Engine").handlers = [InterceptHandler()]
|
||||
logging.getLogger("sqlalchemy.engine.Engine").propagate = False
|
||||
|
||||
|
||||
class Formatter:
|
||||
|
|
997
poetry.lock
generated
997
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -16,25 +16,25 @@ python = "^3.12 | ^3.11 | ^3.10 | ^3.9"
|
|||
bech32 = "1.2.0"
|
||||
click = "8.1.7"
|
||||
ecdsa = "0.19.0"
|
||||
fastapi = "0.112.0"
|
||||
fastapi = "0.113.0"
|
||||
httpx = "0.27.0"
|
||||
jinja2 = "3.1.4"
|
||||
lnurl = "0.5.3"
|
||||
psycopg2-binary = "2.9.9"
|
||||
pydantic = "1.10.17"
|
||||
pydantic = "1.10.18"
|
||||
pyqrcode = "1.2.1"
|
||||
shortuuid = "1.0.13"
|
||||
sqlalchemy = "1.3.24"
|
||||
sqlalchemy-aio = "0.17.0"
|
||||
sse-starlette = "1.8.2"
|
||||
typing-extensions = "4.12.2"
|
||||
uvicorn = "0.30.5"
|
||||
uvicorn = "0.30.6"
|
||||
sqlalchemy = "1.4.54"
|
||||
aiosqlite = "0.20.0"
|
||||
asyncpg = "0.29.0"
|
||||
uvloop = "0.19.0"
|
||||
websockets = "11.0.3"
|
||||
loguru = "0.7.2"
|
||||
grpcio = "1.65.5"
|
||||
protobuf = "5.27.3"
|
||||
pyln-client = "24.5"
|
||||
grpcio = "1.66.1"
|
||||
protobuf = "5.28.0"
|
||||
pyln-client = "24.8.1"
|
||||
pywebpush = "1.14.1"
|
||||
slowapi = "0.1.9"
|
||||
websocket-client = "1.8.0"
|
||||
|
@ -70,11 +70,11 @@ black = "^24.8.0"
|
|||
pytest-asyncio = "^0.21.2"
|
||||
pytest = "^8.3.2"
|
||||
pytest-cov = "^4.1.0"
|
||||
mypy = "^1.11.1"
|
||||
mypy = "^1.11.2"
|
||||
types-protobuf = "^5.27.0.20240626"
|
||||
pre-commit = "^3.8.0"
|
||||
openapi-spec-validator = "^0.7.1"
|
||||
ruff = "^0.5.7"
|
||||
ruff = "^0.6.4"
|
||||
types-passlib = "^1.7.7.20240327"
|
||||
openai = "^1.39.0"
|
||||
json5 = "^0.9.25"
|
||||
|
@ -84,7 +84,7 @@ pytest-httpserver = "^1.1.0"
|
|||
pytest-mock = "^3.14.0"
|
||||
types-mock = "^5.1.0.20240425"
|
||||
mock = "^5.1.0"
|
||||
grpcio-tools = "^1.65.5"
|
||||
grpcio-tools = "^1.66.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
@ -126,7 +126,6 @@ module = [
|
|||
"secp256k1.*",
|
||||
"uvicorn.*",
|
||||
"sqlalchemy.*",
|
||||
"sqlalchemy_aio.*",
|
||||
"websocket.*",
|
||||
"websockets.*",
|
||||
"pyqrcode.*",
|
||||
|
@ -136,7 +135,6 @@ module = [
|
|||
"bolt11.*",
|
||||
"bitstring.*",
|
||||
"ecdsa.*",
|
||||
"psycopg2.*",
|
||||
"pyngrok.*",
|
||||
"pyln.client.*",
|
||||
"py_vapid.*",
|
||||
|
|
|
@ -367,11 +367,11 @@ async def test_get_payments_history(client, adminkey_headers_from, fake_payments
|
|||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["spending"] == sum(
|
||||
payment.amount * 1000 for payment in fake_data if payment.out
|
||||
)
|
||||
assert data[0]["income"] == sum(
|
||||
payment.amount * 1000 for payment in fake_data if not payment.out
|
||||
[int(payment.amount * 1000) for payment in fake_data if not payment.out]
|
||||
)
|
||||
assert data[0]["spending"] == sum(
|
||||
[int(payment.amount * 1000) for payment in fake_data if payment.out]
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
|
|
|
@ -25,7 +25,6 @@ from lnbits.core.views.payment_api import api_payments_create_invoice
|
|||
from lnbits.db import DB_TYPE, SQLITE, Database
|
||||
from lnbits.settings import settings
|
||||
from tests.helpers import (
|
||||
clean_database,
|
||||
get_random_invoice_data,
|
||||
)
|
||||
|
||||
|
@ -47,7 +46,6 @@ def event_loop():
|
|||
# use session scope to run once before and once after all tests
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def app():
|
||||
clean_database(settings)
|
||||
app = create_app()
|
||||
async with LifespanManager(app) as manager:
|
||||
settings.first_install = False
|
||||
|
@ -199,9 +197,9 @@ async def fake_payments(client, adminkey_headers_from):
|
|||
"/api/v1/payments", headers=adminkey_headers_from, json=invoice.dict()
|
||||
)
|
||||
assert response.is_success
|
||||
await update_payment_status(
|
||||
response.json()["checking_id"], status=PaymentState.SUCCESS
|
||||
)
|
||||
data = response.json()
|
||||
assert data["checking_id"]
|
||||
await update_payment_status(data["checking_id"], status=PaymentState.SUCCESS)
|
||||
|
||||
params = {"time[ge]": ts, "time[le]": time()}
|
||||
return fake_data, params
|
||||
|
|
|
@ -2,11 +2,7 @@ import random
|
|||
import string
|
||||
from typing import Optional
|
||||
|
||||
from psycopg2 import connect
|
||||
from psycopg2.errors import InvalidCatalogName
|
||||
|
||||
from lnbits import core
|
||||
from lnbits.db import DB_TYPE, POSTGRES, FromRowModel
|
||||
from lnbits.db import FromRowModel
|
||||
from lnbits.wallets import get_funding_source, set_funding_source
|
||||
|
||||
|
||||
|
@ -35,21 +31,3 @@ set_funding_source()
|
|||
funding_source = get_funding_source()
|
||||
is_fake: bool = funding_source.__class__.__name__ == "FakeWallet"
|
||||
is_regtest: bool = not is_fake
|
||||
|
||||
|
||||
def clean_database(settings):
|
||||
if DB_TYPE == POSTGRES:
|
||||
conn = connect(settings.lnbits_database_url)
|
||||
conn.autocommit = True
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute("DROP DATABASE lnbits_test")
|
||||
except InvalidCatalogName:
|
||||
pass
|
||||
cur.execute("CREATE DATABASE lnbits_test")
|
||||
core.db.__init__("database")
|
||||
conn.close()
|
||||
else:
|
||||
# TODO: do this once mock data is removed from test data folder
|
||||
# os.remove(settings.lnbits_data_folder + "/database.sqlite3")
|
||||
pass
|
||||
|
|
|
@ -14,8 +14,8 @@ from lnbits.db import POSTGRES
|
|||
@pytest.mark.asyncio
|
||||
async def test_date_conversion(db):
|
||||
if db.type == POSTGRES:
|
||||
row = await db.fetchone("SELECT now()::date")
|
||||
assert row and isinstance(row[0], date)
|
||||
row = await db.fetchone("SELECT now()::date as now")
|
||||
assert row and isinstance(row.get("now"), date)
|
||||
|
||||
|
||||
# make test to create wallet and delete wallet
|
||||
|
|
|
@ -12,10 +12,17 @@ test = DbTestModel(id=1, name="test", value="yes")
|
|||
@pytest.mark.asyncio
|
||||
async def test_helpers_insert_query():
|
||||
q = insert_query("test_helpers_query", test)
|
||||
assert q == "INSERT INTO test_helpers_query (id, name, value) VALUES (?, ?, ?)"
|
||||
assert (
|
||||
q == "INSERT INTO test_helpers_query (id, name, value) "
|
||||
"VALUES (:id, :name, :value)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_helpers_update_query():
|
||||
q = update_query("test_helpers_query", test)
|
||||
assert q == "UPDATE test_helpers_query SET id = ?, name = ?, value = ? WHERE id = ?"
|
||||
assert (
|
||||
q == "UPDATE test_helpers_query "
|
||||
"SET id = :id, name = :name, value = :value "
|
||||
"WHERE id = :id"
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Python script to migrate an LNbits SQLite DB to Postgres
|
||||
# All credits to @Fritz446 for the awesome work
|
||||
# credits to @Fritz446 for the awesome work
|
||||
|
||||
# pip install psycopg2 OR psycopg2-binary
|
||||
|
||||
|
@ -9,10 +9,14 @@ import sqlite3
|
|||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import psycopg2
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
except ImportError:
|
||||
print("Please install psycopg2")
|
||||
sys.exit(1)
|
||||
|
||||
sqfolder = settings.lnbits_data_folder
|
||||
db_url = settings.lnbits_database_url
|
||||
|
||||
|
@ -55,8 +59,8 @@ def check_db_versions(sqdb):
|
|||
version = dbpost[key]
|
||||
if value != version:
|
||||
raise Exception(
|
||||
f"sqlite database version ({value}) of {key} doesn't match postgres"
|
||||
f" database version {version}"
|
||||
f"sqlite database version ({value}) of {key} doesn't match "
|
||||
f"postgres database version {version}"
|
||||
)
|
||||
|
||||
connection = postgres.connection
|
||||
|
|
Loading…
Add table
Reference in a new issue