lnbits-legend/lnbits/db.py

345 lines
11 KiB
Python
Raw Normal View History

import asyncio
import datetime
2021-11-15 12:11:42 +00:00
import os
2022-11-25 14:53:03 +01:00
import re
2021-11-15 12:11:42 +00:00
import time
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar
2021-11-15 12:11:42 +00:00
from loguru import logger
from pydantic import BaseModel, ValidationError
from sqlalchemy import create_engine
from sqlalchemy_aio.base import AsyncConnection
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY
from lnbits.settings import settings
2021-06-21 23:22:52 -03:00
POSTGRES = "POSTGRES"
2021-07-02 18:32:58 -03:00
COCKROACH = "COCKROACH"
2021-06-21 23:22:52 -03:00
SQLITE = "SQLITE"
2021-06-21 23:22:52 -03:00
class Compat:
2021-07-01 13:09:02 -03:00
type: Optional[str] = "<inherited>"
schema: Optional[str] = "<inherited>"
2021-06-21 23:22:52 -03:00
def interval_seconds(self, seconds: int) -> str:
2021-07-02 18:32:58 -03:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-21 23:22:52 -03:00
return f"interval '{seconds} seconds'"
elif self.type == SQLITE:
return f"{seconds}"
return "<nothing>"
2022-12-02 17:38:36 +01:00
def datetime_to_timestamp(self, date: datetime.datetime):
if self.type in {POSTGRES, COCKROACH}:
return date.strftime("%Y-%m-%d %H:%M:%S")
elif self.type == SQLITE:
return time.mktime(date.timetuple())
return "<nothing>"
2021-06-21 23:22:52 -03:00
@property
def timestamp_now(self) -> str:
2021-07-02 18:32:58 -03:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-21 23:22:52 -03:00
return "now()"
elif self.type == SQLITE:
return "(strftime('%s', 'now'))"
return "<nothing>"
@property
def serial_primary_key(self) -> str:
2021-07-02 18:32:58 -03:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-21 23:22:52 -03:00
return "SERIAL PRIMARY KEY"
elif self.type == SQLITE:
return "INTEGER PRIMARY KEY AUTOINCREMENT"
return "<nothing>"
@property
def references_schema(self) -> str:
2021-07-02 18:32:58 -03:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-21 23:22:52 -03:00
return f"{self.schema}."
elif self.type == SQLITE:
return ""
return "<nothing>"
@property
def big_int(self) -> str:
if self.type in {POSTGRES}:
return "BIGINT"
return "INT"
2021-06-21 23:22:52 -03:00
class Connection(Compat):
def __init__(self, conn: AsyncConnection, txn, typ, name, schema):
self.conn = conn
2021-06-21 23:22:52 -03:00
self.txn = txn
self.type = typ
self.name = name
self.schema = schema
def rewrite_query(self, query) -> str:
2021-07-02 18:32:58 -03:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-21 23:22:52 -03:00
query = query.replace("%", "%%")
query = query.replace("?", "%s")
2022-11-25 15:11:58 +01:00
return query
2022-11-25 14:53:03 +01:00
2022-11-25 15:11:58 +01:00
def rewrite_values(self, values):
# strip html
CLEANR = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
2022-11-25 14:53:03 +01:00
2022-11-25 15:11:58 +01:00
def cleanhtml(raw_html):
2022-11-25 15:20:39 +01:00
if isinstance(raw_html, str):
2022-11-25 14:53:03 +01:00
cleantext = re.sub(CLEANR, "", raw_html)
return cleantext
2022-11-25 15:11:58 +01:00
else:
return raw_html
2022-11-25 14:53:03 +01:00
2022-11-25 15:11:58 +01:00
# tuple to list and back to tuple
value_list = [values] if isinstance(values, str) else list(values)
values = tuple([cleanhtml(l) for l in value_list])
2022-11-25 15:11:58 +01:00
return values
2022-11-25 14:53:03 +01:00
async def fetchall(self, query: str, values: tuple = ()) -> list:
2022-11-25 15:11:58 +01:00
result = await self.conn.execute(
self.rewrite_query(query), self.rewrite_values(values)
)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
2022-11-25 15:11:58 +01:00
result = await self.conn.execute(
self.rewrite_query(query), self.rewrite_values(values)
)
row = await result.fetchone()
await result.close()
return row
async def execute(self, query: str, values: tuple = ()):
2022-11-25 15:11:58 +01:00
return await self.conn.execute(
self.rewrite_query(query), self.rewrite_values(values)
)
2021-06-21 23:22:52 -03:00
class Database(Compat):
def __init__(self, db_name: str):
2021-06-21 23:22:52 -03:00
self.name = db_name
if settings.lnbits_database_url:
database_uri = settings.lnbits_database_url
2021-07-02 18:32:58 -03:00
if database_uri.startswith("cockroachdb://"):
self.type = COCKROACH
else:
self.type = POSTGRES
2021-06-21 23:22:52 -03:00
from psycopg2.extensions import DECIMAL, new_type, register_type
2021-07-01 13:09:02 -03:00
2021-11-15 12:11:42 +00:00
def _parse_timestamp(value, _):
2022-12-06 16:21:19 +01:00
if value is None:
return None
2021-11-15 12:11:42 +00:00
f = "%Y-%m-%d %H:%M:%S.%f"
2023-01-21 15:07:40 +00:00
if "." not in value:
2021-11-15 12:11:42 +00:00
f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
register_type(
new_type(
DECIMAL.values,
2021-07-02 18:32:58 -03:00
"DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
register_type(
new_type(
(1082, 1083, 1266),
2021-07-02 18:32:58 -03:00
"DATE2INT",
lambda value, curs: time.mktime(value.timetuple())
if value is not None
else None,
)
2021-06-21 23:22:52 -03:00
)
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
2021-06-21 23:22:52 -03:00
else:
if os.path.isdir(settings.lnbits_data_folder):
self.path = os.path.join(
settings.lnbits_data_folder, f"{self.name}.sqlite3"
)
database_uri = f"sqlite:///{self.path}"
self.type = SQLITE
else:
raise NotADirectoryError(
f"LNBITS_DATA_FOLDER named {settings.lnbits_data_folder} was not created"
f" - please 'mkdir {settings.lnbits_data_folder}' and try again"
)
logger.trace(f"database {self.type} added for {self.name}")
2021-06-21 23:22:52 -03:00
self.schema = self.name
if self.name.startswith("ext_"):
self.schema = self.name[4:]
else:
self.schema = None
self.engine = create_engine(database_uri, strategy=ASYNCIO_STRATEGY)
self.lock = asyncio.Lock()
@asynccontextmanager
async def connect(self):
await self.lock.acquire()
try:
2023-02-28 15:50:06 +01:00
async with self.engine.connect() as conn: # type: ignore
2021-06-21 23:22:52 -03:00
async with conn.begin() as txn:
wconn = Connection(conn, txn, self.type, self.name, self.schema)
if self.schema:
2021-07-02 18:32:58 -03:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-21 23:22:52 -03:00
await wconn.execute(
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
)
elif self.type == SQLITE:
await wconn.execute(
f"ATTACH '{self.path}' AS {self.schema}"
)
yield wconn
finally:
self.lock.release()
async def fetchall(self, query: str, values: tuple = ()) -> list:
async with self.connect() as conn:
result = await conn.execute(query, values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
async with self.connect() as conn:
result = await conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async def execute(self, query: str, values: tuple = ()):
async with self.connect() as conn:
return await conn.execute(query, values)
@asynccontextmanager
async def reuse_conn(self, conn: Connection):
yield conn
class Operator(Enum):
GT = "gt"
LT = "lt"
EQ = "eq"
NE = "ne"
INCLUDE = "in"
EXCLUDE = "ex"
@property
def as_sql(self):
if self == Operator.EQ:
return "="
elif self == Operator.NE:
return "!="
elif self == Operator.INCLUDE:
return "IN"
elif self == Operator.EXCLUDE:
return "NOT IN"
elif self == Operator.GT:
return ">"
elif self == Operator.LT:
return "<"
else:
raise ValueError("Unknown SQL Operator")
TModel = TypeVar("TModel", bound=BaseModel)
class Filter(BaseModel, Generic[TModel]):
field: str
nested: Optional[list[str]]
op: Operator = Operator.EQ
values: list[Any]
@classmethod
def parse_query(cls, key: str, raw_values: list[Any], model: Type[TModel]):
# Key format:
# key[operator]
# e.g. name[eq]
if key.endswith("]"):
split = key[:-1].split("[")
if len(split) != 2:
raise ValueError("Invalid key")
field_names = split[0].split(".")
op = Operator(split[1])
else:
field_names = key.split(".")
op = Operator("eq")
field = field_names[0]
nested = field_names[1:]
if field in model.__fields__:
compare_field = model.__fields__[field]
values = []
for raw_value in raw_values:
# If there is a nested field, pydantic expects a dict, so the raw value is turned into a dict before
# and the converted value is extracted afterwards
for name in reversed(nested):
raw_value = {name: raw_value}
validated, errors = compare_field.validate(raw_value, {}, loc="none")
if errors:
raise ValidationError(errors=[errors], model=model)
for name in nested:
if isinstance(validated, dict):
validated = validated[name]
else:
validated = getattr(validated, name)
values.append(validated)
else:
raise ValueError("Unknown filter field")
return cls(field=field, op=op, nested=nested, values=values)
@property
def statement(self):
accessor = self.field
if self.nested:
for name in self.nested:
accessor = f"({accessor} ->> '{name}')"
if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
placeholders = ", ".join(["?"] * len(self.values))
stmt = [f"{accessor} {self.op.as_sql} ({placeholders})"]
else:
stmt = [f"{accessor} {self.op.as_sql} ?"] * len(self.values)
return " OR ".join(stmt)
class Filters(BaseModel, Generic[TModel]):
filters: List[Filter[TModel]] = []
limit: Optional[int]
offset: Optional[int]
def pagination(self) -> str:
stmt = ""
if self.limit:
stmt += f"LIMIT {self.limit} "
if self.offset:
stmt += f"OFFSET {self.offset}"
return stmt
def where(self, where_stmts: List[str]) -> str:
if self.filters:
for filter in self.filters:
where_stmts.append(filter.statement)
if where_stmts:
return "WHERE " + " AND ".join(where_stmts)
return ""
def values(self, values: List[str]) -> Tuple:
if self.filters:
for filter in self.filters:
values.extend(filter.values)
return tuple(values)