lnbits-legend/lnbits/extensions/lnaddress/crud.py
2023-01-05 12:02:23 +01:00

199 lines
5.7 KiB
Python

from datetime import datetime, timedelta
from typing import List, Optional, Union
from loguru import logger
from lnbits.helpers import urlsafe_short_hash
from . import db
from .models import Addresses, CreateAddress, CreateDomain, Domains
async def create_domain(data: CreateDomain) -> Domains:
domain_id = urlsafe_short_hash()
await db.execute(
"""
INSERT INTO lnaddress.domain (id, wallet, domain, webhook, cf_token, cf_zone_id, cost)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
domain_id,
data.wallet,
data.domain,
data.webhook,
data.cf_token,
data.cf_zone_id,
data.cost,
),
)
new_domain = await get_domain(domain_id)
assert new_domain, "Newly created domain couldn't be retrieved"
return new_domain
async def update_domain(domain_id: str, **kwargs) -> Domains:
q = ", ".join([f"{field[0]} = ?" for field in kwargs.items()])
await db.execute(
f"UPDATE lnaddress.domain SET {q} WHERE id = ?", (*kwargs.values(), domain_id)
)
row = await db.fetchone("SELECT * FROM lnaddress.domain WHERE id = ?", (domain_id,))
assert row, "Newly updated domain couldn't be retrieved"
return Domains(**row)
async def delete_domain(domain_id: str) -> None:
await db.execute("DELETE FROM lnaddress.domain WHERE id = ?", (domain_id,))
async def get_domain(domain_id: str) -> Optional[Domains]:
row = await db.fetchone("SELECT * FROM lnaddress.domain WHERE id = ?", (domain_id,))
return Domains(**row) if row else None
async def get_domains(wallet_ids: Union[str, List[str]]) -> List[Domains]:
if isinstance(wallet_ids, str):
wallet_ids = [wallet_ids]
q = ",".join(["?"] * len(wallet_ids))
rows = await db.fetchall(
f"SELECT * FROM lnaddress.domain WHERE wallet IN ({q})", (*wallet_ids,)
)
return [Domains(**row) for row in rows]
## ADRESSES
async def create_address(
payment_hash: str, wallet: str, data: CreateAddress
) -> Addresses:
await db.execute(
"""
INSERT INTO lnaddress.address (id, wallet, domain, email, username, wallet_key, wallet_endpoint, sats, duration, paid)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
payment_hash,
wallet,
data.domain,
data.email,
data.username,
data.wallet_key,
data.wallet_endpoint,
data.sats,
data.duration,
False,
),
)
new_address = await get_address(payment_hash)
assert new_address, "Newly created address couldn't be retrieved"
return new_address
async def get_address(address_id: str) -> Optional[Addresses]:
row = await db.fetchone(
"SELECT a.* FROM lnaddress.address AS a INNER JOIN lnaddress.domain AS d ON a.id = ? AND a.domain = d.id",
(address_id,),
)
return Addresses(**row) if row else None
async def get_address_by_username(username: str, domain: str) -> Optional[Addresses]:
row = await db.fetchone(
"SELECT a.* FROM lnaddress.address AS a INNER JOIN lnaddress.domain AS d ON a.username = ? AND d.domain = ?",
(username, domain),
)
return Addresses(**row) if row else None
async def delete_address(address_id: str) -> None:
await db.execute("DELETE FROM lnaddress.address WHERE id = ?", (address_id,))
async def get_addresses(wallet_ids: Union[str, List[str]]) -> List[Addresses]:
if isinstance(wallet_ids, str):
wallet_ids = [wallet_ids]
q = ",".join(["?"] * len(wallet_ids))
rows = await db.fetchall(
f"SELECT * FROM lnaddress.address WHERE wallet IN ({q})", (*wallet_ids,)
)
return [Addresses(**row) for row in rows]
async def set_address_paid(payment_hash: str) -> Addresses:
address = await get_address(payment_hash)
assert address
if address.paid == False:
await db.execute(
"""
UPDATE lnaddress.address
SET paid = true
WHERE id = ?
""",
(payment_hash,),
)
new_address = await get_address(payment_hash)
assert new_address, "Newly paid address couldn't be retrieved"
return new_address
async def set_address_renewed(address_id: str, duration: int):
address = await get_address(address_id)
assert address
extend_duration = int(address.duration) + duration
await db.execute(
"""
UPDATE lnaddress.address
SET duration = ?
WHERE id = ?
""",
(extend_duration, address_id),
)
updated_address = await get_address(address_id)
assert updated_address, "Renewed address couldn't be retrieved"
return updated_address
async def check_address_available(username: str, domain: str):
(row,) = await db.fetchone(
"SELECT COUNT(username) FROM lnaddress.address WHERE username = ? AND domain = ?",
(username, domain),
)
return row
async def purge_addresses(domain_id: str):
rows = await db.fetchall(
"SELECT * FROM lnaddress.address WHERE domain = ?", (domain_id,)
)
now = datetime.now().timestamp()
for row in rows:
r = Addresses(**row).dict()
start = datetime.fromtimestamp(r["time"])
paid = r["paid"]
pay_expire = now > start.timestamp() + 86400 # if payment wasn't made in 1 day
expired = (
now > (start + timedelta(days=r["duration"] + 1)).timestamp()
) # give user 1 day to topup is address
if not paid and pay_expire:
logger.debug("DELETE UNP_PAY_EXP", r["username"])
await delete_address(r["id"])
if paid and expired:
logger.debug("DELETE PAID_EXP", r["username"])
await delete_address(r["id"])