lnbits-legend/lnbits/db.py

86 lines
2.7 KiB
Python
Raw Normal View History

2020-01-28 00:31:40 +01:00
import os
from typing import Tuple, Optional, Any
2020-11-22 03:23:11 +01:00
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
from sqlalchemy import create_engine # type: ignore
from quart import g
2020-04-16 15:23:38 +02:00
from .settings import LNBITS_DATA_FOLDER
class Database:
def __init__(self, db_name: str):
self.db_name = db_name
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY)
def connect(self):
return self.engine.connect()
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]:
try:
return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None)
except RuntimeError:
return None, None
async def begin(self):
conn, _ = self.session_connection()
if conn:
return
conn = await self.engine.connect()
setattr(g, f"{self.db_name}_conn", conn)
txn = await conn.begin()
setattr(g, f"{self.db_name}_txn", txn)
async def fetchall(self, query: str, values: tuple = ()) -> list:
conn, _ = self.session_connection()
if conn:
result = await conn.execute(query, values)
return await result.fetchall()
async with self.connect() as conn:
result = await conn.execute(query, values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
conn, _ = self.session_connection()
if conn:
result = await conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async with self.connect() as conn:
result = await conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async def execute(self, query: str, values: tuple = ()):
conn, _ = self.session_connection()
if conn:
return await conn.execute(query, values)
async with self.connect() as conn:
return await conn.execute(query, values)
async def commit(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.commit()
await self.close_session()
async def rollback(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.rollback()
await self.close_session()
async def close_session(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.close()
await conn.close()
delattr(g, f"{self.db_name}_conn")
delattr(g, f"{self.db_name}_txn")