mirror of
https://github.com/lnbits/lnbits-legend.git
synced 2025-01-18 21:32:38 +01:00
fix pyright lnbits
Co-authored-by: dni ⚡ <office@dnilabs.com>
This commit is contained in:
parent
3855cf47f3
commit
02306148df
@ -9,6 +9,7 @@ import secp256k1
|
||||
from bech32 import CHARSET, bech32_decode, bech32_encode
|
||||
from ecdsa import SECP256k1, VerifyingKey
|
||||
from ecdsa.util import sigdecode_string
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Route(NamedTuple):
|
||||
@ -30,6 +31,7 @@ class Invoice:
|
||||
secret: Optional[str] = None
|
||||
route_hints: List[Route] = []
|
||||
min_final_cltv_expiry: int = 18
|
||||
checking_id: Optional[str] = None
|
||||
|
||||
|
||||
def decode(pr: str) -> Invoice:
|
||||
@ -66,11 +68,13 @@ def decode(pr: str) -> Invoice:
|
||||
invoice.amount_msat = _unshorten_amount(amountstr)
|
||||
|
||||
# pull out date
|
||||
invoice.date = data.read(35).uint
|
||||
date_bin = data.read(35)
|
||||
assert date_bin
|
||||
invoice.date = date_bin.uint
|
||||
|
||||
while data.pos != data.len:
|
||||
tag, tagdata, data = _pull_tagged(data)
|
||||
data_length = len(tagdata) / 5
|
||||
data_length = len(tagdata or []) / 5
|
||||
|
||||
if tag == "d":
|
||||
invoice.description = _trim_to_bytes(tagdata).decode()
|
||||
@ -89,12 +93,22 @@ def decode(pr: str) -> Invoice:
|
||||
elif tag == "r":
|
||||
s = bitstring.ConstBitStream(tagdata)
|
||||
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
||||
pubkey = s.read(264)
|
||||
assert pubkey
|
||||
short_channel_id = s.read(64)
|
||||
assert short_channel_id
|
||||
base_fee_msat = s.read(32)
|
||||
assert base_fee_msat
|
||||
ppm_fee = s.read(32)
|
||||
assert ppm_fee
|
||||
cltv = s.read(16)
|
||||
assert cltv
|
||||
route = Route(
|
||||
pubkey=s.read(264).tobytes().hex(),
|
||||
short_channel_id=_readable_scid(s.read(64).intbe),
|
||||
base_fee_msat=s.read(32).intbe,
|
||||
ppm_fee=s.read(32).intbe,
|
||||
cltv=s.read(16).intbe,
|
||||
pubkey=pubkey.tobytes().hex(),
|
||||
short_channel_id=_readable_scid(short_channel_id.intbe),
|
||||
base_fee_msat=base_fee_msat.intbe,
|
||||
ppm_fee=ppm_fee.intbe,
|
||||
cltv=cltv.intbe,
|
||||
)
|
||||
invoice.route_hints.append(route)
|
||||
|
||||
@ -160,6 +174,10 @@ def encode(options):
|
||||
return lnencode(addr, options["privkey"])
|
||||
|
||||
|
||||
def encode_fallback(v, currency):
|
||||
logger.error(f"hit bolt11.py encode_fallback with v: {v} and currency: {currency}")
|
||||
|
||||
|
||||
def lnencode(addr, privkey):
|
||||
if addr.amount:
|
||||
amount = Decimal(str(addr.amount))
|
||||
@ -244,7 +262,13 @@ def lnencode(addr, privkey):
|
||||
|
||||
class LnAddr:
|
||||
def __init__(
|
||||
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None
|
||||
self,
|
||||
paymenthash=None,
|
||||
amount=None,
|
||||
currency="bc",
|
||||
tags=None,
|
||||
date=None,
|
||||
fallback=None,
|
||||
):
|
||||
self.date = int(time.time()) if not date else int(date)
|
||||
self.tags = [] if not tags else tags
|
||||
@ -252,6 +276,7 @@ class LnAddr:
|
||||
self.paymenthash = paymenthash
|
||||
self.signature = None
|
||||
self.pubkey = None
|
||||
self.fallback = fallback
|
||||
self.currency = currency
|
||||
self.amount = amount
|
||||
|
||||
@ -266,6 +291,7 @@ def shorten_amount(amount):
|
||||
# Convert to pico initially
|
||||
amount = int(amount * 10**12)
|
||||
units = ["p", "n", "u", "m", ""]
|
||||
unit = ""
|
||||
for unit in units:
|
||||
if amount % 1000 == 0:
|
||||
amount //= 1000
|
||||
@ -304,14 +330,6 @@ def _pull_tagged(stream):
|
||||
return (CHARSET[tag], stream.read(length * 5), stream)
|
||||
|
||||
|
||||
def is_p2pkh(currency, prefix):
|
||||
return prefix == base58_prefix_map[currency][0]
|
||||
|
||||
|
||||
def is_p2sh(currency, prefix):
|
||||
return prefix == base58_prefix_map[currency][1]
|
||||
|
||||
|
||||
# Tagged field containing BitArray
|
||||
def tagged(char, l):
|
||||
# Tagged fields need to be zero-padded to 5 bits.
|
||||
@ -359,5 +377,5 @@ def bitarray_to_u5(barr):
|
||||
ret = []
|
||||
s = bitstring.ConstBitStream(barr)
|
||||
while s.pos != s.len:
|
||||
ret.append(s.read(5).uint)
|
||||
ret.append(s.read(5).uint) # type: ignore
|
||||
return ret
|
||||
|
@ -41,6 +41,7 @@ async def migrate_databases():
|
||||
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
||||
|
||||
async with core_db.connect() as conn:
|
||||
exists = False
|
||||
if conn.type == SQLITE:
|
||||
exists = await conn.fetchone(
|
||||
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
||||
|
18
lnbits/db.py
18
lnbits/db.py
@ -131,7 +131,7 @@ class Database(Compat):
|
||||
else:
|
||||
self.type = POSTGRES
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extensions import DECIMAL, new_type, register_type
|
||||
|
||||
def _parse_timestamp(value, _):
|
||||
if value is None:
|
||||
@ -141,15 +141,15 @@ class Database(Compat):
|
||||
f = "%Y-%m-%d %H:%M:%S"
|
||||
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
psycopg2.extensions.DECIMAL.values,
|
||||
register_type(
|
||||
new_type(
|
||||
DECIMAL.values,
|
||||
"DEC2FLOAT",
|
||||
lambda value, curs: float(value) if value is not None else None,
|
||||
)
|
||||
)
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
register_type(
|
||||
new_type(
|
||||
(1082, 1083, 1266),
|
||||
"DATE2INT",
|
||||
lambda value, curs: time.mktime(value.timetuple())
|
||||
@ -158,11 +158,7 @@ class Database(Compat):
|
||||
)
|
||||
)
|
||||
|
||||
psycopg2.extensions.register_type(
|
||||
psycopg2.extensions.new_type(
|
||||
(1184, 1114), "TIMESTAMP2INT", _parse_timestamp
|
||||
)
|
||||
)
|
||||
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
|
||||
else:
|
||||
if os.path.isdir(settings.lnbits_data_folder):
|
||||
self.path = os.path.join(
|
||||
|
@ -1,14 +1,12 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Type
|
||||
|
||||
from fastapi import Security, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import HTTPException, Request, Security, status
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
from fastapi.security.base import SecurityBase
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import UUID4
|
||||
from starlette.requests import Request
|
||||
|
||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||
from lnbits.core.models import User, Wallet
|
||||
@ -17,9 +15,13 @@ from lnbits.requestvars import g
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
||||
# TODO: fix type ignores
|
||||
class KeyChecker(SecurityBase):
|
||||
def __init__(
|
||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
||||
self,
|
||||
scheme_name: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
@ -27,13 +29,13 @@ class KeyChecker(SecurityBase):
|
||||
self._api_key = api_key
|
||||
if api_key:
|
||||
key = APIKey(
|
||||
**{"in": APIKeyIn.query},
|
||||
**{"in": APIKeyIn.query}, # type: ignore
|
||||
name="X-API-KEY",
|
||||
description="Wallet API Key - QUERY",
|
||||
)
|
||||
else:
|
||||
key = APIKey(
|
||||
**{"in": APIKeyIn.header},
|
||||
**{"in": APIKeyIn.header}, # type: ignore
|
||||
name="X-API-KEY",
|
||||
description="Wallet API Key - HEADER",
|
||||
)
|
||||
@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
||||
self,
|
||||
scheme_name: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(scheme_name, auto_error, api_key)
|
||||
self._key_type = "invoice"
|
||||
@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
||||
self,
|
||||
scheme_name: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(scheme_name, auto_error, api_key)
|
||||
self._key_type = "admin"
|
||||
|
@ -3,20 +3,146 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
from urllib import request
|
||||
|
||||
import httpx
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
||||
class ExplicitRelease(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
version: str
|
||||
archive: str
|
||||
hash: str
|
||||
dependencies: List[str] = []
|
||||
icon: Optional[str]
|
||||
short_description: Optional[str]
|
||||
html_url: Optional[str]
|
||||
details: Optional[str]
|
||||
info_notification: Optional[str]
|
||||
critical_notification: Optional[str]
|
||||
|
||||
|
||||
class GitHubRelease(BaseModel):
|
||||
id: str
|
||||
organisation: str
|
||||
repository: str
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
featured: List[str] = []
|
||||
extensions: List["ExplicitRelease"] = []
|
||||
repos: List["GitHubRelease"] = []
|
||||
|
||||
|
||||
class GitHubRepoRelease(BaseModel):
|
||||
name: str
|
||||
tag_name: str
|
||||
zipball_url: str
|
||||
html_url: str
|
||||
|
||||
|
||||
class GitHubRepo(BaseModel):
|
||||
stargazers_count: str
|
||||
html_url: str
|
||||
default_branch: str
|
||||
|
||||
|
||||
class ExtensionConfig(BaseModel):
|
||||
name: str
|
||||
short_description: str
|
||||
tile: str = ""
|
||||
|
||||
|
||||
def download_url(url, save_path):
|
||||
with request.urlopen(url) as dl_file:
|
||||
with open(save_path, "wb") as out_file:
|
||||
out_file.write(dl_file.read())
|
||||
|
||||
|
||||
def file_hash(filename):
|
||||
h = hashlib.sha256()
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
with open(filename, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
async def fetch_github_repo_info(
|
||||
org: str, repository: str
|
||||
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
||||
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
||||
error_msg = "Cannot fetch extension repo"
|
||||
repo = await gihub_api_get(repo_url, error_msg)
|
||||
github_repo = GitHubRepo.parse_obj(repo)
|
||||
|
||||
lates_release_url = (
|
||||
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
||||
)
|
||||
error_msg = "Cannot fetch extension releases"
|
||||
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
|
||||
|
||||
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
||||
error_msg = "Cannot fetch config for extension"
|
||||
config = await gihub_api_get(config_url, error_msg)
|
||||
|
||||
return (
|
||||
github_repo,
|
||||
GitHubRepoRelease.parse_obj(latest_release),
|
||||
ExtensionConfig.parse_obj(config),
|
||||
)
|
||||
|
||||
|
||||
async def fetch_manifest(url) -> Manifest:
|
||||
error_msg = "Cannot fetch extensions manifest"
|
||||
manifest = await gihub_api_get(url, error_msg)
|
||||
return Manifest.parse_obj(manifest)
|
||||
|
||||
|
||||
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
|
||||
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
||||
error_msg = "Cannot fetch extension releases"
|
||||
releases = await gihub_api_get(releases_url, error_msg)
|
||||
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
||||
|
||||
|
||||
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = (
|
||||
{"Authorization": "Bearer " + settings.lnbits_ext_github_token}
|
||||
if settings.lnbits_ext_github_token
|
||||
else None
|
||||
)
|
||||
resp = await client.get(
|
||||
url,
|
||||
headers=headers,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
||||
if not path:
|
||||
return ""
|
||||
_, _, *rest = path.split("/")
|
||||
tail = "/".join(rest)
|
||||
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
||||
|
||||
|
||||
class Extension(NamedTuple):
|
||||
code: str
|
||||
is_valid: bool
|
||||
@ -97,12 +223,12 @@ class ExtensionRelease(BaseModel):
|
||||
version: str
|
||||
archive: str
|
||||
source_repo: str
|
||||
is_github_release = False
|
||||
hash: Optional[str]
|
||||
html_url: Optional[str]
|
||||
description: Optional[str]
|
||||
is_github_release: bool = False
|
||||
hash: Optional[str] = None
|
||||
html_url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
details_html: Optional[str] = None
|
||||
icon: Optional[str]
|
||||
icon: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_github_release(
|
||||
@ -132,52 +258,6 @@ class ExtensionRelease(BaseModel):
|
||||
return []
|
||||
|
||||
|
||||
class ExplicitRelease(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
version: str
|
||||
archive: str
|
||||
hash: str
|
||||
dependencies: List[str] = []
|
||||
icon: Optional[str]
|
||||
short_description: Optional[str]
|
||||
html_url: Optional[str]
|
||||
details: Optional[str]
|
||||
info_notification: Optional[str]
|
||||
critical_notification: Optional[str]
|
||||
|
||||
|
||||
class GitHubRelease(BaseModel):
|
||||
id: str
|
||||
organisation: str
|
||||
repository: str
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
featured: List[str] = []
|
||||
extensions: List["ExplicitRelease"] = []
|
||||
repos: List["GitHubRelease"] = []
|
||||
|
||||
|
||||
class GitHubRepoRelease(BaseModel):
|
||||
name: str
|
||||
tag_name: str
|
||||
zipball_url: str
|
||||
html_url: str
|
||||
|
||||
|
||||
class GitHubRepo(BaseModel):
|
||||
stargazers_count: str
|
||||
html_url: str
|
||||
default_branch: str
|
||||
|
||||
|
||||
class ExtensionConfig(BaseModel):
|
||||
name: str
|
||||
short_description: str
|
||||
tile: str = ""
|
||||
|
||||
|
||||
class InstallableExtension(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -187,8 +267,9 @@ class InstallableExtension(BaseModel):
|
||||
is_admin_only: bool = False
|
||||
stars: int = 0
|
||||
featured = False
|
||||
latest_release: Optional[ExtensionRelease]
|
||||
installed_release: Optional[ExtensionRelease]
|
||||
latest_release: Optional[ExtensionRelease] = None
|
||||
installed_release: Optional[ExtensionRelease] = None
|
||||
archive: Optional[str] = None
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
@ -234,6 +315,7 @@ class InstallableExtension(BaseModel):
|
||||
if ext_zip_file.is_file():
|
||||
os.remove(ext_zip_file)
|
||||
try:
|
||||
assert self.installed_release
|
||||
download_url(self.installed_release.archive, ext_zip_file)
|
||||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
@ -334,8 +416,7 @@ class InstallableExtension(BaseModel):
|
||||
id=github_release.id,
|
||||
name=config.name,
|
||||
short_description=config.short_description,
|
||||
version="0",
|
||||
stars=repo.stargazers_count,
|
||||
stars=int(repo.stargazers_count),
|
||||
icon=icon_to_github_url(
|
||||
f"{github_release.organisation}/{github_release.repository}",
|
||||
config.tile,
|
||||
@ -354,7 +435,6 @@ class InstallableExtension(BaseModel):
|
||||
id=e.id,
|
||||
name=e.name,
|
||||
archive=e.archive,
|
||||
hash=e.hash,
|
||||
short_description=e.short_description,
|
||||
icon=e.icon,
|
||||
dependencies=e.dependencies,
|
||||
@ -443,6 +523,52 @@ class InstallableExtension(BaseModel):
|
||||
return selected_release[0] if len(selected_release) != 0 else None
|
||||
|
||||
|
||||
class InstalledExtensionMiddleware:
|
||||
# This middleware class intercepts calls made to the extensions API and:
|
||||
# - it blocks the calls if the extension has been disabled or uninstalled.
|
||||
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
|
||||
# - otherwise it has no effect
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if "path" not in scope:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path_elements = scope["path"].split("/")
|
||||
if len(path_elements) > 2:
|
||||
_, path_name, path_type, *rest = path_elements
|
||||
tail = "/".join(rest)
|
||||
else:
|
||||
_, path_name = path_elements
|
||||
path_type = None
|
||||
tail = ""
|
||||
|
||||
# block path for all users if the extension is disabled
|
||||
if path_name in settings.lnbits_deactivated_extensions:
|
||||
response = JSONResponse(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
content={"detail": f"Extension '{path_name}' disabled"},
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# re-route API trafic if the extension has been upgraded
|
||||
if path_type == "api":
|
||||
upgraded_extensions = list(
|
||||
filter(
|
||||
lambda ext: ext.endswith(f"/{path_name}"),
|
||||
settings.lnbits_upgraded_extensions,
|
||||
)
|
||||
)
|
||||
if len(upgraded_extensions) != 0:
|
||||
upgrade_path = upgraded_extensions[0]
|
||||
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class CreateExtension(BaseModel):
|
||||
ext_id: str
|
||||
archive: str
|
||||
|
@ -1,25 +1,18 @@
|
||||
# Borrowed from the excellent accent-starlette
|
||||
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
|
||||
|
||||
import typing
|
||||
|
||||
from starlette import templating
|
||||
from jinja2 import BaseLoader, Environment, pass_context
|
||||
from starlette.datastructures import QueryParams
|
||||
from starlette.requests import Request
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError: # pragma: nocover
|
||||
jinja2 = None # type: ignore
|
||||
from starlette.templating import Jinja2Templates as SuperJinja2Templates
|
||||
|
||||
|
||||
class Jinja2Templates(templating.Jinja2Templates):
|
||||
def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231
|
||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
||||
class Jinja2Templates(SuperJinja2Templates):
|
||||
def __init__(self, loader: BaseLoader) -> None:
|
||||
super().__init__("")
|
||||
self.env = self.get_environment(loader)
|
||||
|
||||
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
|
||||
@jinja2.pass_context
|
||||
def get_environment(self, loader: BaseLoader) -> Environment:
|
||||
@pass_context
|
||||
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
|
||||
request: Request = context["request"]
|
||||
return request.app.url_path_for(name, **path_params)
|
||||
@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates):
|
||||
values.update(new)
|
||||
return QueryParams(**values)
|
||||
|
||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
||||
env = Environment(loader=loader, autoescape=True)
|
||||
env.globals["url_for"] = url_for
|
||||
env.globals["url_params_update"] = url_params_update
|
||||
return env
|
||||
|
@ -24,6 +24,7 @@ def list_parse_fallback(v):
|
||||
|
||||
|
||||
class LNbitsSettings(BaseSettings):
|
||||
@classmethod
|
||||
def validate(cls, val):
|
||||
if type(val) == str:
|
||||
val = val.split(",") if val else []
|
||||
@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings):
|
||||
class LNbitsFundingSource(LNbitsSettings):
|
||||
lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
|
||||
lnbits_key: Optional[str] = Field(default=None)
|
||||
lnbits_admin_key: Optional[str] = Field(default=None)
|
||||
lnbits_invoice_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class ClicheFundingSource(LNbitsSettings):
|
||||
@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings):
|
||||
lnpay_api_endpoint: Optional[str] = Field(default=None)
|
||||
lnpay_api_key: Optional[str] = Field(default=None)
|
||||
lnpay_wallet_key: Optional[str] = Field(default=None)
|
||||
lnpay_admin_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class OpenNodeFundingSource(LNbitsSettings):
|
||||
opennode_api_endpoint: Optional[str] = Field(default=None)
|
||||
opennode_key: Optional[str] = Field(default=None)
|
||||
opennode_admin_key: Optional[str] = Field(default=None)
|
||||
opennode_invoice_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class SparkFundingSource(LNbitsSettings):
|
||||
@ -208,8 +214,9 @@ class EditableSettings(
|
||||
"lnbits_admin_extensions",
|
||||
pre=True,
|
||||
)
|
||||
@classmethod
|
||||
def validate_editable_settings(cls, val):
|
||||
return super().validate(cls, val)
|
||||
return super().validate(val)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
@ -281,8 +288,9 @@ class ReadOnlySettings(
|
||||
"lnbits_allowed_funding_sources",
|
||||
pre=True,
|
||||
)
|
||||
@classmethod
|
||||
def validate_readonly_settings(cls, val):
|
||||
return super().validate(cls, val)
|
||||
return super().validate(val)
|
||||
|
||||
@classmethod
|
||||
def readonly_fields(cls):
|
||||
|
@ -3,7 +3,7 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi.exceptions import HTTPException
|
||||
from loguru import logger
|
||||
@ -42,7 +42,7 @@ class SseListenersDict(dict):
|
||||
A dict of sse listeners.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = None):
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
@ -65,7 +65,7 @@ class SseListenersDict(dict):
|
||||
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
|
||||
|
||||
|
||||
def register_invoice_listener(send_chan: asyncio.Queue, name: str = None):
|
||||
def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None):
|
||||
"""
|
||||
A method intended for extensions (and core/tasks.py) to call when they want to be notified about
|
||||
new invoice payments incoming. Will emit all incoming payments.
|
||||
@ -164,7 +164,7 @@ async def check_pending_payments():
|
||||
async def perform_balance_checks():
|
||||
while True:
|
||||
for bc in await get_balance_checks():
|
||||
redeem_lnurl_withdraw(bc.wallet, bc.url)
|
||||
await redeem_lnurl_withdraw(bc.wallet, bc.url)
|
||||
|
||||
await asyncio.sleep(60 * 60 * 6) # every 6 hours
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user