From 3bd68cb394afcccc28df27f4bf47a248ebf7131e Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Thu, 2 Feb 2023 12:57:52 +0000 Subject: [PATCH] fix pyright lnbits/extensions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: dni ⚡ --- lnbits/extensions/bleskomat/exchange_rates.py | 6 +-- lnbits/extensions/cashu/models.py | 7 ++-- lnbits/extensions/gerty/number_prefixer.py | 22 ++++------ lnbits/extensions/livestream/crud.py | 18 ++++----- lnbits/extensions/livestream/models.py | 13 ++++++ lnbits/extensions/lnaddress/views_api.py | 4 +- lnbits/extensions/lnurlp/models.py | 4 +- lnbits/extensions/market/notifier.py | 6 +-- lnbits/extensions/market/views.py | 27 +++---------- lnbits/extensions/market/views_api.py | 7 ++-- lnbits/extensions/ngrok/views.py | 2 +- lnbits/extensions/offlineshop/crud.py | 6 +-- lnbits/extensions/offlineshop/models.py | 6 ++- lnbits/extensions/satsdice/lnurl.py | 4 +- lnbits/extensions/satspay/crud.py | 6 ++- lnbits/extensions/splitpayments/crud.py | 2 +- lnbits/extensions/splitpayments/models.py | 7 +++- lnbits/extensions/streamalerts/crud.py | 13 ++++-- lnbits/extensions/subdomains/cloudflare.py | 40 +++++++------------ lnbits/extensions/subdomains/views_api.py | 24 +++++------ lnbits/extensions/tipjar/helpers.py | 20 ---------- lnbits/extensions/tipjar/views.py | 3 +- lnbits/extensions/tipjar/views_api.py | 28 ++++++++----- lnbits/extensions/watchonly/helpers.py | 9 ++++- lnbits/extensions/watchonly/views_api.py | 18 ++++----- 25 files changed, 142 insertions(+), 160 deletions(-) delete mode 100644 lnbits/extensions/tipjar/helpers.py diff --git a/lnbits/extensions/bleskomat/exchange_rates.py b/lnbits/extensions/bleskomat/exchange_rates.py index aff9ce657..b0a3969c7 100644 --- a/lnbits/extensions/bleskomat/exchange_rates.py +++ b/lnbits/extensions/bleskomat/exchange_rates.py @@ -80,7 +80,5 @@ async def fetch_fiat_exchange_rate(currency: str, provider: str): else: data = {} getter = exchange_rate_providers[provider]["getter"] - print(getter) - if callable(getter): - rate = float(getter(data, replacements)) - return rate + assert callable(getter), "cannot call getter function" + return float(getter(data, replacements)) diff --git a/lnbits/extensions/cashu/models.py b/lnbits/extensions/cashu/models.py index 84f28c2bc..aaff195f8 100644 --- a/lnbits/extensions/cashu/models.py +++ b/lnbits/extensions/cashu/models.py @@ -68,10 +68,11 @@ class Proof(BaseModel): def from_dict(cls, d: dict): assert "secret" in d, "no secret in proof" assert "amount" in d, "no amount in proof" + assert "C" in d, "no C in proof" return cls( - amount=d.get("amount"), - C=d.get("C"), - secret=d.get("secret"), + amount=d["amount"], + C=d["C"], + secret=d["secret"], reserved=d.get("reserved") or False, send_id=d.get("send_id") or "", time_created=d.get("time_created") or "", diff --git a/lnbits/extensions/gerty/number_prefixer.py b/lnbits/extensions/gerty/number_prefixer.py index dca001087..78f583bf4 100644 --- a/lnbits/extensions/gerty/number_prefixer.py +++ b/lnbits/extensions/gerty/number_prefixer.py @@ -1,7 +1,8 @@ import math +from typing import Tuple -def si_classifier(val): +def si_classifier(val) -> dict: suffixes = { 24: {"long_suffix": "yotta", "short_suffix": "Y", "scalar": 10**24}, 21: {"long_suffix": "zetta", "short_suffix": "Z", "scalar": 10**21}, @@ -22,24 +23,22 @@ def si_classifier(val): -24: {"long_suffix": "yocto", "short_suffix": "y", "scalar": 10**-24}, } exponent = int(math.floor(math.log10(abs(val)) / 3.0) * 3) - return suffixes.get(exponent, None) + suffix = suffixes.get(exponent) + assert suffix, f"could not classify: {val}" + return suffix -def si_formatter(value): +def si_formatter(value) -> Tuple: """ Return a triple of scaled value, short suffix, long suffix, or None if the value cannot be classified. """ classifier = si_classifier(value) - if classifier is None: - # Don't know how to classify this value - return None - scaled = value / classifier["scalar"] - return (scaled, classifier["short_suffix"], classifier["long_suffix"]) + return scaled, classifier["short_suffix"], classifier["long_suffix"] -def si_format(value, precision=4, long_form=False, separator=""): +def si_format(value: float, precision=4, long_form=False, separator="") -> str: """ "SI prefix" formatted string: return a string with the given precision and an appropriate order-of-3-magnitudes suffix, e.g.: @@ -47,11 +46,6 @@ def si_format(value, precision=4, long_form=False, separator=""): si_format(0.00000000123, long_form=True, separator=' ') => '1.230 nano' """ scaled, short_suffix, long_suffix = si_formatter(value) - - if scaled is None: - # Don't know how to format this value - return value - suffix = long_suffix if long_form else short_suffix if abs(scaled) < 10: diff --git a/lnbits/extensions/livestream/crud.py b/lnbits/extensions/livestream/crud.py index 4784494c0..fe4194eb0 100644 --- a/lnbits/extensions/livestream/crud.py +++ b/lnbits/extensions/livestream/crud.py @@ -23,14 +23,14 @@ async def create_livestream(*, wallet_id: str) -> int: if db.type == SQLITE: return result._result_proxy.lastrowid else: - return result[0] + return result[0] # type: ignore async def get_livestream(ls_id: int) -> Optional[Livestream]: row = await db.fetchone( "SELECT * FROM livestream.livestreams WHERE id = ?", (ls_id,) ) - return Livestream(**dict(row)) if row else None + return Livestream(**row) if row else None async def get_livestream_by_track(track_id: int) -> Optional[Livestream]: @@ -42,7 +42,7 @@ async def get_livestream_by_track(track_id: int) -> Optional[Livestream]: """, (track_id,), ) - return Livestream(**dict(row)) if row else None + return Livestream(**row) if row else None async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream]: @@ -55,7 +55,7 @@ async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream ls_id = await create_livestream(wallet_id=wallet) return await get_livestream(ls_id) - return Livestream(**dict(row)) if row else None + return Livestream(**row) if row else None async def update_current_track(ls_id: int, track_id: Optional[int]): @@ -121,7 +121,7 @@ async def get_track(track_id: Optional[int]) -> Optional[Track]: """, (track_id,), ) - return Track(**dict(row)) if row else None + return Track(**row) if row else None async def get_tracks(livestream: int) -> List[Track]: @@ -132,7 +132,7 @@ async def get_tracks(livestream: int) -> List[Track]: """, (livestream,), ) - return [Track(**dict(row)) for row in rows] + return [Track(**row) for row in rows] async def delete_track_from_livestream(livestream: int, track_id: int): @@ -174,7 +174,7 @@ async def add_producer(livestream: int, name: str) -> int: if db.type == SQLITE: return result._result_proxy.lastrowid else: - return result[0] + return result[0] # type: ignore async def get_producer(producer_id: int) -> Optional[Producer]: @@ -185,7 +185,7 @@ async def get_producer(producer_id: int) -> Optional[Producer]: """, (producer_id,), ) - return Producer(**dict(row)) if row else None + return Producer(**row) if row else None async def get_producers(livestream: int) -> List[Producer]: @@ -196,4 +196,4 @@ async def get_producers(livestream: int) -> List[Producer]: """, (livestream,), ) - return [Producer(**dict(row)) for row in rows] + return [Producer(**row) for row in rows] diff --git a/lnbits/extensions/livestream/models.py b/lnbits/extensions/livestream/models.py index 5d617da99..31d3f6ebf 100644 --- a/lnbits/extensions/livestream/models.py +++ b/lnbits/extensions/livestream/models.py @@ -1,4 +1,5 @@ import json +from sqlite3 import Row from typing import Optional from fastapi import Query, Request @@ -27,6 +28,10 @@ class Livestream(BaseModel): url = request.url_for("livestream.lnurl_livestream", ls_id=self.id) return lnurl_encode(url) + @classmethod + def from_row(cls, row: Row): + return cls(**dict(row)) + class Track(BaseModel): id: int @@ -35,6 +40,10 @@ class Track(BaseModel): name: str producer: int + @classmethod + def from_row(cls, row: Row): + return cls(**dict(row)) + @property def min_sendable(self) -> int: return min(100_000, self.price_msat or 100_000) @@ -88,3 +97,7 @@ class Producer(BaseModel): user: str wallet: str name: str + + @classmethod + def from_row(cls, row: Row): + return cls(**dict(row)) diff --git a/lnbits/extensions/lnaddress/views_api.py b/lnbits/extensions/lnaddress/views_api.py index 7d15a55f7..3b44fc31b 100644 --- a/lnbits/extensions/lnaddress/views_api.py +++ b/lnbits/extensions/lnaddress/views_api.py @@ -71,9 +71,7 @@ async def api_domain_create( if not cf_response or not cf_response["success"]: await delete_domain(domain.id) raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="Problem with cloudflare: " - + cf_response["errors"][0]["message"], + status_code=HTTPStatus.BAD_REQUEST, detail="Problem with cloudflare." ) return domain.dict() diff --git a/lnbits/extensions/lnurlp/models.py b/lnbits/extensions/lnurlp/models.py index 1c6b6f711..ce1095794 100644 --- a/lnbits/extensions/lnurlp/models.py +++ b/lnbits/extensions/lnurlp/models.py @@ -61,8 +61,8 @@ class PayLink(BaseModel): def success_action(self, payment_hash: str) -> Optional[Dict]: if self.success_url: url: ParseResult = urlparse(self.success_url) - qs: Dict = parse_qs(url.query) - qs["payment_hash"] = payment_hash + qs = parse_qs(url.query) + setattr(qs, "payment_hash", payment_hash) url = url._replace(query=urlencode(qs, doseq=True)) return { "tag": "url", diff --git a/lnbits/extensions/market/notifier.py b/lnbits/extensions/market/notifier.py index 88a1a4a38..4fe366f3e 100644 --- a/lnbits/extensions/market/notifier.py +++ b/lnbits/extensions/market/notifier.py @@ -6,6 +6,7 @@ and delivery to the specific person import json from collections import defaultdict +from typing import AsyncGenerator from fastapi import WebSocket from loguru import logger @@ -34,7 +35,7 @@ class Notifier: # Create notification generator: self.generator = self.get_notification_generator() - async def get_notification_generator(self): + async def get_notification_generator(self) -> AsyncGenerator: """Notification Generator""" while True: @@ -54,9 +55,8 @@ class Notifier: logger.exception(f"There is no member in room: {room_name}") return None - async def push(self, message: str, room_name: str = None): + async def push(self, message: str, room_name: str): """Push a message""" - message_body = {"message": message, "room_name": room_name} await self.generator.asend(message_body) diff --git a/lnbits/extensions/market/views.py b/lnbits/extensions/market/views.py index 0bcfac459..f9e7131bc 100644 --- a/lnbits/extensions/market/views.py +++ b/lnbits/extensions/market/views.py @@ -1,15 +1,8 @@ -import json from http import HTTPStatus -from fastapi import ( - BackgroundTasks, - Depends, - Query, - Request, - WebSocket, - WebSocketDisconnect, -) +from fastapi import Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.templating import Jinja2Templates +from loguru import logger from starlette.exceptions import HTTPException from starlette.responses import HTMLResponse @@ -147,24 +140,14 @@ notifier = Notifier() @market_ext.websocket("/ws/{room_name}") -async def websocket_endpoint( - websocket: WebSocket, room_name: str, background_tasks: BackgroundTasks -): +async def websocket_endpoint(websocket: WebSocket, room_name: str): await notifier.connect(websocket, room_name) try: while True: data = await websocket.receive_text() - d = json.loads(data) - d["room_name"] = room_name - - room_members = ( - notifier.get_members(room_name) - if notifier.get_members(room_name) is not None - else [] - ) - + room_members = notifier.get_members(room_name) or [] if websocket not in room_members: - print("Sender not in room member: Reconnecting...") + logger.warning("Sender not in room member: Reconnecting...") await notifier.connect(websocket, room_name) await notifier._notify(data, room_name) diff --git a/lnbits/extensions/market/views_api.py b/lnbits/extensions/market/views_api.py index ad0cbb463..221d51bbf 100644 --- a/lnbits/extensions/market/views_api.py +++ b/lnbits/extensions/market/views_api.py @@ -1,4 +1,5 @@ from http import HTTPStatus +from typing import Optional from fastapi import Depends, Query from loguru import logger @@ -224,7 +225,7 @@ async def api_market_stalls( @market_ext.put("/api/v1/stalls/{stall_id}") async def api_market_stall_create( data: createStalls, - stall_id: str = None, + stall_id: Optional[str] = None, wallet: WalletTypeInfo = Depends(require_invoice_key), ): @@ -447,7 +448,7 @@ async def api_market_market_stalls(market_id: str): @market_ext.put("/api/v1/markets/{market_id}") async def api_market_market_create( data: CreateMarket, - market_id: str = None, + market_id: Optional[str] = None, wallet: WalletTypeInfo = Depends(require_invoice_key), ): if market_id: @@ -506,7 +507,7 @@ async def api_get_settings(wallet: WalletTypeInfo = Depends(require_admin_key)): @market_ext.put("/api/v1/settings/{usr}") async def api_set_settings( data: SetSettings, - usr: str = None, + usr: Optional[str] = None, wallet: WalletTypeInfo = Depends(require_admin_key), ): if usr: diff --git a/lnbits/extensions/ngrok/views.py b/lnbits/extensions/ngrok/views.py index d84ecd2d9..2fa4df9cf 100644 --- a/lnbits/extensions/ngrok/views.py +++ b/lnbits/extensions/ngrok/views.py @@ -2,7 +2,7 @@ from os import getenv from fastapi import Depends, Request from fastapi.templating import Jinja2Templates -from pyngrok import conf, ngrok +from pyngrok import conf, ngrok # type: ignore from lnbits.core.models import User from lnbits.decorators import check_user_exists diff --git a/lnbits/extensions/offlineshop/crud.py b/lnbits/extensions/offlineshop/crud.py index 896842d80..1fa63f3e0 100644 --- a/lnbits/extensions/offlineshop/crud.py +++ b/lnbits/extensions/offlineshop/crud.py @@ -22,12 +22,12 @@ async def create_shop(*, wallet_id: str) -> int: if db.type == SQLITE: return result._result_proxy.lastrowid else: - return result[0] + return result[0] # type: ignore async def get_shop(id: int) -> Optional[Shop]: row = await db.fetchone("SELECT * FROM offlineshop.shops WHERE id = ?", (id,)) - return Shop(**dict(row)) if row else None + return Shop(**row) if row else None async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]: @@ -40,7 +40,7 @@ async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]: ls_id = await create_shop(wallet_id=wallet) return await get_shop(ls_id) - return Shop(**dict(row)) if row else None + return Shop(**row) if row else None async def set_method(shop: int, method: str, wordlist: str = "") -> Optional[Shop]: diff --git a/lnbits/extensions/offlineshop/models.py b/lnbits/extensions/offlineshop/models.py index d2e3b3d27..01044cb0f 100644 --- a/lnbits/extensions/offlineshop/models.py +++ b/lnbits/extensions/offlineshop/models.py @@ -52,7 +52,7 @@ class ShopCounter: # cleanup confirmation words cache to_remove = len(self.fulfilled_payments) - 23 if to_remove > 0: - for i in range(to_remove): + for _ in range(to_remove): self.fulfilled_payments.popitem(False) return word @@ -64,6 +64,10 @@ class Shop(BaseModel): method: str wordlist: str + @classmethod + def from_row(cls, row: Row): + return cls(**dict(row)) + @property def otp_key(self) -> str: return base64.b32encode( diff --git a/lnbits/extensions/satsdice/lnurl.py b/lnbits/extensions/satsdice/lnurl.py index f766d8cbd..2bb590162 100644 --- a/lnbits/extensions/satsdice/lnurl.py +++ b/lnbits/extensions/satsdice/lnurl.py @@ -18,8 +18,6 @@ from .crud import ( ) from .models import CreateSatsDicePayment -##############LNURLP STUFF - @satsdice_ext.get( "/api/v1/lnurlp/{link_id}", @@ -84,7 +82,7 @@ async def api_lnurlp_callback( data = CreateSatsDicePayment( satsdice_pay=link.id, - value=amount_received / 1000, + value=int(amount_received / 1000), payment_hash=payment_hash, ) diff --git a/lnbits/extensions/satspay/crud.py b/lnbits/extensions/satspay/crud.py index 01abe24e4..c13d0a4b8 100644 --- a/lnbits/extensions/satspay/crud.py +++ b/lnbits/extensions/satspay/crud.py @@ -13,7 +13,7 @@ from .helpers import fetch_onchain_balance from .models import Charges, CreateCharge, SatsPayThemes -async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]: +async def create_charge(user: str, data: CreateCharge) -> Charges: data = CreateCharge(**data.dict()) charge_id = urlsafe_short_hash() if data.onchainwallet: @@ -79,7 +79,9 @@ async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]: data.custom_css, ), ) - return await get_charge(charge_id) + charge = await get_charge(charge_id) + assert charge, "Newly created charge does not exist" + return charge async def update_charge(charge_id: str, **kwargs) -> Optional[Charges]: diff --git a/lnbits/extensions/splitpayments/crud.py b/lnbits/extensions/splitpayments/crud.py index de4e0822a..737e7bb9a 100644 --- a/lnbits/extensions/splitpayments/crud.py +++ b/lnbits/extensions/splitpayments/crud.py @@ -10,7 +10,7 @@ async def get_targets(source_wallet: str) -> List[Target]: rows = await db.fetchall( "SELECT * FROM splitpayments.targets WHERE source = ?", (source_wallet,) ) - return [Target(**dict(row)) for row in rows] + return [Target(**row) for row in rows] async def set_targets(source_wallet: str, targets: List[Target]): diff --git a/lnbits/extensions/splitpayments/models.py b/lnbits/extensions/splitpayments/models.py index fc3db2c6d..4f2bb0106 100644 --- a/lnbits/extensions/splitpayments/models.py +++ b/lnbits/extensions/splitpayments/models.py @@ -1,6 +1,7 @@ +from sqlite3 import Row from typing import List, Optional -from fastapi.param_functions import Query +from fastapi import Query from pydantic import BaseModel @@ -11,6 +12,10 @@ class Target(BaseModel): tag: str alias: Optional[str] + @classmethod + def from_row(cls, row: Row): + return cls(**dict(row)) + class TargetPutList(BaseModel): wallet: str = Query(...) diff --git a/lnbits/extensions/streamalerts/crud.py b/lnbits/extensions/streamalerts/crud.py index f376841a4..1745623d4 100644 --- a/lnbits/extensions/streamalerts/crud.py +++ b/lnbits/extensions/streamalerts/crud.py @@ -147,14 +147,16 @@ async def create_service(data: CreateService) -> Service: if db.type == SQLITE: service_id = result._result_proxy.lastrowid else: - service_id = result[0] + service_id = result[0] # type: ignore service = await get_service(service_id) assert service return service -async def get_service(service_id: int, by_state: str = None) -> Optional[Service]: +async def get_service( + service_id: int, by_state: Optional[str] = None +) -> Optional[Service]: """Return a service either by ID or, available, by state Each Service's donation page is reached through its "state" hash @@ -184,7 +186,9 @@ async def authenticate_service(service_id, code, redirect_uri): """Use authentication code from third party API to retreive access token""" # The API token is passed in the querystring as 'code' service = await get_service(service_id) + assert service wallet = await get_wallet(service.wallet) + assert wallet user = wallet.user url = "https://streamlabs.com/api/v1.0/token" data = { @@ -208,8 +212,11 @@ async def service_add_token(service_id, token): is not overwritten. Tokens for Streamlabs never need to be refreshed. """ - if (await get_service(service_id)).authenticated: + service = await get_service(service_id) + assert service + if service.authenticated: return False + await db.execute( "UPDATE streamalerts.Services SET authenticated = 1, token = ? where id = ?", (token, service_id), diff --git a/lnbits/extensions/subdomains/cloudflare.py b/lnbits/extensions/subdomains/cloudflare.py index d0d8c4f31..3d3b9bdeb 100644 --- a/lnbits/extensions/subdomains/cloudflare.py +++ b/lnbits/extensions/subdomains/cloudflare.py @@ -1,5 +1,3 @@ -import json - import httpx from .models import Domains @@ -20,25 +18,21 @@ async def cloudflare_create_subdomain( "Content-Type": "application/json", } aRecord = subdomain + "." + domain.domain - cf_response = "" async with httpx.AsyncClient() as client: - try: - r = await client.post( - url, - headers=header, - json={ - "type": record_type, - "name": aRecord, - "content": ip, - "ttl": 0, - "proxied": False, - }, - timeout=40, - ) - cf_response = json.loads(r.text) - except AssertionError: - cf_response = "Error occured" - return cf_response + r = await client.post( + url, + headers=header, + json={ + "type": record_type, + "name": aRecord, + "content": ip, + "ttl": 0, + "proxied": False, + }, + timeout=40, + ) + r.raise_for_status() + return r.json() async def cloudflare_deletesubdomain(domain: Domains, domain_id: str): @@ -52,8 +46,4 @@ async def cloudflare_deletesubdomain(domain: Domains, domain_id: str): "Content-Type": "application/json", } async with httpx.AsyncClient() as client: - try: - r = await client.delete(url + "/" + domain_id, headers=header, timeout=40) - r.text - except AssertionError: - pass + await client.delete(url + "/" + domain_id, headers=header, timeout=40) diff --git a/lnbits/extensions/subdomains/views_api.py b/lnbits/extensions/subdomains/views_api.py index 6f85c66e0..3c0330f53 100644 --- a/lnbits/extensions/subdomains/views_api.py +++ b/lnbits/extensions/subdomains/views_api.py @@ -121,21 +121,21 @@ async def api_subdomain_make_subdomain(domain_id, data: CreateSubdomain): detail=f"{data.subdomain}.{domain.domain} domain already taken.", ) - ## Dry run cloudflare... (create and if create is sucessful delete it) - cf_response = await cloudflare_create_subdomain( - domain=domain, - subdomain=data.subdomain, - record_type=data.record_type, - ip=data.ip, - ) - if cf_response["success"] is True: - await cloudflare_deletesubdomain( - domain=domain, domain_id=cf_response["result"]["id"] + ## Dry run cloudflare... (create and if create is successful delete it) + try: + res_json = await cloudflare_create_subdomain( + domain=domain, + subdomain=data.subdomain, + record_type=data.record_type, + ip=data.ip, ) - else: + await cloudflare_deletesubdomain( + domain=domain, domain_id=res_json["result"]["id"] + ) + except: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, - detail=f'Problem with cloudflare: {cf_response["errors"][0]["message"]}', + detail="Problem with cloudflare.", ) ## ALL OK - create an invoice and return it to the user diff --git a/lnbits/extensions/tipjar/helpers.py b/lnbits/extensions/tipjar/helpers.py deleted file mode 100644 index 7214e19c3..000000000 --- a/lnbits/extensions/tipjar/helpers.py +++ /dev/null @@ -1,20 +0,0 @@ -from lnbits.core.crud import get_wallet - -from .crud import get_tipjar - - -async def get_charge_details(tipjar_id): - """Return the default details for a satspay charge""" - tipjar = await get_tipjar(tipjar_id) - wallet_id = tipjar.wallet - wallet = await get_wallet(wallet_id) - user = wallet.user - details = { - "time": 1440, - "user": user, - "lnbitswallet": wallet_id, - "onchainwallet": tipjar.onchain, - "completelink": "/tipjar/" + str(tipjar_id), - "completelinktext": "Thanks for the tip!", - } - return details diff --git a/lnbits/extensions/tipjar/views.py b/lnbits/extensions/tipjar/views.py index 56f718e21..ddb1b63c0 100644 --- a/lnbits/extensions/tipjar/views.py +++ b/lnbits/extensions/tipjar/views.py @@ -1,7 +1,6 @@ from http import HTTPStatus -from fastapi import Depends, Request -from fastapi.param_functions import Query +from fastapi import Depends, Query, Request from fastapi.templating import Jinja2Templates from starlette.exceptions import HTTPException diff --git a/lnbits/extensions/tipjar/views_api.py b/lnbits/extensions/tipjar/views_api.py index 7d3df9205..7d420fae8 100644 --- a/lnbits/extensions/tipjar/views_api.py +++ b/lnbits/extensions/tipjar/views_api.py @@ -3,7 +3,7 @@ from http import HTTPStatus from fastapi import Depends, Query from starlette.exceptions import HTTPException -from lnbits.core.crud import get_user +from lnbits.core.crud import get_user, get_wallet from lnbits.decorators import WalletTypeInfo, get_key_type # todo: use the API, not direct import @@ -22,7 +22,6 @@ from .crud import ( update_tip, update_tipjar, ) -from .helpers import get_charge_details from .models import createTip, createTipJar, createTips @@ -55,25 +54,32 @@ async def api_create_tip(data: createTips): status_code=HTTPStatus.NOT_FOUND, detail="Tipjar does not exist." ) - webhook = tipjar.webhook - charge_details = await get_charge_details(tipjar.id) + wallet_id = tipjar.wallet + wallet = await get_wallet(wallet_id) + if not wallet: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail="Tipjar wallet does not exist." + ) + name = data.name + # Ensure that description string can be split reliably name = name.replace('"', "''") if not name: name = "Anonymous" + description = f"{name}: {message}" charge = await create_charge( - user=charge_details["user"], + user=wallet.user, data=CreateCharge( amount=sats, - webhook=webhook or "", + webhook=tipjar.webhook or "", description=description, - onchainwallet=charge_details["onchainwallet"], - lnbitswallet=charge_details["lnbitswallet"], - completelink=charge_details["completelink"], - completelinktext=charge_details["completelinktext"], - time=charge_details["time"], + onchainwallet=tipjar.onchain or "", + lnbitswallet=tipjar.wallet, + completelink="/tipjar/" + str(tipjar_id), + completelinktext="Thanks for the tip!", + time=1440, custom_css="", ), ) diff --git a/lnbits/extensions/watchonly/helpers.py b/lnbits/extensions/watchonly/helpers.py index 8db9ff573..40e9788f6 100644 --- a/lnbits/extensions/watchonly/helpers.py +++ b/lnbits/extensions/watchonly/helpers.py @@ -1,3 +1,5 @@ +from typing import Optional, Tuple + from embit.descriptor import Descriptor, Key from embit.descriptor.arguments import AllowedDerivation from embit.networks import NETWORKS @@ -12,7 +14,7 @@ def detect_network(k): return net -def parse_key(masterpub: str) -> Descriptor: +def parse_key(masterpub: str) -> Tuple[Descriptor, Optional[dict]]: """Parses masterpub or descriptor and returns a tuple: (Descriptor, network) To create addresses use descriptor.derive(num).address(network=network) """ @@ -34,6 +36,7 @@ def parse_key(masterpub: str) -> Descriptor: k.allowed_derivation = AllowedDerivation.default() # get version bytes version = k.key.version + desc = Descriptor() for network_name in NETWORKS: net = NETWORKS[network_name] # not found in this network @@ -47,8 +50,9 @@ def parse_key(masterpub: str) -> Descriptor: desc = Descriptor.from_string("wpkh(%s)" % str(k)) break # we didn't find correct version - if network is None: + if not network: raise ValueError("Unknown master public key version") + else: desc = Descriptor.from_string(masterpub) if not desc.is_wildcard: @@ -61,6 +65,7 @@ def parse_key(masterpub: str) -> Descriptor: if network is not None and network != net: raise ValueError("Keys from different networks") network = net + return desc, network diff --git a/lnbits/extensions/watchonly/views_api.py b/lnbits/extensions/watchonly/views_api.py index 2e3fc45d8..e0c427fe7 100644 --- a/lnbits/extensions/watchonly/views_api.py +++ b/lnbits/extensions/watchonly/views_api.py @@ -73,7 +73,8 @@ async def api_wallet_create_or_update( data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key) ): try: - (descriptor, network) = parse_key(data.masterpub) + descriptor, network = parse_key(data.masterpub) + assert network if data.network != network["name"]: raise ValueError( "Account network error. This account is for '{}'".format( @@ -308,12 +309,9 @@ async def api_psbt_utxos_tx( raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) -@watchonly_ext.put("/api/v1/psbt/extract") -async def api_psbt_extract_tx( - data: ExtractPsbt, w: WalletTypeInfo = Depends(require_admin_key) -): +@watchonly_ext.put("/api/v1/psbt/extract", dependencies=[Depends(require_admin_key)]) +async def api_psbt_extract_tx(data: ExtractPsbt): network = NETWORKS["main"] if data.network == "Mainnet" else NETWORKS["test"] - res = SignedTransaction() try: psbt = PSBT.from_base64(data.psbtBase64) for i, inp in enumerate(data.inputs): @@ -322,9 +320,9 @@ async def api_psbt_extract_tx( final_psbt = finalizer.finalize_psbt(psbt) if not final_psbt: raise ValueError("PSBT cannot be finalized!") - res.tx_hex = final_psbt.to_string() - transaction = Transaction.from_string(res.tx_hex) + tx_hex = final_psbt.to_string() + transaction = Transaction.from_string(tx_hex) tx = { "locktime": transaction.locktime, "version": transaction.version, @@ -336,10 +334,10 @@ async def api_psbt_extract_tx( tx["outputs"].append( {"amount": out.value, "address": out.script_pubkey.address(network)} ) - res.tx_json = json.dumps(tx) + signed_tx = SignedTransaction(tx_hex=tx_hex, tx_json=json.dumps(tx)) + return signed_tx.dict() except Exception as e: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) - return res.dict() @watchonly_ext.post("/api/v1/tx")