From 2949253cb277cd771704b1b48e8beaf98cb4f233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Thu, 5 Jan 2023 12:27:43 +0100 Subject: [PATCH] fix boltscard mypy issue --- lnbits/extensions/boltcards/crud.py | 13 +++++----- lnbits/extensions/boltcards/lnurl.py | 33 ++++++++++++------------ lnbits/extensions/boltcards/models.py | 16 ++++++------ lnbits/extensions/boltcards/tasks.py | 15 ++++++----- lnbits/extensions/boltcards/views.py | 3 +-- lnbits/extensions/boltcards/views_api.py | 27 ++++++++----------- pyproject.toml | 1 - 7 files changed, 51 insertions(+), 57 deletions(-) diff --git a/lnbits/extensions/boltcards/crud.py b/lnbits/extensions/boltcards/crud.py index 4fae31f92..cc5d51931 100644 --- a/lnbits/extensions/boltcards/crud.py +++ b/lnbits/extensions/boltcards/crud.py @@ -1,5 +1,5 @@ import secrets -from datetime import date, datetime +from datetime import datetime from typing import List, Optional, Union from lnbits.helpers import urlsafe_short_hash @@ -124,7 +124,6 @@ async def get_card_by_otp(otp: str) -> Optional[Card]: async def delete_card(card_id: str) -> None: # Delete cards - card = await get_card(card_id) await db.execute("DELETE FROM boltcards.cards WHERE id = ?", (card_id,)) # Delete hits hits = await get_hits([card_id]) @@ -146,7 +145,7 @@ async def update_card_counter(counter: int, id: str): async def enable_disable_card(enable: bool, id: str) -> Optional[Card]: - row = await db.execute( + await db.execute( "UPDATE boltcards.cards SET enable = ? WHERE id = ?", (enable, id), ) @@ -161,7 +160,7 @@ async def update_card_otp(otp: str, id: str): async def get_hit(hit_id: str) -> Optional[Hit]: - row = await db.fetchone(f"SELECT * FROM boltcards.hits WHERE id = ?", (hit_id)) + row = await db.fetchone(f"SELECT * FROM boltcards.hits WHERE id = ?", (hit_id,)) if not row: return None @@ -182,7 +181,7 @@ async def get_hits(cards_ids: Union[str, List[str]]) -> List[Hit]: return [Hit(**row) for row in rows] -async def get_hits_today(card_id: str) -> Optional[Hit]: +async def get_hits_today(card_id: str) -> List[Hit]: rows = await db.fetchall( f"SELECT * FROM boltcards.hits WHERE card_id = ?", (card_id,), @@ -259,7 +258,7 @@ async def create_refund(hit_id, refund_amount) -> Refund: async def get_refund(refund_id: str) -> Optional[Refund]: row = await db.fetchone( - f"SELECT * FROM boltcards.refunds WHERE id = ?", (refund_id) + f"SELECT * FROM boltcards.refunds WHERE id = ?", (refund_id,) ) if not row: return None @@ -267,7 +266,7 @@ async def get_refund(refund_id: str) -> Optional[Refund]: return Refund.parse_obj(refund) -async def get_refunds(hits_ids: Union[str, List[str]]) -> List[Refund]: +async def get_refunds(hits_ids: List[Hit]) -> List[Refund]: if len(hits_ids) == 0: return [] diff --git a/lnbits/extensions/boltcards/lnurl.py b/lnbits/extensions/boltcards/lnurl.py index 3a99073ab..d04303725 100644 --- a/lnbits/extensions/boltcards/lnurl.py +++ b/lnbits/extensions/boltcards/lnurl.py @@ -3,13 +3,9 @@ import secrets from http import HTTPStatus from urllib.parse import urlparse -from fastapi import Request -from fastapi.param_functions import Query -from fastapi.params import Depends, Query -from lnurl import encode as lnurl_encode # type: ignore -from lnurl.types import LnurlPayMetadata # type: ignore -from starlette.exceptions import HTTPException -from starlette.requests import Request +from fastapi import HTTPException, Query, Request +from lnurl import encode as lnurl_encode +from lnurl.types import LnurlPayMetadata from starlette.responses import HTMLResponse from lnbits import bolt11 @@ -28,14 +24,13 @@ from .crud import ( update_card_counter, update_card_otp, ) -from .models import CreateCardData from .nxp424 import decryptSUN, getSunMAC ###############LNURLWITHDRAW################# # /boltcards/api/v1/scan?p=00000000000000000000000000000000&c=0000000000000000 @boltcards_ext.get("/api/v1/scan/{external_id}") -async def api_scan(p, c, request: Request, external_id: str = None): +async def api_scan(p, c, request: Request, external_id: str = Query(None)): # some wallets send everything as lower case, no bueno p = p.upper() c = c.upper() @@ -63,6 +58,7 @@ async def api_scan(p, c, request: Request, external_id: str = None): await update_card_counter(ctr_int, card.id) # gathering some info for hit record + assert request.client ip = request.client.host if "x-real-ip" in request.headers: ip = request.headers["x-real-ip"] @@ -95,7 +91,6 @@ async def api_scan(p, c, request: Request, external_id: str = None): name="boltcards.lnurl_callback", ) async def lnurl_callback( - request: Request, pr: str = Query(None), k1: str = Query(None), ): @@ -120,7 +115,9 @@ async def lnurl_callback( return {"status": "ERROR", "reason": "Failed to decode payment request"} card = await get_card(hit.card_id) + assert card hit = await spend_hit(id=hit.id, amount=int(invoice.amount_msat / 1000)) + assert hit try: await pay_invoice( wallet_id=card.wallet, @@ -155,7 +152,7 @@ async def api_auth(a, request: Request): response = { "card_name": card.card_name, - "id": 1, + "id": str(1), "k0": card.k0, "k1": card.k1, "k2": card.k2, @@ -163,7 +160,7 @@ async def api_auth(a, request: Request): "k4": card.k2, "lnurlw_base": "lnurlw://" + lnurlw_base, "protocol_name": "new_bolt_card_response", - "protocol_version": 1, + "protocol_version": str(1), } return response @@ -179,7 +176,9 @@ async def api_auth(a, request: Request): ) async def lnurlp_response(req: Request, hit_id: str = Query(None)): hit = await get_hit(hit_id) + assert hit card = await get_card(hit.card_id) + assert card if not hit: return {"status": "ERROR", "reason": f"LNURL-pay record not found."} if not card.enable: @@ -199,17 +198,17 @@ async def lnurlp_response(req: Request, hit_id: str = Query(None)): response_class=HTMLResponse, name="boltcards.lnurlp_callback", ) -async def lnurlp_callback( - req: Request, hit_id: str = Query(None), amount: str = Query(None) -): +async def lnurlp_callback(hit_id: str = Query(None), amount: str = Query(None)): hit = await get_hit(hit_id) + assert hit card = await get_card(hit.card_id) + assert card if not hit: return {"status": "ERROR", "reason": f"LNURL-pay record not found."} - payment_hash, payment_request = await create_invoice( + _, payment_request = await create_invoice( wallet_id=card.wallet, - amount=int(amount) / 1000, + amount=int(int(amount) / 1000), memo=f"Refund {hit_id}", unhashed_description=LnurlPayMetadata( json.dumps([["text/plain", "Refund"]]) diff --git a/lnbits/extensions/boltcards/models.py b/lnbits/extensions/boltcards/models.py index 47ca1df09..5ea4be15d 100644 --- a/lnbits/extensions/boltcards/models.py +++ b/lnbits/extensions/boltcards/models.py @@ -1,14 +1,11 @@ +import json from sqlite3 import Row -from typing import Optional -from fastapi import Request -from fastapi.params import Query +from fastapi import Query, Request from lnurl import Lnurl -from lnurl import encode as lnurl_encode # type: ignore -from lnurl.models import LnurlPaySuccessAction, UrlAction # type: ignore -from lnurl.types import LnurlPayMetadata # type: ignore +from lnurl import encode as lnurl_encode +from lnurl.types import LnurlPayMetadata from pydantic import BaseModel -from pydantic.main import BaseModel ZERO_KEY = "00000000000000000000000000000000" @@ -32,6 +29,7 @@ class Card(BaseModel): otp: str time: int + @classmethod def from_row(cls, row: Row) -> "Card": return cls(**dict(row)) @@ -40,7 +38,7 @@ class Card(BaseModel): return lnurl_encode(url) async def lnurlpay_metadata(self) -> LnurlPayMetadata: - return LnurlPayMetadata(json.dumps([["text/plain", self.title]])) + return LnurlPayMetadata(json.dumps([["text/plain", self.card_name]])) class CreateCardData(BaseModel): @@ -69,6 +67,7 @@ class Hit(BaseModel): amount: int time: int + @classmethod def from_row(cls, row: Row) -> "Hit": return cls(**dict(row)) @@ -79,5 +78,6 @@ class Refund(BaseModel): refund_amount: int time: int + @classmethod def from_row(cls, row: Row) -> "Refund": return cls(**dict(row)) diff --git a/lnbits/extensions/boltcards/tasks.py b/lnbits/extensions/boltcards/tasks.py index c1e99b765..6addf0339 100644 --- a/lnbits/extensions/boltcards/tasks.py +++ b/lnbits/extensions/boltcards/tasks.py @@ -1,8 +1,6 @@ import asyncio import json -import httpx - from lnbits.core import db as core_db from lnbits.core.models import Payment from lnbits.helpers import get_current_extension_name @@ -21,22 +19,27 @@ async def wait_for_paid_invoices(): async def on_invoice_paid(payment: Payment) -> None: + if not payment.extra: + return + if not payment.extra.get("refund"): return if payment.extra.get("wh_status"): # this webhook has already been sent return - hit = await get_hit(payment.extra.get("refund")) + + hit = await get_hit(str(payment.extra.get("refund"))) if hit: - refund = await create_refund( - hit_id=hit.id, refund_amount=(payment.amount / 1000) - ) + await create_refund(hit_id=hit.id, refund_amount=(payment.amount / 1000)) await mark_webhook_sent(payment, 1) async def mark_webhook_sent(payment: Payment, status: int) -> None: + if not payment.extra: + return + payment.extra["wh_status"] = status await core_db.execute( diff --git a/lnbits/extensions/boltcards/views.py b/lnbits/extensions/boltcards/views.py index 8fcbb7def..273cfcbf1 100644 --- a/lnbits/extensions/boltcards/views.py +++ b/lnbits/extensions/boltcards/views.py @@ -1,5 +1,4 @@ -from fastapi import FastAPI, Request -from fastapi.params import Depends +from fastapi import Depends, Request from fastapi.templating import Jinja2Templates from starlette.responses import HTMLResponse diff --git a/lnbits/extensions/boltcards/views_api.py b/lnbits/extensions/boltcards/views_api.py index c18c33d05..feca12e0b 100644 --- a/lnbits/extensions/boltcards/views_api.py +++ b/lnbits/extensions/boltcards/views_api.py @@ -1,10 +1,6 @@ -import secrets from http import HTTPStatus -from fastapi.params import Depends, Query -from loguru import logger -from starlette.exceptions import HTTPException -from starlette.requests import Request +from fastapi import Depends, HTTPException, Query from lnbits.core.crud import get_user from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key @@ -15,13 +11,11 @@ from .crud import ( delete_card, enable_disable_card, get_card, - get_card_by_otp, get_card_by_uid, get_cards, get_hits, get_refunds, update_card, - update_card_otp, ) from .models import CreateCardData @@ -33,7 +27,8 @@ async def api_cards( wallet_ids = [g.wallet.id] if all_wallets: - wallet_ids = (await get_user(g.wallet.user)).wallet_ids + user = await get_user(g.wallet.user) + wallet_ids = user.wallet_ids if user else [] return [card.dict() for card in await get_cards(wallet_ids)] @@ -41,9 +36,8 @@ async def api_cards( @boltcards_ext.post("/api/v1/cards", status_code=HTTPStatus.CREATED) @boltcards_ext.put("/api/v1/cards/{card_id}", status_code=HTTPStatus.OK) async def api_card_create_or_update( - # req: Request, data: CreateCardData, - card_id: str = None, + card_id: str = Query(None), wallet: WalletTypeInfo = Depends(require_admin_key), ): try: @@ -95,6 +89,7 @@ async def api_card_create_or_update( status_code=HTTPStatus.BAD_REQUEST, ) card = await create_card(wallet_id=wallet.wallet.id, data=data) + assert card return card.dict() @@ -110,6 +105,7 @@ async def enable_card( if card.wallet != wallet.wallet.id: raise HTTPException(detail="Not your card.", status_code=HTTPStatus.FORBIDDEN) card = await enable_disable_card(enable=enable, id=card_id) + assert card return card.dict() @@ -136,7 +132,8 @@ async def api_hits( wallet_ids = [g.wallet.id] if all_wallets: - wallet_ids = (await get_user(g.wallet.user)).wallet_ids + user = await get_user(g.wallet.user) + wallet_ids = user.wallet_ids if user else [] cards = await get_cards(wallet_ids) cards_ids = [] @@ -153,15 +150,13 @@ async def api_refunds( wallet_ids = [g.wallet.id] if all_wallets: - wallet_ids = (await get_user(g.wallet.user)).wallet_ids + user = await get_user(g.wallet.user) + wallet_ids = user.wallet_ids if user else [] cards = await get_cards(wallet_ids) cards_ids = [] for card in cards: cards_ids.append(card.id) hits = await get_hits(cards_ids) - hits_ids = [] - for hit in hits: - hits_ids.append(hit.id) - return [refund.dict() for refund in await get_refunds(hits_ids)] + return [refund.dict() for refund in await get_refunds(hits)] diff --git a/pyproject.toml b/pyproject.toml index e2116ed08..cb9b82a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,6 @@ files = "lnbits" exclude = """(?x)( ^lnbits/extensions/bleskomat. | ^lnbits/extensions/boltz. - | ^lnbits/extensions/boltcards. | ^lnbits/extensions/livestream. | ^lnbits/extensions/lnaddress. | ^lnbits/extensions/lnurldevice.