fix pyright lnbits

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
Pavol Rusnak 2023-02-02 12:58:23 +00:00 committed by dni ⚡
parent 3855cf47f3
commit 02306148df
No known key found for this signature in database
GPG Key ID: 886317704CC4E618
8 changed files with 266 additions and 116 deletions

View File

@ -9,6 +9,7 @@ import secp256k1
from bech32 import CHARSET, bech32_decode, bech32_encode
from ecdsa import SECP256k1, VerifyingKey
from ecdsa.util import sigdecode_string
from loguru import logger
class Route(NamedTuple):
@ -30,6 +31,7 @@ class Invoice:
secret: Optional[str] = None
route_hints: List[Route] = []
min_final_cltv_expiry: int = 18
checking_id: Optional[str] = None
def decode(pr: str) -> Invoice:
@ -66,11 +68,13 @@ def decode(pr: str) -> Invoice:
invoice.amount_msat = _unshorten_amount(amountstr)
# pull out date
invoice.date = data.read(35).uint
date_bin = data.read(35)
assert date_bin
invoice.date = date_bin.uint
while data.pos != data.len:
tag, tagdata, data = _pull_tagged(data)
data_length = len(tagdata) / 5
data_length = len(tagdata or []) / 5
if tag == "d":
invoice.description = _trim_to_bytes(tagdata).decode()
@ -89,12 +93,22 @@ def decode(pr: str) -> Invoice:
elif tag == "r":
s = bitstring.ConstBitStream(tagdata)
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
pubkey = s.read(264)
assert pubkey
short_channel_id = s.read(64)
assert short_channel_id
base_fee_msat = s.read(32)
assert base_fee_msat
ppm_fee = s.read(32)
assert ppm_fee
cltv = s.read(16)
assert cltv
route = Route(
pubkey=s.read(264).tobytes().hex(),
short_channel_id=_readable_scid(s.read(64).intbe),
base_fee_msat=s.read(32).intbe,
ppm_fee=s.read(32).intbe,
cltv=s.read(16).intbe,
pubkey=pubkey.tobytes().hex(),
short_channel_id=_readable_scid(short_channel_id.intbe),
base_fee_msat=base_fee_msat.intbe,
ppm_fee=ppm_fee.intbe,
cltv=cltv.intbe,
)
invoice.route_hints.append(route)
@ -160,6 +174,10 @@ def encode(options):
return lnencode(addr, options["privkey"])
def encode_fallback(v, currency):
logger.error(f"hit bolt11.py encode_fallback with v: {v} and currency: {currency}")
def lnencode(addr, privkey):
if addr.amount:
amount = Decimal(str(addr.amount))
@ -244,7 +262,13 @@ def lnencode(addr, privkey):
class LnAddr:
def __init__(
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None
self,
paymenthash=None,
amount=None,
currency="bc",
tags=None,
date=None,
fallback=None,
):
self.date = int(time.time()) if not date else int(date)
self.tags = [] if not tags else tags
@ -252,6 +276,7 @@ class LnAddr:
self.paymenthash = paymenthash
self.signature = None
self.pubkey = None
self.fallback = fallback
self.currency = currency
self.amount = amount
@ -266,6 +291,7 @@ def shorten_amount(amount):
# Convert to pico initially
amount = int(amount * 10**12)
units = ["p", "n", "u", "m", ""]
unit = ""
for unit in units:
if amount % 1000 == 0:
amount //= 1000
@ -304,14 +330,6 @@ def _pull_tagged(stream):
return (CHARSET[tag], stream.read(length * 5), stream)
def is_p2pkh(currency, prefix):
return prefix == base58_prefix_map[currency][0]
def is_p2sh(currency, prefix):
return prefix == base58_prefix_map[currency][1]
# Tagged field containing BitArray
def tagged(char, l):
# Tagged fields need to be zero-padded to 5 bits.
@ -359,5 +377,5 @@ def bitarray_to_u5(barr):
ret = []
s = bitstring.ConstBitStream(barr)
while s.pos != s.len:
ret.append(s.read(5).uint)
ret.append(s.read(5).uint) # type: ignore
return ret

View File

@ -41,6 +41,7 @@ async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
async with core_db.connect() as conn:
exists = False
if conn.type == SQLITE:
exists = await conn.fetchone(
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"

View File

@ -131,7 +131,7 @@ class Database(Compat):
else:
self.type = POSTGRES
import psycopg2
from psycopg2.extensions import DECIMAL, new_type, register_type
def _parse_timestamp(value, _):
if value is None:
@ -141,15 +141,15 @@ class Database(Compat):
f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
psycopg2.extensions.DECIMAL.values,
register_type(
new_type(
DECIMAL.values,
"DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
register_type(
new_type(
(1082, 1083, 1266),
"DATE2INT",
lambda value, curs: time.mktime(value.timetuple())
@ -158,11 +158,7 @@ class Database(Compat):
)
)
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
(1184, 1114), "TIMESTAMP2INT", _parse_timestamp
)
)
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
else:
if os.path.isdir(settings.lnbits_data_folder):
self.path = os.path.join(

View File

@ -1,14 +1,12 @@
from http import HTTPStatus
from typing import Optional, Type
from fastapi import Security, status
from fastapi.exceptions import HTTPException
from fastapi import HTTPException, Request, Security, status
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.security.base import SecurityBase
from pydantic import BaseModel
from pydantic.types import UUID4
from starlette.requests import Request
from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet
@ -17,9 +15,13 @@ from lnbits.requestvars import g
from lnbits.settings import settings
# TODO: fix type ignores
class KeyChecker(SecurityBase):
def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
@ -27,13 +29,13 @@ class KeyChecker(SecurityBase):
self._api_key = api_key
if api_key:
key = APIKey(
**{"in": APIKeyIn.query},
**{"in": APIKeyIn.query}, # type: ignore
name="X-API-KEY",
description="Wallet API Key - QUERY",
)
else:
key = APIKey(
**{"in": APIKeyIn.header},
**{"in": APIKeyIn.header}, # type: ignore
name="X-API-KEY",
description="Wallet API Key - HEADER",
)
@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker):
"""
def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
):
super().__init__(scheme_name, auto_error, api_key)
self._key_type = "invoice"
@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker):
"""
def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
):
super().__init__(scheme_name, auto_error, api_key)
self._key_type = "admin"

View File

@ -3,20 +3,146 @@ import json
import os
import shutil
import sys
import urllib.request
import zipfile
from http import HTTPStatus
from pathlib import Path
from typing import Any, List, NamedTuple, Optional, Tuple
from urllib import request
import httpx
from fastapi.exceptions import HTTPException
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from loguru import logger
from pydantic import BaseModel
from lnbits.settings import settings
class ExplicitRelease(BaseModel):
id: str
name: str
version: str
archive: str
hash: str
dependencies: List[str] = []
icon: Optional[str]
short_description: Optional[str]
html_url: Optional[str]
details: Optional[str]
info_notification: Optional[str]
critical_notification: Optional[str]
class GitHubRelease(BaseModel):
id: str
organisation: str
repository: str
class Manifest(BaseModel):
featured: List[str] = []
extensions: List["ExplicitRelease"] = []
repos: List["GitHubRelease"] = []
class GitHubRepoRelease(BaseModel):
name: str
tag_name: str
zipball_url: str
html_url: str
class GitHubRepo(BaseModel):
stargazers_count: str
html_url: str
default_branch: str
class ExtensionConfig(BaseModel):
name: str
short_description: str
tile: str = ""
def download_url(url, save_path):
with request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
async def fetch_github_repo_info(
org: str, repository: str
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await gihub_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await gihub_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
async def fetch_manifest(url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await gihub_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await gihub_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
async with httpx.AsyncClient() as client:
headers = (
{"Authorization": "Bearer " + settings.lnbits_ext_github_token}
if settings.lnbits_ext_github_token
else None
)
resp = await client.get(
url,
headers=headers,
)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"
class Extension(NamedTuple):
code: str
is_valid: bool
@ -97,12 +223,12 @@ class ExtensionRelease(BaseModel):
version: str
archive: str
source_repo: str
is_github_release = False
hash: Optional[str]
html_url: Optional[str]
description: Optional[str]
is_github_release: bool = False
hash: Optional[str] = None
html_url: Optional[str] = None
description: Optional[str] = None
details_html: Optional[str] = None
icon: Optional[str]
icon: Optional[str] = None
@classmethod
def from_github_release(
@ -132,52 +258,6 @@ class ExtensionRelease(BaseModel):
return []
class ExplicitRelease(BaseModel):
id: str
name: str
version: str
archive: str
hash: str
dependencies: List[str] = []
icon: Optional[str]
short_description: Optional[str]
html_url: Optional[str]
details: Optional[str]
info_notification: Optional[str]
critical_notification: Optional[str]
class GitHubRelease(BaseModel):
id: str
organisation: str
repository: str
class Manifest(BaseModel):
featured: List[str] = []
extensions: List["ExplicitRelease"] = []
repos: List["GitHubRelease"] = []
class GitHubRepoRelease(BaseModel):
name: str
tag_name: str
zipball_url: str
html_url: str
class GitHubRepo(BaseModel):
stargazers_count: str
html_url: str
default_branch: str
class ExtensionConfig(BaseModel):
name: str
short_description: str
tile: str = ""
class InstallableExtension(BaseModel):
id: str
name: str
@ -187,8 +267,9 @@ class InstallableExtension(BaseModel):
is_admin_only: bool = False
stars: int = 0
featured = False
latest_release: Optional[ExtensionRelease]
installed_release: Optional[ExtensionRelease]
latest_release: Optional[ExtensionRelease] = None
installed_release: Optional[ExtensionRelease] = None
archive: Optional[str] = None
@property
def hash(self) -> str:
@ -234,6 +315,7 @@ class InstallableExtension(BaseModel):
if ext_zip_file.is_file():
os.remove(ext_zip_file)
try:
assert self.installed_release
download_url(self.installed_release.archive, ext_zip_file)
except Exception as ex:
logger.warning(ex)
@ -334,8 +416,7 @@ class InstallableExtension(BaseModel):
id=github_release.id,
name=config.name,
short_description=config.short_description,
version="0",
stars=repo.stargazers_count,
stars=int(repo.stargazers_count),
icon=icon_to_github_url(
f"{github_release.organisation}/{github_release.repository}",
config.tile,
@ -354,7 +435,6 @@ class InstallableExtension(BaseModel):
id=e.id,
name=e.name,
archive=e.archive,
hash=e.hash,
short_description=e.short_description,
icon=e.icon,
dependencies=e.dependencies,
@ -443,6 +523,52 @@ class InstallableExtension(BaseModel):
return selected_release[0] if len(selected_release) != 0 else None
class InstalledExtensionMiddleware:
# This middleware class intercepts calls made to the extensions API and:
# - it blocks the calls if the extension has been disabled or uninstalled.
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
# - otherwise it has no effect
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" not in scope:
await self.app(scope, receive, send)
return
path_elements = scope["path"].split("/")
if len(path_elements) > 2:
_, path_name, path_type, *rest = path_elements
tail = "/".join(rest)
else:
_, path_name = path_elements
path_type = None
tail = ""
# block path for all users if the extension is disabled
if path_name in settings.lnbits_deactivated_extensions:
response = JSONResponse(
status_code=HTTPStatus.NOT_FOUND,
content={"detail": f"Extension '{path_name}' disabled"},
)
await response(scope, receive, send)
return
# re-route API trafic if the extension has been upgraded
if path_type == "api":
upgraded_extensions = list(
filter(
lambda ext: ext.endswith(f"/{path_name}"),
settings.lnbits_upgraded_extensions,
)
)
if len(upgraded_extensions) != 0:
upgrade_path = upgraded_extensions[0]
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
await self.app(scope, receive, send)
class CreateExtension(BaseModel):
ext_id: str
archive: str

View File

@ -1,25 +1,18 @@
# Borrowed from the excellent accent-starlette
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
import typing
from starlette import templating
from jinja2 import BaseLoader, Environment, pass_context
from starlette.datastructures import QueryParams
from starlette.requests import Request
try:
import jinja2
except ImportError: # pragma: nocover
jinja2 = None # type: ignore
from starlette.templating import Jinja2Templates as SuperJinja2Templates
class Jinja2Templates(templating.Jinja2Templates):
def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
class Jinja2Templates(SuperJinja2Templates):
def __init__(self, loader: BaseLoader) -> None:
super().__init__("")
self.env = self.get_environment(loader)
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
@jinja2.pass_context
def get_environment(self, loader: BaseLoader) -> Environment:
@pass_context
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request: Request = context["request"]
return request.app.url_path_for(name, **path_params)
@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates):
values.update(new)
return QueryParams(**values)
env = jinja2.Environment(loader=loader, autoescape=True)
env = Environment(loader=loader, autoescape=True)
env.globals["url_for"] = url_for
env.globals["url_params_update"] = url_params_update
return env

View File

@ -24,6 +24,7 @@ def list_parse_fallback(v):
class LNbitsSettings(BaseSettings):
@classmethod
def validate(cls, val):
if type(val) == str:
val = val.split(",") if val else []
@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings):
class LNbitsFundingSource(LNbitsSettings):
lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
lnbits_key: Optional[str] = Field(default=None)
lnbits_admin_key: Optional[str] = Field(default=None)
lnbits_invoice_key: Optional[str] = Field(default=None)
class ClicheFundingSource(LNbitsSettings):
@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings):
lnpay_api_endpoint: Optional[str] = Field(default=None)
lnpay_api_key: Optional[str] = Field(default=None)
lnpay_wallet_key: Optional[str] = Field(default=None)
lnpay_admin_key: Optional[str] = Field(default=None)
class OpenNodeFundingSource(LNbitsSettings):
opennode_api_endpoint: Optional[str] = Field(default=None)
opennode_key: Optional[str] = Field(default=None)
opennode_admin_key: Optional[str] = Field(default=None)
opennode_invoice_key: Optional[str] = Field(default=None)
class SparkFundingSource(LNbitsSettings):
@ -208,8 +214,9 @@ class EditableSettings(
"lnbits_admin_extensions",
pre=True,
)
@classmethod
def validate_editable_settings(cls, val):
return super().validate(cls, val)
return super().validate(val)
@classmethod
def from_dict(cls, d: dict):
@ -281,8 +288,9 @@ class ReadOnlySettings(
"lnbits_allowed_funding_sources",
pre=True,
)
@classmethod
def validate_readonly_settings(cls, val):
return super().validate(cls, val)
return super().validate(val)
@classmethod
def readonly_fields(cls):

View File

@ -3,7 +3,7 @@ import time
import traceback
import uuid
from http import HTTPStatus
from typing import Dict
from typing import Dict, Optional
from fastapi.exceptions import HTTPException
from loguru import logger
@ -42,7 +42,7 @@ class SseListenersDict(dict):
A dict of sse listeners.
"""
def __init__(self, name: str = None):
def __init__(self, name: Optional[str] = None):
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
def __setitem__(self, key, value):
@ -65,7 +65,7 @@ class SseListenersDict(dict):
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
def register_invoice_listener(send_chan: asyncio.Queue, name: str = None):
def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None):
"""
A method intended for extensions (and core/tasks.py) to call when they want to be notified about
new invoice payments incoming. Will emit all incoming payments.
@ -164,7 +164,7 @@ async def check_pending_payments():
async def perform_balance_checks():
while True:
for bc in await get_balance_checks():
redeem_lnurl_withdraw(bc.wallet, bc.url)
await redeem_lnurl_withdraw(bc.wallet, bc.url)
await asyncio.sleep(60 * 60 * 6) # every 6 hours