mega chore: update sqlalchemy (#2611)

* update sqlalchemy to 1.4
* async postgres

---------

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
This commit is contained in:
dni ⚡ 2024-09-24 10:56:03 +02:00 committed by GitHub
parent c637e8d31e
commit 21d87adc52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1020 additions and 951 deletions

View file

@ -46,7 +46,10 @@ runs:
- name: Install the project dependencies
shell: bash
run: poetry install
run: |
poetry install
# needed for conv tests
poetry add psycopg2-binary
- name: Use Node.js ${{ inputs.node-version }}
if: ${{ (inputs.npm == 'true') }}

View file

@ -30,6 +30,7 @@
meta.rev = self.dirtyRev or self.rev;
meta.mainProgram = projectName;
overrides = pkgs.poetry2nix.overrides.withDefaults (final: prev: {
coincurve = prev.coincurve.override { preferWheel = true; };
protobuf = prev.protobuf.override { preferWheel = true; };
ruff = prev.ruff.override { preferWheel = true; };
wallycore = prev.wallycore.override { preferWheel = true; };

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,3 @@
import datetime
from time import time
from loguru import logger
@ -102,7 +101,7 @@ async def m002_add_fields_to_apipayments(db):
import json
rows = await (await db.execute("SELECT * FROM apipayments")).fetchall()
rows = await db.fetchall("SELECT * FROM apipayments")
for row in rows:
if not row["memo"] or not row["memo"].startswith("#"):
continue
@ -113,15 +112,15 @@ async def m002_add_fields_to_apipayments(db):
new = row["memo"][len(prefix) :]
await db.execute(
"""
UPDATE apipayments SET extra = ?, memo = ?
WHERE checking_id = ? AND memo = ?
UPDATE apipayments SET extra = :extra, memo = :memo1
WHERE checking_id = :checking_id AND memo = :memo2
""",
(
json.dumps({"tag": ext}),
new,
row["checking_id"],
row["memo"],
),
{
"extra": json.dumps({"tag": ext}),
"memo1": new,
"checking_id": row["checking_id"],
"memo2": row["memo"],
},
)
break
except OperationalError:
@ -212,19 +211,17 @@ async def m007_set_invoice_expiries(db):
Precomputes invoice expiry for existing pending incoming payments.
"""
try:
rows = await (
await db.execute(
f"""
SELECT bolt11, checking_id
FROM apipayments
WHERE pending = true
AND amount > 0
AND bolt11 IS NOT NULL
AND expiry IS NULL
AND time < {db.timestamp_now}
"""
)
).fetchall()
rows = await db.fetchall(
f"""
SELECT bolt11, checking_id
FROM apipayments
WHERE pending = true
AND amount > 0
AND bolt11 IS NOT NULL
AND expiry IS NULL
AND time < {db.timestamp_now}
"""
)
if len(rows):
logger.info(f"Migration: Checking expiry of {len(rows)} invoices")
for i, (
@ -236,22 +233,17 @@ async def m007_set_invoice_expiries(db):
if invoice.expiry is None:
continue
expiration_date = datetime.datetime.fromtimestamp(
invoice.date + invoice.expiry
)
expiration_date = invoice.date + invoice.expiry
logger.info(
f"Migration: {i+1}/{len(rows)} setting expiry of invoice"
f" {invoice.payment_hash} to {expiration_date}"
)
await db.execute(
"""
UPDATE apipayments SET expiry = ?
WHERE checking_id = ? AND amount > 0
f"""
UPDATE apipayments SET expiry = {db.timestamp_placeholder('expiry')}
WHERE checking_id = :checking_id AND amount > 0
""",
(
db.datetime_to_timestamp(expiration_date),
checking_id,
),
{"expiry": expiration_date, "checking_id": checking_id},
)
except Exception:
continue
@ -347,17 +339,15 @@ async def m014_set_deleted_wallets(db):
Sets deleted column to wallets.
"""
try:
rows = await (
await db.execute(
"""
SELECT *
FROM wallets
WHERE user LIKE 'del:%'
AND adminkey LIKE 'del:%'
AND inkey LIKE 'del:%'
"""
)
).fetchall()
rows = await db.fetchall(
"""
SELECT *
FROM wallets
WHERE user LIKE 'del:%'
AND adminkey LIKE 'del:%'
AND inkey LIKE 'del:%'
"""
)
for row in rows:
try:
@ -367,10 +357,15 @@ async def m014_set_deleted_wallets(db):
await db.execute(
"""
UPDATE wallets SET
"user" = ?, adminkey = ?, inkey = ?, deleted = true
WHERE id = ?
"user" = :user, adminkey = :adminkey, inkey = :inkey, deleted = true
WHERE id = :wallet
""",
(user, adminkey, inkey, row[0]),
{
"user": user,
"adminkey": adminkey,
"inkey": inkey,
"wallet": row.get("id"),
},
)
except Exception:
continue
@ -456,17 +451,17 @@ async def m017_add_timestamp_columns_to_accounts_and_wallets(db):
now = int(time())
await db.execute(
f"""
UPDATE wallets SET created_at = {db.timestamp_placeholder}
UPDATE wallets SET created_at = {db.timestamp_placeholder('now')}
WHERE created_at IS NULL
""",
(now,),
{"now": now},
)
await db.execute(
f"""
UPDATE accounts SET created_at = {db.timestamp_placeholder}
UPDATE accounts SET created_at = {db.timestamp_placeholder('now')}
WHERE created_at IS NULL
""",
(now,),
{"now": now},
)
except OperationalError as exc:

View file

@ -7,7 +7,6 @@ import json
import time
from dataclasses import dataclass
from enum import Enum
from sqlite3 import Row
from typing import Callable, Optional
from ecdsa import SECP256k1, SigningKey
@ -240,7 +239,7 @@ class Payment(FromRowModel):
return self.status == PaymentState.FAILED.value
@classmethod
def from_row(cls, row: Row):
def from_row(cls, row: dict):
return cls(
checking_id=row["checking_id"],
payment_hash=row["hash"] or "0" * 64,
@ -347,7 +346,7 @@ class TinyURL(BaseModel):
time: float
@classmethod
def from_row(cls, row: Row):
def from_row(cls, row: dict):
return cls(**dict(row))

View file

@ -7,14 +7,13 @@ import re
import time
from contextlib import asynccontextmanager
from enum import Enum
from sqlite3 import Row
from typing import Any, Generic, Literal, Optional, TypeVar
from loguru import logger
from pydantic import BaseModel, ValidationError, root_validator
from sqlalchemy import create_engine
from sqlalchemy_aio.base import AsyncConnection
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
from sqlalchemy.sql import text
from lnbits.settings import settings
@ -24,31 +23,15 @@ SQLITE = "SQLITE"
if settings.lnbits_database_url:
database_uri = settings.lnbits_database_url
if database_uri.startswith("cockroachdb://"):
DB_TYPE = COCKROACH
else:
if not database_uri.startswith("postgres://"):
raise ValueError(
"Please use the 'postgres://...' " "format for the database URL."
)
DB_TYPE = POSTGRES
from psycopg2.extensions import DECIMAL, new_type, register_type
def _parse_timestamp(value, _):
if value is None:
return None
f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value:
f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
register_type(
new_type(
DECIMAL.values,
"DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
else:
if not os.path.isdir(settings.lnbits_data_folder):
os.mkdir(settings.lnbits_data_folder)
@ -56,21 +39,21 @@ else:
DB_TYPE = SQLITE
def compat_timestamp_placeholder():
def compat_timestamp_placeholder(key: str):
if DB_TYPE == POSTGRES:
return "to_timestamp(?)"
return f"to_timestamp(:{key})"
elif DB_TYPE == COCKROACH:
return "cast(? AS timestamp)"
return f"cast(:{key} AS timestamp)"
else:
return "?"
return f":{key}"
def get_placeholder(model: Any, field: str) -> str:
type_ = model.__fields__[field].type_
if type_ == datetime.datetime:
return compat_timestamp_placeholder()
return compat_timestamp_placeholder(field)
else:
return "?"
return f":{field}"
class Compat:
@ -127,15 +110,13 @@ class Compat:
return "BIGINT"
return "INT"
@property
def timestamp_placeholder(self) -> str:
return compat_timestamp_placeholder()
def timestamp_placeholder(self, key: str) -> str:
return compat_timestamp_placeholder(key)
class Connection(Compat):
def __init__(self, conn: AsyncConnection, txn, typ, name, schema):
def __init__(self, conn: AsyncConnection, typ, name, schema):
self.conn = conn
self.txn = txn
self.type = typ
self.name = name
self.schema = schema
@ -146,45 +127,42 @@ class Connection(Compat):
query = query.replace("?", "%s")
return query
def rewrite_values(self, values):
def rewrite_values(self, values: dict) -> dict:
# strip html
clean_regex = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
# tuple to list and back to tuple
raw_values = [values] if isinstance(values, str) else list(values)
values = []
for raw_value in raw_values:
clean_values: dict = {}
for key, raw_value in values.items():
if isinstance(raw_value, str):
values.append(re.sub(clean_regex, "", raw_value))
clean_values[key] = re.sub(clean_regex, "", raw_value)
elif isinstance(raw_value, datetime.datetime):
ts = raw_value.timestamp()
if self.type == SQLITE:
values.append(int(ts))
clean_values[key] = int(ts)
else:
values.append(ts)
clean_values[key] = ts
else:
values.append(raw_value)
return tuple(values)
clean_values[key] = raw_value
return clean_values
async def fetchall(self, query: str, values: tuple = ()) -> list:
result = await self.conn.execute(
self.rewrite_query(query), self.rewrite_values(values)
)
return await result.fetchall()
async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]:
params = self.rewrite_values(values) if values else {}
result = await self.conn.execute(text(self.rewrite_query(query)), params)
row = result.mappings().all()
result.close()
return row
async def fetchone(self, query: str, values: tuple = ()):
result = await self.conn.execute(
self.rewrite_query(query), self.rewrite_values(values)
)
row = await result.fetchone()
await result.close()
async def fetchone(self, query: str, values: Optional[dict] = None) -> dict:
params = self.rewrite_values(values) if values else {}
result = await self.conn.execute(text(self.rewrite_query(query)), params)
row = result.mappings().first()
result.close()
return row
async def fetch_page(
self,
query: str,
where: Optional[list[str]] = None,
values: Optional[list[str]] = None,
values: Optional[dict] = None,
filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None,
group_by: Optional[list[str]] = None,
@ -211,14 +189,14 @@ class Connection(Compat):
{filters.order_by()}
{filters.pagination()}
""",
parsed_values,
self.rewrite_values(parsed_values),
)
if rows:
# no need for extra query if no pagination is specified
if filters.offset or filters.limit:
count = await self.fetchone(
result = await self.fetchone(
f"""
SELECT COUNT(*) FROM (
SELECT COUNT(*) as count FROM (
{query}
{clause}
{group_by_string}
@ -226,21 +204,22 @@ class Connection(Compat):
""",
parsed_values,
)
count = int(count[0])
count = int(result.get("count", 0))
else:
count = len(rows)
else:
count = 0
return Page(
data=[model.from_row(row) for row in rows] if model else rows,
data=[model.from_row(row) for row in rows] if model else [],
total=count,
)
async def execute(self, query: str, values: tuple = ()):
return await self.conn.execute(
self.rewrite_query(query), self.rewrite_values(values)
)
async def execute(self, query: str, values: Optional[dict] = None):
params = self.rewrite_values(values) if values else {}
result = await self.conn.execute(text(self.rewrite_query(query)), params)
await self.conn.commit()
return result
class Database(Compat):
@ -253,18 +232,44 @@ class Database(Compat):
self.path = os.path.join(
settings.lnbits_data_folder, f"{self.name}.sqlite3"
)
database_uri = f"sqlite:///{self.path}"
database_uri = f"sqlite+aiosqlite:///{self.path}"
else:
database_uri = settings.lnbits_database_url
database_uri = settings.lnbits_database_url.replace(
"postgres://", "postgresql+asyncpg://"
)
if self.name.startswith("ext_"):
self.schema = self.name[4:]
else:
self.schema = None
self.engine = create_engine(
database_uri, strategy=ASYNCIO_STRATEGY, echo=settings.debug_database
self.engine: AsyncEngine = create_async_engine(
database_uri, echo=settings.debug_database
)
if self.type in {POSTGRES, COCKROACH}:
@event.listens_for(self.engine.sync_engine, "connect")
def register_custom_types(dbapi_connection, *_):
def _parse_timestamp(value):
if value is None:
return None
f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value:
f = "%Y-%m-%d %H:%M:%S"
return int(
time.mktime(datetime.datetime.strptime(value, f).timetuple())
)
dbapi_connection.run_async(
lambda connection: connection.set_type_codec(
"TIMESTAMP",
encoder=datetime.datetime,
decoder=_parse_timestamp,
schema="pg_catalog",
)
)
self.lock = asyncio.Lock()
logger.trace(f"database {self.type} added for {self.name}")
@ -273,41 +278,37 @@ class Database(Compat):
async def connect(self):
await self.lock.acquire()
try:
async with self.engine.connect() as conn: # type: ignore
async with conn.begin() as txn:
wconn = Connection(conn, txn, self.type, self.name, self.schema)
async with self.engine.connect() as conn:
if not conn:
raise Exception("Could not connect to the database")
if self.schema:
if self.type in {POSTGRES, COCKROACH}:
await wconn.execute(
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
)
elif self.type == SQLITE:
await wconn.execute(
f"ATTACH '{self.path}' AS {self.schema}"
)
wconn = Connection(conn, self.type, self.name, self.schema)
yield wconn
if self.schema:
if self.type in {POSTGRES, COCKROACH}:
await wconn.execute(
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
)
elif self.type == SQLITE:
await wconn.execute(f"ATTACH '{self.path}' AS {self.schema}")
yield wconn
finally:
self.lock.release()
async def fetchall(self, query: str, values: tuple = ()) -> list:
async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]:
async with self.connect() as conn:
result = await conn.execute(query, values)
return await result.fetchall()
return await conn.fetchall(query, values)
async def fetchone(self, query: str, values: tuple = ()):
async def fetchone(self, query: str, values: Optional[dict] = None) -> dict:
async with self.connect() as conn:
result = await conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
return await conn.fetchone(query, values)
async def fetch_page(
self,
query: str,
where: Optional[list[str]] = None,
values: Optional[list[str]] = None,
values: Optional[dict] = None,
filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None,
group_by: Optional[list[str]] = None,
@ -315,7 +316,7 @@ class Database(Compat):
async with self.connect() as conn:
return await conn.fetch_page(query, where, values, filters, model, group_by)
async def execute(self, query: str, values: tuple = ()):
async def execute(self, query: str, values: Optional[dict] = None):
async with self.connect() as conn:
return await conn.execute(query, values)
@ -373,8 +374,8 @@ class Operator(Enum):
class FromRowModel(BaseModel):
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
def from_row(cls, row: dict):
return cls(**row)
class FilterModel(BaseModel):
@ -396,12 +397,13 @@ class Page(BaseModel, Generic[T]):
class Filter(BaseModel, Generic[TFilterModel]):
field: str
op: Operator = Operator.EQ
values: list[Any]
model: Optional[type[TFilterModel]]
values: Optional[dict] = None
@classmethod
def parse_query(cls, key: str, raw_values: list[Any], model: type[TFilterModel]):
def parse_query(
cls, key: str, raw_values: list[Any], model: type[TFilterModel], i: int = 0
):
# Key format:
# key[operator]
# e.g. name[eq]
@ -417,12 +419,12 @@ class Filter(BaseModel, Generic[TFilterModel]):
if field in model.__fields__:
compare_field = model.__fields__[field]
values = []
values: dict = {}
for raw_value in raw_values:
validated, errors = compare_field.validate(raw_value, {}, loc="none")
if errors:
raise ValidationError(errors=[errors], model=model)
values.append(validated)
values[f"{field}__{i}"] = validated
else:
raise ValueError("Unknown filter field")
@ -430,13 +432,17 @@ class Filter(BaseModel, Generic[TFilterModel]):
@property
def statement(self):
assert self.model, "Model is required for statement generation"
placeholder = get_placeholder(self.model, self.field)
if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
placeholders = ", ".join([placeholder] * len(self.values))
stmt = [f"{self.field} {self.op.as_sql} ({placeholders})"]
else:
stmt = [f"{self.field} {self.op.as_sql} {placeholder}"] * len(self.values)
stmt = []
for key in self.values.keys() if self.values else []:
clean_key = key.split("__")[0]
if (
self.model
and self.model.__fields__[clean_key].type_ == datetime.datetime
):
placeholder = compat_timestamp_placeholder(key)
else:
placeholder = f":{key}"
stmt.append(f"{clean_key} {self.op.as_sql} {placeholder}")
return " OR ".join(stmt)
@ -487,14 +493,11 @@ class Filters(BaseModel, Generic[TFilterModel]):
for page_filter in self.filters:
where_stmts.append(page_filter.statement)
if self.search and self.model:
fields = self.model.__search_fields__
if DB_TYPE == POSTGRES:
where_stmts.append(
f"lower(concat({', '.join(self.model.__search_fields__)})) LIKE ?"
)
where_stmts.append(f"lower(concat({', '.join(fields)})) LIKE :search")
elif DB_TYPE == SQLITE:
where_stmts.append(
f"lower({'||'.join(self.model.__search_fields__)}) LIKE ?"
)
where_stmts.append(f"lower({'||'.join(fields)}) LIKE :search")
if where_stmts:
return "WHERE " + " AND ".join(where_stmts)
return ""
@ -504,12 +507,14 @@ class Filters(BaseModel, Generic[TFilterModel]):
return f"ORDER BY {self.sortby} {self.direction or 'asc'}"
return ""
def values(self, values: Optional[list[str]] = None) -> tuple:
def values(self, values: Optional[dict] = None) -> dict:
if not values:
values = []
values = {}
if self.filters:
for page_filter in self.filters:
values.extend(page_filter.values)
if page_filter.values:
for key, value in page_filter.values.items():
values[key] = value
if self.search and self.model:
values.append(f"%{self.search}%")
return tuple(values)
values["search"] = f"%{self.search}%"
return values

View file

@ -204,9 +204,9 @@ def parse_filters(model: Type[TFilterModel]):
):
params = request.query_params
filters = []
for key in params.keys():
for i, key in enumerate(params.keys()):
try:
filters.append(Filter.parse_query(key, params.getlist(key), model))
filters.append(Filter.parse_query(key, params.getlist(key), model, i))
except ValueError:
continue

View file

@ -187,12 +187,14 @@ def insert_query(table_name: str, model: BaseModel) -> str:
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})"
def update_query(table_name: str, model: BaseModel, where: str = "WHERE id = ?") -> str:
def update_query(
table_name: str, model: BaseModel, where: str = "WHERE id = :id"
) -> str:
"""
Generate an update query with placeholders for a given table and model
:param table_name: Name of the table
:param model: Pydantic model
:param where: Where string, default to `WHERE id = ?`
:param where: Where string, default to `WHERE id = :id`
"""
fields = []
for field in model.dict().keys():

View file

@ -74,10 +74,10 @@ def configure_logger() -> None:
logging.getLogger("uvicorn.error").propagate = False
logging.getLogger("sqlalchemy").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine.base").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine.base").propagate = False
logging.getLogger("sqlalchemy.engine.base.Engine").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine.base.Engine").propagate = False
logging.getLogger("sqlalchemy.engine").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine").propagate = False
logging.getLogger("sqlalchemy.engine.Engine").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine.Engine").propagate = False
class Formatter:

997
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -16,25 +16,25 @@ python = "^3.12 | ^3.11 | ^3.10 | ^3.9"
bech32 = "1.2.0"
click = "8.1.7"
ecdsa = "0.19.0"
fastapi = "0.112.0"
fastapi = "0.113.0"
httpx = "0.27.0"
jinja2 = "3.1.4"
lnurl = "0.5.3"
psycopg2-binary = "2.9.9"
pydantic = "1.10.17"
pydantic = "1.10.18"
pyqrcode = "1.2.1"
shortuuid = "1.0.13"
sqlalchemy = "1.3.24"
sqlalchemy-aio = "0.17.0"
sse-starlette = "1.8.2"
typing-extensions = "4.12.2"
uvicorn = "0.30.5"
uvicorn = "0.30.6"
sqlalchemy = "1.4.54"
aiosqlite = "0.20.0"
asyncpg = "0.29.0"
uvloop = "0.19.0"
websockets = "11.0.3"
loguru = "0.7.2"
grpcio = "1.65.5"
protobuf = "5.27.3"
pyln-client = "24.5"
grpcio = "1.66.1"
protobuf = "5.28.0"
pyln-client = "24.8.1"
pywebpush = "1.14.1"
slowapi = "0.1.9"
websocket-client = "1.8.0"
@ -70,11 +70,11 @@ black = "^24.8.0"
pytest-asyncio = "^0.21.2"
pytest = "^8.3.2"
pytest-cov = "^4.1.0"
mypy = "^1.11.1"
mypy = "^1.11.2"
types-protobuf = "^5.27.0.20240626"
pre-commit = "^3.8.0"
openapi-spec-validator = "^0.7.1"
ruff = "^0.5.7"
ruff = "^0.6.4"
types-passlib = "^1.7.7.20240327"
openai = "^1.39.0"
json5 = "^0.9.25"
@ -84,7 +84,7 @@ pytest-httpserver = "^1.1.0"
pytest-mock = "^3.14.0"
types-mock = "^5.1.0.20240425"
mock = "^5.1.0"
grpcio-tools = "^1.65.5"
grpcio-tools = "^1.66.1"
[build-system]
requires = ["poetry-core>=1.0.0"]
@ -126,7 +126,6 @@ module = [
"secp256k1.*",
"uvicorn.*",
"sqlalchemy.*",
"sqlalchemy_aio.*",
"websocket.*",
"websockets.*",
"pyqrcode.*",
@ -136,7 +135,6 @@ module = [
"bolt11.*",
"bitstring.*",
"ecdsa.*",
"psycopg2.*",
"pyngrok.*",
"pyln.client.*",
"py_vapid.*",

View file

@ -367,11 +367,11 @@ async def test_get_payments_history(client, adminkey_headers_from, fake_payments
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["spending"] == sum(
payment.amount * 1000 for payment in fake_data if payment.out
)
assert data[0]["income"] == sum(
payment.amount * 1000 for payment in fake_data if not payment.out
[int(payment.amount * 1000) for payment in fake_data if not payment.out]
)
assert data[0]["spending"] == sum(
[int(payment.amount * 1000) for payment in fake_data if payment.out]
)
response = await client.get(

View file

@ -25,7 +25,6 @@ from lnbits.core.views.payment_api import api_payments_create_invoice
from lnbits.db import DB_TYPE, SQLITE, Database
from lnbits.settings import settings
from tests.helpers import (
clean_database,
get_random_invoice_data,
)
@ -47,7 +46,6 @@ def event_loop():
# use session scope to run once before and once after all tests
@pytest_asyncio.fixture(scope="session")
async def app():
clean_database(settings)
app = create_app()
async with LifespanManager(app) as manager:
settings.first_install = False
@ -199,9 +197,9 @@ async def fake_payments(client, adminkey_headers_from):
"/api/v1/payments", headers=adminkey_headers_from, json=invoice.dict()
)
assert response.is_success
await update_payment_status(
response.json()["checking_id"], status=PaymentState.SUCCESS
)
data = response.json()
assert data["checking_id"]
await update_payment_status(data["checking_id"], status=PaymentState.SUCCESS)
params = {"time[ge]": ts, "time[le]": time()}
return fake_data, params

View file

@ -2,11 +2,7 @@ import random
import string
from typing import Optional
from psycopg2 import connect
from psycopg2.errors import InvalidCatalogName
from lnbits import core
from lnbits.db import DB_TYPE, POSTGRES, FromRowModel
from lnbits.db import FromRowModel
from lnbits.wallets import get_funding_source, set_funding_source
@ -35,21 +31,3 @@ set_funding_source()
funding_source = get_funding_source()
is_fake: bool = funding_source.__class__.__name__ == "FakeWallet"
is_regtest: bool = not is_fake
def clean_database(settings):
if DB_TYPE == POSTGRES:
conn = connect(settings.lnbits_database_url)
conn.autocommit = True
with conn.cursor() as cur:
try:
cur.execute("DROP DATABASE lnbits_test")
except InvalidCatalogName:
pass
cur.execute("CREATE DATABASE lnbits_test")
core.db.__init__("database")
conn.close()
else:
# TODO: do this once mock data is removed from test data folder
# os.remove(settings.lnbits_data_folder + "/database.sqlite3")
pass

View file

@ -14,8 +14,8 @@ from lnbits.db import POSTGRES
@pytest.mark.asyncio
async def test_date_conversion(db):
if db.type == POSTGRES:
row = await db.fetchone("SELECT now()::date")
assert row and isinstance(row[0], date)
row = await db.fetchone("SELECT now()::date as now")
assert row and isinstance(row.get("now"), date)
# make test to create wallet and delete wallet

View file

@ -12,10 +12,17 @@ test = DbTestModel(id=1, name="test", value="yes")
@pytest.mark.asyncio
async def test_helpers_insert_query():
q = insert_query("test_helpers_query", test)
assert q == "INSERT INTO test_helpers_query (id, name, value) VALUES (?, ?, ?)"
assert (
q == "INSERT INTO test_helpers_query (id, name, value) "
"VALUES (:id, :name, :value)"
)
@pytest.mark.asyncio
async def test_helpers_update_query():
q = update_query("test_helpers_query", test)
assert q == "UPDATE test_helpers_query SET id = ?, name = ?, value = ? WHERE id = ?"
assert (
q == "UPDATE test_helpers_query "
"SET id = :id, name = :name, value = :value "
"WHERE id = :id"
)

View file

@ -1,5 +1,5 @@
# Python script to migrate an LNbits SQLite DB to Postgres
# All credits to @Fritz446 for the awesome work
# credits to @Fritz446 for the awesome work
# pip install psycopg2 OR psycopg2-binary
@ -9,10 +9,14 @@ import sqlite3
import sys
from typing import List, Optional
import psycopg2
from lnbits.settings import settings
try:
import psycopg2 # type: ignore
except ImportError:
print("Please install psycopg2")
sys.exit(1)
sqfolder = settings.lnbits_data_folder
db_url = settings.lnbits_database_url
@ -55,8 +59,8 @@ def check_db_versions(sqdb):
version = dbpost[key]
if value != version:
raise Exception(
f"sqlite database version ({value}) of {key} doesn't match postgres"
f" database version {version}"
f"sqlite database version ({value}) of {key} doesn't match "
f"postgres database version {version}"
)
connection = postgres.connection