2020-01-28 00:31:40 +01:00
|
|
|
import os
|
2020-11-21 22:04:39 +01:00
|
|
|
from typing import Tuple, Optional, Any
|
2020-11-22 03:23:11 +01:00
|
|
|
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
|
|
|
|
from sqlalchemy import create_engine # type: ignore
|
2020-11-21 22:04:39 +01:00
|
|
|
from quart import g
|
2019-12-13 18:14:25 +01:00
|
|
|
|
2020-04-16 15:23:38 +02:00
|
|
|
from .settings import LNBITS_DATA_FOLDER
|
2019-12-13 18:14:25 +01:00
|
|
|
|
2019-12-14 02:59:35 +01:00
|
|
|
|
2020-01-31 21:07:05 +01:00
|
|
|
class Database:
|
2020-11-21 22:04:39 +01:00
|
|
|
def __init__(self, db_name: str):
|
|
|
|
self.db_name = db_name
|
|
|
|
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
|
|
|
|
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY)
|
2020-01-28 00:30:31 +01:00
|
|
|
|
2020-11-21 22:04:39 +01:00
|
|
|
def connect(self):
|
|
|
|
return self.engine.connect()
|
2020-10-06 05:39:54 +02:00
|
|
|
|
2020-11-21 22:04:39 +01:00
|
|
|
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]:
|
2020-09-02 03:36:52 +02:00
|
|
|
try:
|
2020-11-21 22:04:39 +01:00
|
|
|
return getattr(g, f"{self.db_name}_conn", None), getattr(g, f"{self.db_name}_txn", None)
|
|
|
|
except RuntimeError:
|
|
|
|
return None, None
|
2020-01-28 00:30:31 +01:00
|
|
|
|
2020-11-21 22:04:39 +01:00
|
|
|
async def begin(self):
|
|
|
|
conn, _ = self.session_connection()
|
|
|
|
if conn:
|
|
|
|
return
|
2020-01-28 00:30:31 +01:00
|
|
|
|
2020-11-21 22:04:39 +01:00
|
|
|
conn = await self.engine.connect()
|
|
|
|
setattr(g, f"{self.db_name}_conn", conn)
|
|
|
|
txn = await conn.begin()
|
|
|
|
setattr(g, f"{self.db_name}_txn", txn)
|
|
|
|
|
|
|
|
async def fetchall(self, query: str, values: tuple = ()) -> list:
|
|
|
|
conn, _ = self.session_connection()
|
|
|
|
if conn:
|
|
|
|
result = await conn.execute(query, values)
|
|
|
|
return await result.fetchall()
|
|
|
|
|
|
|
|
async with self.connect() as conn:
|
|
|
|
result = await conn.execute(query, values)
|
|
|
|
return await result.fetchall()
|
|
|
|
|
|
|
|
async def fetchone(self, query: str, values: tuple = ()):
|
|
|
|
conn, _ = self.session_connection()
|
|
|
|
if conn:
|
|
|
|
result = await conn.execute(query, values)
|
|
|
|
row = await result.fetchone()
|
|
|
|
await result.close()
|
|
|
|
return row
|
|
|
|
|
|
|
|
async with self.connect() as conn:
|
|
|
|
result = await conn.execute(query, values)
|
|
|
|
row = await result.fetchone()
|
|
|
|
await result.close()
|
|
|
|
return row
|
|
|
|
|
|
|
|
async def execute(self, query: str, values: tuple = ()):
|
|
|
|
conn, _ = self.session_connection()
|
|
|
|
if conn:
|
|
|
|
return await conn.execute(query, values)
|
|
|
|
|
|
|
|
async with self.connect() as conn:
|
|
|
|
return await conn.execute(query, values)
|
|
|
|
|
|
|
|
async def commit(self):
|
|
|
|
conn, txn = self.session_connection()
|
|
|
|
if conn and txn:
|
|
|
|
await txn.commit()
|
|
|
|
await self.close_session()
|
|
|
|
|
|
|
|
async def rollback(self):
|
|
|
|
conn, txn = self.session_connection()
|
|
|
|
if conn and txn:
|
|
|
|
await txn.rollback()
|
|
|
|
await self.close_session()
|
|
|
|
|
|
|
|
async def close_session(self):
|
|
|
|
conn, txn = self.session_connection()
|
|
|
|
if conn and txn:
|
|
|
|
await txn.close()
|
|
|
|
await conn.close()
|
|
|
|
delattr(g, f"{self.db_name}_conn")
|
|
|
|
delattr(g, f"{self.db_name}_txn")
|