mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2024-11-20 10:39:59 +01:00
86 lines
2.7 KiB
Python
86 lines
2.7 KiB
Python
import os
|
|
from typing import Tuple, Optional, Any
|
|
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
|
|
from sqlalchemy import create_engine # type: ignore
|
|
from quart import g
|
|
|
|
from .settings import LNBITS_DATA_FOLDER
|
|
|
|
|
|
class Database:
|
|
def __init__(self, db_name: str):
|
|
self.db_name = db_name
|
|
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
|
|
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY)
|
|
|
|
def connect(self):
|
|
return self.engine.connect()
|
|
|
|
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]:
|
|
try:
|
|
return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None)
|
|
except RuntimeError:
|
|
return None, None
|
|
|
|
async def begin(self):
|
|
conn, _ = self.session_connection()
|
|
if conn:
|
|
return
|
|
|
|
conn = await self.engine.connect()
|
|
setattr(g, f"{self.db_name}_conn", conn)
|
|
txn = await conn.begin()
|
|
setattr(g, f"{self.db_name}_txn", txn)
|
|
|
|
async def fetchall(self, query: str, values: tuple = ()) -> list:
|
|
conn, _ = self.session_connection()
|
|
if conn:
|
|
result = await conn.execute(query, values)
|
|
return await result.fetchall()
|
|
|
|
async with self.connect() as conn:
|
|
result = await conn.execute(query, values)
|
|
return await result.fetchall()
|
|
|
|
async def fetchone(self, query: str, values: tuple = ()):
|
|
conn, _ = self.session_connection()
|
|
if conn:
|
|
result = await conn.execute(query, values)
|
|
row = await result.fetchone()
|
|
await result.close()
|
|
return row
|
|
|
|
async with self.connect() as conn:
|
|
result = await conn.execute(query, values)
|
|
row = await result.fetchone()
|
|
await result.close()
|
|
return row
|
|
|
|
async def execute(self, query: str, values: tuple = ()):
|
|
conn, _ = self.session_connection()
|
|
if conn:
|
|
return await conn.execute(query, values)
|
|
|
|
async with self.connect() as conn:
|
|
return await conn.execute(query, values)
|
|
|
|
async def commit(self):
|
|
conn, txn = self.session_connection()
|
|
if conn and txn:
|
|
await txn.commit()
|
|
await self.close_session()
|
|
|
|
async def rollback(self):
|
|
conn, txn = self.session_connection()
|
|
if conn and txn:
|
|
await txn.rollback()
|
|
await self.close_session()
|
|
|
|
async def close_session(self):
|
|
conn, txn = self.session_connection()
|
|
if conn and txn:
|
|
await txn.close()
|
|
await conn.close()
|
|
delattr(g, f"{self.db_name}_conn")
|
|
delattr(g, f"{self.db_name}_txn")
|