lnbits-legend/lnbits/db.py

193 lines
6.3 KiB
Python
Raw Normal View History

import asyncio
import datetime
2021-11-15 13:11:42 +01:00
import os
import time
from contextlib import asynccontextmanager
2021-11-15 13:11:42 +01:00
from typing import Optional
from sqlalchemy import create_engine
from sqlalchemy_aio.base import AsyncConnection
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY # type: ignore
2021-06-22 04:22:52 +02:00
from .settings import LNBITS_DATA_FOLDER, LNBITS_DATABASE_URL
2021-06-22 04:22:52 +02:00
POSTGRES = "POSTGRES"
2021-07-02 23:32:58 +02:00
COCKROACH = "COCKROACH"
2021-06-22 04:22:52 +02:00
SQLITE = "SQLITE"
2021-06-22 04:22:52 +02:00
class Compat:
2021-07-01 18:09:02 +02:00
type: Optional[str] = "<inherited>"
schema: Optional[str] = "<inherited>"
2021-06-22 04:22:52 +02:00
def interval_seconds(self, seconds: int) -> str:
2021-07-02 23:32:58 +02:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-22 04:22:52 +02:00
return f"interval '{seconds} seconds'"
elif self.type == SQLITE:
return f"{seconds}"
return "<nothing>"
@property
def timestamp_now(self) -> str:
2021-07-02 23:32:58 +02:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-22 04:22:52 +02:00
return "now()"
elif self.type == SQLITE:
return "(strftime('%s', 'now'))"
return "<nothing>"
@property
def serial_primary_key(self) -> str:
2021-07-02 23:32:58 +02:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-22 04:22:52 +02:00
return "SERIAL PRIMARY KEY"
elif self.type == SQLITE:
return "INTEGER PRIMARY KEY AUTOINCREMENT"
return "<nothing>"
@property
def references_schema(self) -> str:
2021-07-02 23:32:58 +02:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-22 04:22:52 +02:00
return f"{self.schema}."
elif self.type == SQLITE:
return ""
return "<nothing>"
class Connection(Compat):
def __init__(self, conn: AsyncConnection, txn, typ, name, schema):
self.conn = conn
2021-06-22 04:22:52 +02:00
self.txn = txn
self.type = typ
self.name = name
self.schema = schema
def rewrite_query(self, query) -> str:
2021-07-02 23:32:58 +02:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-22 04:22:52 +02:00
query = query.replace("%", "%%")
query = query.replace("?", "%s")
return query
async def fetchall(self, query: str, values: tuple = ()) -> list:
2021-06-22 04:22:52 +02:00
result = await self.conn.execute(self.rewrite_query(query), values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
2021-06-22 04:22:52 +02:00
result = await self.conn.execute(self.rewrite_query(query), values)
row = await result.fetchone()
await result.close()
return row
async def execute(self, query: str, values: tuple = ()):
2021-06-22 04:22:52 +02:00
return await self.conn.execute(self.rewrite_query(query), values)
2021-06-22 04:22:52 +02:00
class Database(Compat):
def __init__(self, db_name: str):
2021-06-22 04:22:52 +02:00
self.name = db_name
if LNBITS_DATABASE_URL:
database_uri = LNBITS_DATABASE_URL
2021-07-02 23:32:58 +02:00
if database_uri.startswith("cockroachdb://"):
self.type = COCKROACH
else:
self.type = POSTGRES
2021-06-22 04:22:52 +02:00
2021-07-01 18:09:02 +02:00
import psycopg2 # type: ignore
2021-11-15 13:11:42 +01:00
def _parse_timestamp(value, _):
f = "%Y-%m-%d %H:%M:%S.%f"
if not "." in value:
f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
2021-07-02 23:32:58 +02:00
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
psycopg2.extensions.DECIMAL.values,
"DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
(1082, 1083, 1266),
2021-07-02 23:32:58 +02:00
"DATE2INT",
lambda value, curs: time.mktime(value.timetuple())
if value is not None
else None,
)
2021-06-22 04:22:52 +02:00
)
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
(1184, 1114),
"TIMESTAMP2INT",
2021-11-15 13:11:42 +01:00
_parse_timestamp
# lambda value, curs: time.mktime(
# datetime.datetime.strptime(
# value, "%Y-%m-%d %H:%M:%S.%f"
# ).timetuple()
# ),
)
)
2021-06-22 04:22:52 +02:00
else:
if os.path.isdir(LNBITS_DATA_FOLDER):
self.path = os.path.join(LNBITS_DATA_FOLDER, f"{self.name}.sqlite3")
database_uri = f"sqlite:///{self.path}"
self.type = SQLITE
else:
raise NotADirectoryError(
f"LNBITS_DATA_FOLDER named {LNBITS_DATA_FOLDER} was not created"
f" - please 'mkdir {LNBITS_DATA_FOLDER}' and try again"
)
2021-06-22 04:22:52 +02:00
self.schema = self.name
if self.name.startswith("ext_"):
self.schema = self.name[4:]
else:
self.schema = None
self.engine = create_engine(database_uri, strategy=ASYNCIO_STRATEGY)
self.lock = asyncio.Lock()
@asynccontextmanager
async def connect(self):
await self.lock.acquire()
try:
async with self.engine.connect() as conn:
2021-06-22 04:22:52 +02:00
async with conn.begin() as txn:
wconn = Connection(conn, txn, self.type, self.name, self.schema)
if self.schema:
2021-07-02 23:32:58 +02:00
if self.type in {POSTGRES, COCKROACH}:
2021-06-22 04:22:52 +02:00
await wconn.execute(
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
)
elif self.type == SQLITE:
await wconn.execute(
f"ATTACH '{self.path}' AS {self.schema}"
)
yield wconn
finally:
self.lock.release()
async def fetchall(self, query: str, values: tuple = ()) -> list:
async with self.connect() as conn:
result = await conn.execute(query, values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
async with self.connect() as conn:
result = await conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async def execute(self, query: str, values: tuple = ()):
async with self.connect() as conn:
return await conn.execute(query, values)
@asynccontextmanager
async def reuse_conn(self, conn: Connection):
yield conn