fix pyright lnbits/extensions

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
Pavol Rusnak 2023-02-02 12:57:52 +00:00 committed by dni ⚡
parent 09eab69e12
commit 3bd68cb394
No known key found for this signature in database
GPG key ID: 886317704CC4E618
25 changed files with 142 additions and 160 deletions

View file

@ -80,7 +80,5 @@ async def fetch_fiat_exchange_rate(currency: str, provider: str):
else:
data = {}
getter = exchange_rate_providers[provider]["getter"]
print(getter)
if callable(getter):
rate = float(getter(data, replacements))
return rate
assert callable(getter), "cannot call getter function"
return float(getter(data, replacements))

View file

@ -68,10 +68,11 @@ class Proof(BaseModel):
def from_dict(cls, d: dict):
assert "secret" in d, "no secret in proof"
assert "amount" in d, "no amount in proof"
assert "C" in d, "no C in proof"
return cls(
amount=d.get("amount"),
C=d.get("C"),
secret=d.get("secret"),
amount=d["amount"],
C=d["C"],
secret=d["secret"],
reserved=d.get("reserved") or False,
send_id=d.get("send_id") or "",
time_created=d.get("time_created") or "",

View file

@ -1,7 +1,8 @@
import math
from typing import Tuple
def si_classifier(val):
def si_classifier(val) -> dict:
suffixes = {
24: {"long_suffix": "yotta", "short_suffix": "Y", "scalar": 10**24},
21: {"long_suffix": "zetta", "short_suffix": "Z", "scalar": 10**21},
@ -22,24 +23,22 @@ def si_classifier(val):
-24: {"long_suffix": "yocto", "short_suffix": "y", "scalar": 10**-24},
}
exponent = int(math.floor(math.log10(abs(val)) / 3.0) * 3)
return suffixes.get(exponent, None)
suffix = suffixes.get(exponent)
assert suffix, f"could not classify: {val}"
return suffix
def si_formatter(value):
def si_formatter(value) -> Tuple:
"""
Return a triple of scaled value, short suffix, long suffix, or None if
the value cannot be classified.
"""
classifier = si_classifier(value)
if classifier is None:
# Don't know how to classify this value
return None
scaled = value / classifier["scalar"]
return (scaled, classifier["short_suffix"], classifier["long_suffix"])
return scaled, classifier["short_suffix"], classifier["long_suffix"]
def si_format(value, precision=4, long_form=False, separator=""):
def si_format(value: float, precision=4, long_form=False, separator="") -> str:
"""
"SI prefix" formatted string: return a string with the given precision
and an appropriate order-of-3-magnitudes suffix, e.g.:
@ -47,11 +46,6 @@ def si_format(value, precision=4, long_form=False, separator=""):
si_format(0.00000000123, long_form=True, separator=' ') => '1.230 nano'
"""
scaled, short_suffix, long_suffix = si_formatter(value)
if scaled is None:
# Don't know how to format this value
return value
suffix = long_suffix if long_form else short_suffix
if abs(scaled) < 10:

View file

@ -23,14 +23,14 @@ async def create_livestream(*, wallet_id: str) -> int:
if db.type == SQLITE:
return result._result_proxy.lastrowid
else:
return result[0]
return result[0] # type: ignore
async def get_livestream(ls_id: int) -> Optional[Livestream]:
row = await db.fetchone(
"SELECT * FROM livestream.livestreams WHERE id = ?", (ls_id,)
)
return Livestream(**dict(row)) if row else None
return Livestream(**row) if row else None
async def get_livestream_by_track(track_id: int) -> Optional[Livestream]:
@ -42,7 +42,7 @@ async def get_livestream_by_track(track_id: int) -> Optional[Livestream]:
""",
(track_id,),
)
return Livestream(**dict(row)) if row else None
return Livestream(**row) if row else None
async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream]:
@ -55,7 +55,7 @@ async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream
ls_id = await create_livestream(wallet_id=wallet)
return await get_livestream(ls_id)
return Livestream(**dict(row)) if row else None
return Livestream(**row) if row else None
async def update_current_track(ls_id: int, track_id: Optional[int]):
@ -121,7 +121,7 @@ async def get_track(track_id: Optional[int]) -> Optional[Track]:
""",
(track_id,),
)
return Track(**dict(row)) if row else None
return Track(**row) if row else None
async def get_tracks(livestream: int) -> List[Track]:
@ -132,7 +132,7 @@ async def get_tracks(livestream: int) -> List[Track]:
""",
(livestream,),
)
return [Track(**dict(row)) for row in rows]
return [Track(**row) for row in rows]
async def delete_track_from_livestream(livestream: int, track_id: int):
@ -174,7 +174,7 @@ async def add_producer(livestream: int, name: str) -> int:
if db.type == SQLITE:
return result._result_proxy.lastrowid
else:
return result[0]
return result[0] # type: ignore
async def get_producer(producer_id: int) -> Optional[Producer]:
@ -185,7 +185,7 @@ async def get_producer(producer_id: int) -> Optional[Producer]:
""",
(producer_id,),
)
return Producer(**dict(row)) if row else None
return Producer(**row) if row else None
async def get_producers(livestream: int) -> List[Producer]:
@ -196,4 +196,4 @@ async def get_producers(livestream: int) -> List[Producer]:
""",
(livestream,),
)
return [Producer(**dict(row)) for row in rows]
return [Producer(**row) for row in rows]

View file

@ -1,4 +1,5 @@
import json
from sqlite3 import Row
from typing import Optional
from fastapi import Query, Request
@ -27,6 +28,10 @@ class Livestream(BaseModel):
url = request.url_for("livestream.lnurl_livestream", ls_id=self.id)
return lnurl_encode(url)
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class Track(BaseModel):
id: int
@ -35,6 +40,10 @@ class Track(BaseModel):
name: str
producer: int
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
@property
def min_sendable(self) -> int:
return min(100_000, self.price_msat or 100_000)
@ -88,3 +97,7 @@ class Producer(BaseModel):
user: str
wallet: str
name: str
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))

View file

@ -71,9 +71,7 @@ async def api_domain_create(
if not cf_response or not cf_response["success"]:
await delete_domain(domain.id)
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Problem with cloudflare: "
+ cf_response["errors"][0]["message"],
status_code=HTTPStatus.BAD_REQUEST, detail="Problem with cloudflare."
)
return domain.dict()

View file

@ -61,8 +61,8 @@ class PayLink(BaseModel):
def success_action(self, payment_hash: str) -> Optional[Dict]:
if self.success_url:
url: ParseResult = urlparse(self.success_url)
qs: Dict = parse_qs(url.query)
qs["payment_hash"] = payment_hash
qs = parse_qs(url.query)
setattr(qs, "payment_hash", payment_hash)
url = url._replace(query=urlencode(qs, doseq=True))
return {
"tag": "url",

View file

@ -6,6 +6,7 @@ and delivery to the specific person
import json
from collections import defaultdict
from typing import AsyncGenerator
from fastapi import WebSocket
from loguru import logger
@ -34,7 +35,7 @@ class Notifier:
# Create notification generator:
self.generator = self.get_notification_generator()
async def get_notification_generator(self):
async def get_notification_generator(self) -> AsyncGenerator:
"""Notification Generator"""
while True:
@ -54,9 +55,8 @@ class Notifier:
logger.exception(f"There is no member in room: {room_name}")
return None
async def push(self, message: str, room_name: str = None):
async def push(self, message: str, room_name: str):
"""Push a message"""
message_body = {"message": message, "room_name": room_name}
await self.generator.asend(message_body)

View file

@ -1,15 +1,8 @@
import json
from http import HTTPStatus
from fastapi import (
BackgroundTasks,
Depends,
Query,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi import Depends, Query, Request, WebSocket, WebSocketDisconnect
from fastapi.templating import Jinja2Templates
from loguru import logger
from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse
@ -147,24 +140,14 @@ notifier = Notifier()
@market_ext.websocket("/ws/{room_name}")
async def websocket_endpoint(
websocket: WebSocket, room_name: str, background_tasks: BackgroundTasks
):
async def websocket_endpoint(websocket: WebSocket, room_name: str):
await notifier.connect(websocket, room_name)
try:
while True:
data = await websocket.receive_text()
d = json.loads(data)
d["room_name"] = room_name
room_members = (
notifier.get_members(room_name)
if notifier.get_members(room_name) is not None
else []
)
room_members = notifier.get_members(room_name) or []
if websocket not in room_members:
print("Sender not in room member: Reconnecting...")
logger.warning("Sender not in room member: Reconnecting...")
await notifier.connect(websocket, room_name)
await notifier._notify(data, room_name)

View file

@ -1,4 +1,5 @@
from http import HTTPStatus
from typing import Optional
from fastapi import Depends, Query
from loguru import logger
@ -224,7 +225,7 @@ async def api_market_stalls(
@market_ext.put("/api/v1/stalls/{stall_id}")
async def api_market_stall_create(
data: createStalls,
stall_id: str = None,
stall_id: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_invoice_key),
):
@ -447,7 +448,7 @@ async def api_market_market_stalls(market_id: str):
@market_ext.put("/api/v1/markets/{market_id}")
async def api_market_market_create(
data: CreateMarket,
market_id: str = None,
market_id: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_invoice_key),
):
if market_id:
@ -506,7 +507,7 @@ async def api_get_settings(wallet: WalletTypeInfo = Depends(require_admin_key)):
@market_ext.put("/api/v1/settings/{usr}")
async def api_set_settings(
data: SetSettings,
usr: str = None,
usr: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_admin_key),
):
if usr:

View file

@ -2,7 +2,7 @@ from os import getenv
from fastapi import Depends, Request
from fastapi.templating import Jinja2Templates
from pyngrok import conf, ngrok
from pyngrok import conf, ngrok # type: ignore
from lnbits.core.models import User
from lnbits.decorators import check_user_exists

View file

@ -22,12 +22,12 @@ async def create_shop(*, wallet_id: str) -> int:
if db.type == SQLITE:
return result._result_proxy.lastrowid
else:
return result[0]
return result[0] # type: ignore
async def get_shop(id: int) -> Optional[Shop]:
row = await db.fetchone("SELECT * FROM offlineshop.shops WHERE id = ?", (id,))
return Shop(**dict(row)) if row else None
return Shop(**row) if row else None
async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]:
@ -40,7 +40,7 @@ async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]:
ls_id = await create_shop(wallet_id=wallet)
return await get_shop(ls_id)
return Shop(**dict(row)) if row else None
return Shop(**row) if row else None
async def set_method(shop: int, method: str, wordlist: str = "") -> Optional[Shop]:

View file

@ -52,7 +52,7 @@ class ShopCounter:
# cleanup confirmation words cache
to_remove = len(self.fulfilled_payments) - 23
if to_remove > 0:
for i in range(to_remove):
for _ in range(to_remove):
self.fulfilled_payments.popitem(False)
return word
@ -64,6 +64,10 @@ class Shop(BaseModel):
method: str
wordlist: str
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
@property
def otp_key(self) -> str:
return base64.b32encode(

View file

@ -18,8 +18,6 @@ from .crud import (
)
from .models import CreateSatsDicePayment
##############LNURLP STUFF
@satsdice_ext.get(
"/api/v1/lnurlp/{link_id}",
@ -84,7 +82,7 @@ async def api_lnurlp_callback(
data = CreateSatsDicePayment(
satsdice_pay=link.id,
value=amount_received / 1000,
value=int(amount_received / 1000),
payment_hash=payment_hash,
)

View file

@ -13,7 +13,7 @@ from .helpers import fetch_onchain_balance
from .models import Charges, CreateCharge, SatsPayThemes
async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]:
async def create_charge(user: str, data: CreateCharge) -> Charges:
data = CreateCharge(**data.dict())
charge_id = urlsafe_short_hash()
if data.onchainwallet:
@ -79,7 +79,9 @@ async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]:
data.custom_css,
),
)
return await get_charge(charge_id)
charge = await get_charge(charge_id)
assert charge, "Newly created charge does not exist"
return charge
async def update_charge(charge_id: str, **kwargs) -> Optional[Charges]:

View file

@ -10,7 +10,7 @@ async def get_targets(source_wallet: str) -> List[Target]:
rows = await db.fetchall(
"SELECT * FROM splitpayments.targets WHERE source = ?", (source_wallet,)
)
return [Target(**dict(row)) for row in rows]
return [Target(**row) for row in rows]
async def set_targets(source_wallet: str, targets: List[Target]):

View file

@ -1,6 +1,7 @@
from sqlite3 import Row
from typing import List, Optional
from fastapi.param_functions import Query
from fastapi import Query
from pydantic import BaseModel
@ -11,6 +12,10 @@ class Target(BaseModel):
tag: str
alias: Optional[str]
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class TargetPutList(BaseModel):
wallet: str = Query(...)

View file

@ -147,14 +147,16 @@ async def create_service(data: CreateService) -> Service:
if db.type == SQLITE:
service_id = result._result_proxy.lastrowid
else:
service_id = result[0]
service_id = result[0] # type: ignore
service = await get_service(service_id)
assert service
return service
async def get_service(service_id: int, by_state: str = None) -> Optional[Service]:
async def get_service(
service_id: int, by_state: Optional[str] = None
) -> Optional[Service]:
"""Return a service either by ID or, available, by state
Each Service's donation page is reached through its "state" hash
@ -184,7 +186,9 @@ async def authenticate_service(service_id, code, redirect_uri):
"""Use authentication code from third party API to retreive access token"""
# The API token is passed in the querystring as 'code'
service = await get_service(service_id)
assert service
wallet = await get_wallet(service.wallet)
assert wallet
user = wallet.user
url = "https://streamlabs.com/api/v1.0/token"
data = {
@ -208,8 +212,11 @@ async def service_add_token(service_id, token):
is not overwritten.
Tokens for Streamlabs never need to be refreshed.
"""
if (await get_service(service_id)).authenticated:
service = await get_service(service_id)
assert service
if service.authenticated:
return False
await db.execute(
"UPDATE streamalerts.Services SET authenticated = 1, token = ? where id = ?",
(token, service_id),

View file

@ -1,5 +1,3 @@
import json
import httpx
from .models import Domains
@ -20,9 +18,7 @@ async def cloudflare_create_subdomain(
"Content-Type": "application/json",
}
aRecord = subdomain + "." + domain.domain
cf_response = ""
async with httpx.AsyncClient() as client:
try:
r = await client.post(
url,
headers=header,
@ -35,10 +31,8 @@ async def cloudflare_create_subdomain(
},
timeout=40,
)
cf_response = json.loads(r.text)
except AssertionError:
cf_response = "Error occured"
return cf_response
r.raise_for_status()
return r.json()
async def cloudflare_deletesubdomain(domain: Domains, domain_id: str):
@ -52,8 +46,4 @@ async def cloudflare_deletesubdomain(domain: Domains, domain_id: str):
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
try:
r = await client.delete(url + "/" + domain_id, headers=header, timeout=40)
r.text
except AssertionError:
pass
await client.delete(url + "/" + domain_id, headers=header, timeout=40)

View file

@ -121,21 +121,21 @@ async def api_subdomain_make_subdomain(domain_id, data: CreateSubdomain):
detail=f"{data.subdomain}.{domain.domain} domain already taken.",
)
## Dry run cloudflare... (create and if create is sucessful delete it)
cf_response = await cloudflare_create_subdomain(
## Dry run cloudflare... (create and if create is successful delete it)
try:
res_json = await cloudflare_create_subdomain(
domain=domain,
subdomain=data.subdomain,
record_type=data.record_type,
ip=data.ip,
)
if cf_response["success"] is True:
await cloudflare_deletesubdomain(
domain=domain, domain_id=cf_response["result"]["id"]
domain=domain, domain_id=res_json["result"]["id"]
)
else:
except:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f'Problem with cloudflare: {cf_response["errors"][0]["message"]}',
detail="Problem with cloudflare.",
)
## ALL OK - create an invoice and return it to the user

View file

@ -1,20 +0,0 @@
from lnbits.core.crud import get_wallet
from .crud import get_tipjar
async def get_charge_details(tipjar_id):
"""Return the default details for a satspay charge"""
tipjar = await get_tipjar(tipjar_id)
wallet_id = tipjar.wallet
wallet = await get_wallet(wallet_id)
user = wallet.user
details = {
"time": 1440,
"user": user,
"lnbitswallet": wallet_id,
"onchainwallet": tipjar.onchain,
"completelink": "/tipjar/" + str(tipjar_id),
"completelinktext": "Thanks for the tip!",
}
return details

View file

@ -1,7 +1,6 @@
from http import HTTPStatus
from fastapi import Depends, Request
from fastapi.param_functions import Query
from fastapi import Depends, Query, Request
from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException

View file

@ -3,7 +3,7 @@ from http import HTTPStatus
from fastapi import Depends, Query
from starlette.exceptions import HTTPException
from lnbits.core.crud import get_user
from lnbits.core.crud import get_user, get_wallet
from lnbits.decorators import WalletTypeInfo, get_key_type
# todo: use the API, not direct import
@ -22,7 +22,6 @@ from .crud import (
update_tip,
update_tipjar,
)
from .helpers import get_charge_details
from .models import createTip, createTipJar, createTips
@ -55,25 +54,32 @@ async def api_create_tip(data: createTips):
status_code=HTTPStatus.NOT_FOUND, detail="Tipjar does not exist."
)
webhook = tipjar.webhook
charge_details = await get_charge_details(tipjar.id)
wallet_id = tipjar.wallet
wallet = await get_wallet(wallet_id)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Tipjar wallet does not exist."
)
name = data.name
# Ensure that description string can be split reliably
name = name.replace('"', "''")
if not name:
name = "Anonymous"
description = f"{name}: {message}"
charge = await create_charge(
user=charge_details["user"],
user=wallet.user,
data=CreateCharge(
amount=sats,
webhook=webhook or "",
webhook=tipjar.webhook or "",
description=description,
onchainwallet=charge_details["onchainwallet"],
lnbitswallet=charge_details["lnbitswallet"],
completelink=charge_details["completelink"],
completelinktext=charge_details["completelinktext"],
time=charge_details["time"],
onchainwallet=tipjar.onchain or "",
lnbitswallet=tipjar.wallet,
completelink="/tipjar/" + str(tipjar_id),
completelinktext="Thanks for the tip!",
time=1440,
custom_css="",
),
)

View file

@ -1,3 +1,5 @@
from typing import Optional, Tuple
from embit.descriptor import Descriptor, Key
from embit.descriptor.arguments import AllowedDerivation
from embit.networks import NETWORKS
@ -12,7 +14,7 @@ def detect_network(k):
return net
def parse_key(masterpub: str) -> Descriptor:
def parse_key(masterpub: str) -> Tuple[Descriptor, Optional[dict]]:
"""Parses masterpub or descriptor and returns a tuple: (Descriptor, network)
To create addresses use descriptor.derive(num).address(network=network)
"""
@ -34,6 +36,7 @@ def parse_key(masterpub: str) -> Descriptor:
k.allowed_derivation = AllowedDerivation.default()
# get version bytes
version = k.key.version
desc = Descriptor()
for network_name in NETWORKS:
net = NETWORKS[network_name]
# not found in this network
@ -47,8 +50,9 @@ def parse_key(masterpub: str) -> Descriptor:
desc = Descriptor.from_string("wpkh(%s)" % str(k))
break
# we didn't find correct version
if network is None:
if not network:
raise ValueError("Unknown master public key version")
else:
desc = Descriptor.from_string(masterpub)
if not desc.is_wildcard:
@ -61,6 +65,7 @@ def parse_key(masterpub: str) -> Descriptor:
if network is not None and network != net:
raise ValueError("Keys from different networks")
network = net
return desc, network

View file

@ -73,7 +73,8 @@ async def api_wallet_create_or_update(
data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key)
):
try:
(descriptor, network) = parse_key(data.masterpub)
descriptor, network = parse_key(data.masterpub)
assert network
if data.network != network["name"]:
raise ValueError(
"Account network error. This account is for '{}'".format(
@ -308,12 +309,9 @@ async def api_psbt_utxos_tx(
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
@watchonly_ext.put("/api/v1/psbt/extract")
async def api_psbt_extract_tx(
data: ExtractPsbt, w: WalletTypeInfo = Depends(require_admin_key)
):
@watchonly_ext.put("/api/v1/psbt/extract", dependencies=[Depends(require_admin_key)])
async def api_psbt_extract_tx(data: ExtractPsbt):
network = NETWORKS["main"] if data.network == "Mainnet" else NETWORKS["test"]
res = SignedTransaction()
try:
psbt = PSBT.from_base64(data.psbtBase64)
for i, inp in enumerate(data.inputs):
@ -322,9 +320,9 @@ async def api_psbt_extract_tx(
final_psbt = finalizer.finalize_psbt(psbt)
if not final_psbt:
raise ValueError("PSBT cannot be finalized!")
res.tx_hex = final_psbt.to_string()
transaction = Transaction.from_string(res.tx_hex)
tx_hex = final_psbt.to_string()
transaction = Transaction.from_string(tx_hex)
tx = {
"locktime": transaction.locktime,
"version": transaction.version,
@ -336,10 +334,10 @@ async def api_psbt_extract_tx(
tx["outputs"].append(
{"amount": out.value, "address": out.script_pubkey.address(network)}
)
res.tx_json = json.dumps(tx)
signed_tx = SignedTransaction(tx_hex=tx_hex, tx_json=json.dumps(tx))
return signed_tx.dict()
except Exception as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
return res.dict()
@watchonly_ext.post("/api/v1/tx")