diff --git a/lnbits/core/models.py b/lnbits/core/models.py index cd7c79f0c..01e485c9e 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -7,6 +7,7 @@ from sqlite3 import Row from typing import Callable, Dict, List, Optional from ecdsa import SECP256k1, SigningKey +from fastapi import Query from lnurl import encode as lnurl_encode from loguru import logger from pydantic import BaseModel @@ -179,7 +180,8 @@ class Payment(FromRowModel): return PaymentStatus(None) logger.debug( - f"Checking {'outgoing' if self.is_out else 'incoming'} pending payment {self.checking_id}" + f"Checking {'outgoing' if self.is_out else 'incoming'} " + f"pending payment {self.checking_id}" ) WALLET = get_wallet_class() @@ -193,7 +195,8 @@ class Payment(FromRowModel): if self.is_in and status.pending and self.is_expired and self.expiry: expiration_date = datetime.datetime.fromtimestamp(self.expiry) logger.debug( - f"Deleting expired incoming pending payment {self.checking_id}: expired {expiration_date}" + f"Deleting expired incoming pending payment {self.checking_id}: " + f"expired {expiration_date}" ) await self.delete(conn) elif self.is_out and status.failed: @@ -203,7 +206,8 @@ class Payment(FromRowModel): await self.delete(conn) elif not status.pending: logger.info( - f"Marking '{'in' if self.is_in else 'out'}' {self.checking_id} as not pending anymore: {status}" + f"Marking '{'in' if self.is_in else 'out'}' " + f"{self.checking_id} as not pending anymore: {status}" ) await self.update_status(status, conn=conn) return status @@ -257,3 +261,41 @@ class TinyURL(BaseModel): @classmethod def from_row(cls, row: Row): return cls(**dict(row)) + + +class ConversionData(BaseModel): + from_: str = "sat" + amount: float + to: str = "usd" + + +class Callback(BaseModel): + callback: str + + +class DecodePayment(BaseModel): + data: str + + +class CreateLnurl(BaseModel): + description_hash: str + callback: str + amount: int + comment: Optional[str] = None + description: Optional[str] = None + + +class CreateInvoice(BaseModel): + unit: str = "sat" + internal: bool = False + out: bool = True + amount: float = Query(None, ge=0) + memo: Optional[str] = None + description_hash: Optional[str] = None + unhashed_description: Optional[str] = None + expiry: Optional[int] = None + lnurl_callback: Optional[str] = None + lnurl_balance_check: Optional[str] = None + extra: Optional[dict] = None + webhook: Optional[str] = None + bolt11: Optional[str] = None diff --git a/lnbits/core/services.py b/lnbits/core/services.py index e4b3c9b04..651c16a29 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -519,7 +519,8 @@ class WebsocketConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] - async def connect(self, websocket: WebSocket): + async def connect(self, websocket: WebSocket, item_id: str): + logger.debug(f"Websocket connected to {item_id}") await websocket.accept() self.active_connections.append(websocket) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 245803d87..c4956bb3f 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -14,7 +14,6 @@ from fastapi import ( Body, Depends, Header, - Query, Request, Response, WebSocket, @@ -22,8 +21,6 @@ from fastapi import ( ) from fastapi.exceptions import HTTPException from loguru import logger -from pydantic import BaseModel -from pydantic.fields import Field from sse_starlette.sse import EventSourceResponse from starlette.responses import RedirectResponse, StreamingResponse @@ -32,7 +29,17 @@ from lnbits.core.helpers import ( migrate_extension_database, stop_extension_background_work, ) -from lnbits.core.models import Payment, PaymentFilters, User, Wallet +from lnbits.core.models import ( + Callback, + ConversionData, + CreateInvoice, + CreateLnurl, + DecodePayment, + Payment, + PaymentFilters, + User, + Wallet, +) from lnbits.db import Filters, Page from lnbits.decorators import ( WalletTypeInfo, @@ -179,23 +186,7 @@ async def api_payments_paginated( return page -class CreateInvoiceData(BaseModel): - out: Optional[bool] = True - amount: float = Query(None, ge=0) - memo: Optional[str] = None - unit: Optional[str] = "sat" - description_hash: Optional[str] = None - unhashed_description: Optional[str] = None - expiry: Optional[int] = None - lnurl_callback: Optional[str] = None - lnurl_balance_check: Optional[str] = None - extra: Optional[dict] = None - webhook: Optional[str] = None - internal: Optional[bool] = False - bolt11: Optional[str] = None - - -async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet): +async def api_payments_create_invoice(data: CreateInvoice, wallet: Wallet): if data.description_hash or data.unhashed_description: try: description_hash = ( @@ -209,7 +200,10 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet): except ValueError: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, - detail="'description_hash' and 'unhashed_description' must be a valid hex strings", + detail=( + "'description_hash' and 'unhashed_description' " + "must be a valid hex strings" + ), ) memo = "" else: @@ -310,7 +304,7 @@ async def api_payments_pay_invoice(bolt11: str, wallet: Wallet): ) async def api_payments_create( wallet: WalletTypeInfo = Depends(require_invoice_key), - invoiceData: CreateInvoiceData = Body(...), + invoiceData: CreateInvoice = Body(...), ): if invoiceData.out is True and wallet.wallet_type == 0: if not invoiceData.bolt11: @@ -331,17 +325,9 @@ async def api_payments_create( ) -class CreateLNURLData(BaseModel): - description_hash: str - callback: str - amount: int - comment: Optional[str] = None - description: Optional[str] = None - - @core_app.post("/api/v1/payments/lnurl") async def api_payments_pay_lnurl( - data: CreateLNURLData, wallet: WalletTypeInfo = Depends(require_admin_key) + data: CreateLnurl, wallet: WalletTypeInfo = Depends(require_admin_key) ): domain = urlparse(data.callback).netloc @@ -377,13 +363,19 @@ async def api_payments_pay_lnurl( if invoice.amount_msat != data.amount: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, - detail=f"{domain} returned an invalid invoice. Expected {data.amount} msat, got {invoice.amount_msat}.", + detail=( + f"{domain} returned an invalid invoice. Expected {data.amount} msat, " + f"got {invoice.amount_msat}.", + ), ) if invoice.description_hash != data.description_hash: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, - detail=f"{domain} returned an invalid invoice. Expected description_hash == {data.description_hash}, got {invoice.description_hash}.", + detail=( + f"{domain} returned an invalid invoice. Expected description_hash == " + f"{data.description_hash}, got {invoice.description_hash}.", + ), ) extra = {} @@ -468,10 +460,11 @@ async def api_payments_sse( async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)): # We use X_Api_Key here because we want this call to work with and without keys # If a valid key is given, we also return the field "details", otherwise not - wallet = await get_wallet_for_key(X_Api_Key) if isinstance(X_Api_Key, str) else None # type: ignore + wallet = await get_wallet_for_key(X_Api_Key) if isinstance(X_Api_Key, str) else None - # we have to specify the wallet id here, because postgres and sqlite return internal payments in different order - # and get_standalone_payment otherwise just fetches the first one, causing unpredictable results + # we have to specify the wallet id here, because postgres and sqlite return + # internal payments in different order and get_standalone_payment otherwise + # just fetches the first one, causing unpredictable results payment = await get_standalone_payment( payment_hash, wallet_id=wallet.id if wallet else None ) @@ -623,10 +616,6 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type return params -class DecodePayment(BaseModel): - data: str - - @core_app.post("/api/v1/payments/decode", status_code=HTTPStatus.OK) async def api_payments_decode(data: DecodePayment, response: Response): payment_str = data.data @@ -653,15 +642,11 @@ async def api_payments_decode(data: DecodePayment, response: Response): return {"message": "Failed to decode"} -class Callback(BaseModel): - callback: str = Query(...) - - @core_app.post("/api/v1/lnurlauth") async def api_perform_lnurlauth( - callback: Callback, wallet: WalletTypeInfo = Depends(require_admin_key) + data: Callback, wallet: WalletTypeInfo = Depends(require_admin_key) ): - err = await perform_lnurlauth(callback.callback, wallet=wallet) + err = await perform_lnurlauth(data.callback, wallet=wallet) if err: raise HTTPException( status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=err.reason @@ -680,12 +665,6 @@ async def api_list_currencies_available(): return list(currencies.keys()) -class ConversionData(BaseModel): - from_: str = Field("sat", alias="from") - amount: float - to: str = Query("usd") - - @core_app.post("/api/v1/conversion") async def api_fiat_as_sats(data: ConversionData): output = {} @@ -705,7 +684,7 @@ async def api_fiat_as_sats(data: ConversionData): @core_app.get("/api/v1/qrcode/{data}", response_class=StreamingResponse) -async def img(request: Request, data): +async def img(data): qr = pyqrcode.create(data) stream = BytesIO() qr.svg(stream, scale=3) @@ -725,12 +704,9 @@ async def img(request: Request, data): ) -# UNIVERSAL WEBSOCKET MANAGER - - @core_app.websocket("/api/v1/ws/{item_id}") async def websocket_connect(websocket: WebSocket, item_id: str): - await websocketManager.connect(websocket) + await websocketManager.connect(websocket, item_id) try: while True: await websocket.receive_text() @@ -808,7 +784,10 @@ async def api_install_extension( ext_info.clean_extension_files() raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"Failed to install extension {ext_info.id} ({ext_info.installed_version}).", + detail=( + f"Failed to install extension {ext_info.id} " + f"({ext_info.installed_version})." + ), ) @@ -831,7 +810,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin) if installed_ext and ext_id in installed_ext.dependencies: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, - detail=f"Cannot uninstall. Extension '{installed_ext.name}' depends on this one.", + detail=( + f"Cannot uninstall. Extension '{installed_ext.name}' " + "depends on this one." + ), ) try: diff --git a/tests/conftest.py b/tests/conftest.py index 68db40263..ac9becd71 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,8 +8,9 @@ from httpx import AsyncClient from lnbits.app import create_app from lnbits.commands import migrate_databases from lnbits.core.crud import create_account, create_wallet +from lnbits.core.models import CreateInvoice from lnbits.core.services import update_wallet_balance -from lnbits.core.views.api import CreateInvoiceData, api_payments_create_invoice +from lnbits.core.views.api import api_payments_create_invoice from lnbits.db import Database from lnbits.settings import settings from tests.helpers import get_hold_invoice, get_random_invoice_data, get_real_invoice @@ -142,7 +143,7 @@ async def adminkey_headers_to(to_wallet): @pytest_asyncio.fixture(scope="session") async def invoice(to_wallet): data = await get_random_invoice_data() - invoiceData = CreateInvoiceData(**data) + invoiceData = CreateInvoice(**data) invoice = await api_payments_create_invoice(invoiceData, to_wallet) yield invoice del invoice diff --git a/tests/core/views/test_api.py b/tests/core/views/test_api.py index 12397db71..9cd4c881e 100644 --- a/tests/core/views/test_api.py +++ b/tests/core/views/test_api.py @@ -11,7 +11,7 @@ from lnbits.core.views.admin_api import api_auditor from lnbits.core.views.api import api_payment from lnbits.db import DB_TYPE, SQLITE from lnbits.wallets import get_wallet_class -from tests.conftest import CreateInvoiceData, api_payments_create_invoice +from tests.conftest import CreateInvoice, api_payments_create_invoice from ...helpers import ( cancel_invoice, @@ -219,9 +219,9 @@ async def test_get_payments(client, from_wallet, adminkey_headers_from): ts = time() fake_data = [ - CreateInvoiceData(amount=10, memo="aaaa"), - CreateInvoiceData(amount=100, memo="bbbb"), - CreateInvoiceData(amount=1000, memo="aabb"), + CreateInvoice(amount=10, memo="aaaa"), + CreateInvoice(amount=100, memo="bbbb"), + CreateInvoice(amount=1000, memo="aabb"), ] for invoice in fake_data: @@ -384,7 +384,7 @@ async def test_pay_real_invoice( @pytest.mark.skipif(is_fake, reason="this only works in regtest") async def test_create_real_invoice(client, adminkey_headers_from, inkey_headers_from): prev_balance = await get_node_balance_sats() - create_invoice = CreateInvoiceData(out=False, amount=1000, memo="test") + create_invoice = CreateInvoice(out=False, amount=1000, memo="test") response = await client.post( "/api/v1/payments", json=create_invoice.dict(), @@ -604,7 +604,7 @@ async def test_receive_real_invoice_set_pending_and_check_state( 5. We recheck the state of the invoice with the backend 6. We verify that the invoice is now marked as paid in the database """ - create_invoice = CreateInvoiceData(out=False, amount=1000, memo="test") + create_invoice = CreateInvoice(out=False, amount=1000, memo="test") response = await client.post( "/api/v1/payments", json=create_invoice.dict(),