[REFACTOR] cleanup views/api.py (#1865)

pull out models
line lengths
change models for tests
This commit is contained in:
dni ⚡ 2023-08-18 11:22:22 +02:00 committed by GitHub
parent 905afc1f5c
commit 59acd3a2ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 72 deletions

View file

@ -7,6 +7,7 @@ from sqlite3 import Row
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
from ecdsa import SECP256k1, SigningKey from ecdsa import SECP256k1, SigningKey
from fastapi import Query
from lnurl import encode as lnurl_encode from lnurl import encode as lnurl_encode
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
@ -179,7 +180,8 @@ class Payment(FromRowModel):
return PaymentStatus(None) return PaymentStatus(None)
logger.debug( logger.debug(
f"Checking {'outgoing' if self.is_out else 'incoming'} pending payment {self.checking_id}" f"Checking {'outgoing' if self.is_out else 'incoming'} "
f"pending payment {self.checking_id}"
) )
WALLET = get_wallet_class() WALLET = get_wallet_class()
@ -193,7 +195,8 @@ class Payment(FromRowModel):
if self.is_in and status.pending and self.is_expired and self.expiry: if self.is_in and status.pending and self.is_expired and self.expiry:
expiration_date = datetime.datetime.fromtimestamp(self.expiry) expiration_date = datetime.datetime.fromtimestamp(self.expiry)
logger.debug( logger.debug(
f"Deleting expired incoming pending payment {self.checking_id}: expired {expiration_date}" f"Deleting expired incoming pending payment {self.checking_id}: "
f"expired {expiration_date}"
) )
await self.delete(conn) await self.delete(conn)
elif self.is_out and status.failed: elif self.is_out and status.failed:
@ -203,7 +206,8 @@ class Payment(FromRowModel):
await self.delete(conn) await self.delete(conn)
elif not status.pending: elif not status.pending:
logger.info( logger.info(
f"Marking '{'in' if self.is_in else 'out'}' {self.checking_id} as not pending anymore: {status}" f"Marking '{'in' if self.is_in else 'out'}' "
f"{self.checking_id} as not pending anymore: {status}"
) )
await self.update_status(status, conn=conn) await self.update_status(status, conn=conn)
return status return status
@ -257,3 +261,41 @@ class TinyURL(BaseModel):
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: Row):
return cls(**dict(row)) return cls(**dict(row))
class ConversionData(BaseModel):
from_: str = "sat"
amount: float
to: str = "usd"
class Callback(BaseModel):
callback: str
class DecodePayment(BaseModel):
data: str
class CreateLnurl(BaseModel):
description_hash: str
callback: str
amount: int
comment: Optional[str] = None
description: Optional[str] = None
class CreateInvoice(BaseModel):
unit: str = "sat"
internal: bool = False
out: bool = True
amount: float = Query(None, ge=0)
memo: Optional[str] = None
description_hash: Optional[str] = None
unhashed_description: Optional[str] = None
expiry: Optional[int] = None
lnurl_callback: Optional[str] = None
lnurl_balance_check: Optional[str] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
bolt11: Optional[str] = None

View file

@ -519,7 +519,8 @@ class WebsocketConnectionManager:
def __init__(self): def __init__(self):
self.active_connections: List[WebSocket] = [] self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket): async def connect(self, websocket: WebSocket, item_id: str):
logger.debug(f"Websocket connected to {item_id}")
await websocket.accept() await websocket.accept()
self.active_connections.append(websocket) self.active_connections.append(websocket)

View file

@ -14,7 +14,6 @@ from fastapi import (
Body, Body,
Depends, Depends,
Header, Header,
Query,
Request, Request,
Response, Response,
WebSocket, WebSocket,
@ -22,8 +21,6 @@ from fastapi import (
) )
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from loguru import logger from loguru import logger
from pydantic import BaseModel
from pydantic.fields import Field
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from starlette.responses import RedirectResponse, StreamingResponse from starlette.responses import RedirectResponse, StreamingResponse
@ -32,7 +29,17 @@ from lnbits.core.helpers import (
migrate_extension_database, migrate_extension_database,
stop_extension_background_work, stop_extension_background_work,
) )
from lnbits.core.models import Payment, PaymentFilters, User, Wallet from lnbits.core.models import (
Callback,
ConversionData,
CreateInvoice,
CreateLnurl,
DecodePayment,
Payment,
PaymentFilters,
User,
Wallet,
)
from lnbits.db import Filters, Page from lnbits.db import Filters, Page
from lnbits.decorators import ( from lnbits.decorators import (
WalletTypeInfo, WalletTypeInfo,
@ -179,23 +186,7 @@ async def api_payments_paginated(
return page return page
class CreateInvoiceData(BaseModel): async def api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
out: Optional[bool] = True
amount: float = Query(None, ge=0)
memo: Optional[str] = None
unit: Optional[str] = "sat"
description_hash: Optional[str] = None
unhashed_description: Optional[str] = None
expiry: Optional[int] = None
lnurl_callback: Optional[str] = None
lnurl_balance_check: Optional[str] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
internal: Optional[bool] = False
bolt11: Optional[str] = None
async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
if data.description_hash or data.unhashed_description: if data.description_hash or data.unhashed_description:
try: try:
description_hash = ( description_hash = (
@ -209,7 +200,10 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail="'description_hash' and 'unhashed_description' must be a valid hex strings", detail=(
"'description_hash' and 'unhashed_description' "
"must be a valid hex strings"
),
) )
memo = "" memo = ""
else: else:
@ -310,7 +304,7 @@ async def api_payments_pay_invoice(bolt11: str, wallet: Wallet):
) )
async def api_payments_create( async def api_payments_create(
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
invoiceData: CreateInvoiceData = Body(...), invoiceData: CreateInvoice = Body(...),
): ):
if invoiceData.out is True and wallet.wallet_type == 0: if invoiceData.out is True and wallet.wallet_type == 0:
if not invoiceData.bolt11: if not invoiceData.bolt11:
@ -331,17 +325,9 @@ async def api_payments_create(
) )
class CreateLNURLData(BaseModel):
description_hash: str
callback: str
amount: int
comment: Optional[str] = None
description: Optional[str] = None
@core_app.post("/api/v1/payments/lnurl") @core_app.post("/api/v1/payments/lnurl")
async def api_payments_pay_lnurl( async def api_payments_pay_lnurl(
data: CreateLNURLData, wallet: WalletTypeInfo = Depends(require_admin_key) data: CreateLnurl, wallet: WalletTypeInfo = Depends(require_admin_key)
): ):
domain = urlparse(data.callback).netloc domain = urlparse(data.callback).netloc
@ -377,13 +363,19 @@ async def api_payments_pay_lnurl(
if invoice.amount_msat != data.amount: if invoice.amount_msat != data.amount:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=f"{domain} returned an invalid invoice. Expected {data.amount} msat, got {invoice.amount_msat}.", detail=(
f"{domain} returned an invalid invoice. Expected {data.amount} msat, "
f"got {invoice.amount_msat}.",
),
) )
if invoice.description_hash != data.description_hash: if invoice.description_hash != data.description_hash:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=f"{domain} returned an invalid invoice. Expected description_hash == {data.description_hash}, got {invoice.description_hash}.", detail=(
f"{domain} returned an invalid invoice. Expected description_hash == "
f"{data.description_hash}, got {invoice.description_hash}.",
),
) )
extra = {} extra = {}
@ -468,10 +460,11 @@ async def api_payments_sse(
async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)): async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)):
# We use X_Api_Key here because we want this call to work with and without keys # We use X_Api_Key here because we want this call to work with and without keys
# If a valid key is given, we also return the field "details", otherwise not # If a valid key is given, we also return the field "details", otherwise not
wallet = await get_wallet_for_key(X_Api_Key) if isinstance(X_Api_Key, str) else None # type: ignore wallet = await get_wallet_for_key(X_Api_Key) if isinstance(X_Api_Key, str) else None
# we have to specify the wallet id here, because postgres and sqlite return internal payments in different order # we have to specify the wallet id here, because postgres and sqlite return
# and get_standalone_payment otherwise just fetches the first one, causing unpredictable results # internal payments in different order and get_standalone_payment otherwise
# just fetches the first one, causing unpredictable results
payment = await get_standalone_payment( payment = await get_standalone_payment(
payment_hash, wallet_id=wallet.id if wallet else None payment_hash, wallet_id=wallet.id if wallet else None
) )
@ -623,10 +616,6 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type
return params return params
class DecodePayment(BaseModel):
data: str
@core_app.post("/api/v1/payments/decode", status_code=HTTPStatus.OK) @core_app.post("/api/v1/payments/decode", status_code=HTTPStatus.OK)
async def api_payments_decode(data: DecodePayment, response: Response): async def api_payments_decode(data: DecodePayment, response: Response):
payment_str = data.data payment_str = data.data
@ -653,15 +642,11 @@ async def api_payments_decode(data: DecodePayment, response: Response):
return {"message": "Failed to decode"} return {"message": "Failed to decode"}
class Callback(BaseModel):
callback: str = Query(...)
@core_app.post("/api/v1/lnurlauth") @core_app.post("/api/v1/lnurlauth")
async def api_perform_lnurlauth( async def api_perform_lnurlauth(
callback: Callback, wallet: WalletTypeInfo = Depends(require_admin_key) data: Callback, wallet: WalletTypeInfo = Depends(require_admin_key)
): ):
err = await perform_lnurlauth(callback.callback, wallet=wallet) err = await perform_lnurlauth(data.callback, wallet=wallet)
if err: if err:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=err.reason status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=err.reason
@ -680,12 +665,6 @@ async def api_list_currencies_available():
return list(currencies.keys()) return list(currencies.keys())
class ConversionData(BaseModel):
from_: str = Field("sat", alias="from")
amount: float
to: str = Query("usd")
@core_app.post("/api/v1/conversion") @core_app.post("/api/v1/conversion")
async def api_fiat_as_sats(data: ConversionData): async def api_fiat_as_sats(data: ConversionData):
output = {} output = {}
@ -705,7 +684,7 @@ async def api_fiat_as_sats(data: ConversionData):
@core_app.get("/api/v1/qrcode/{data}", response_class=StreamingResponse) @core_app.get("/api/v1/qrcode/{data}", response_class=StreamingResponse)
async def img(request: Request, data): async def img(data):
qr = pyqrcode.create(data) qr = pyqrcode.create(data)
stream = BytesIO() stream = BytesIO()
qr.svg(stream, scale=3) qr.svg(stream, scale=3)
@ -725,12 +704,9 @@ async def img(request: Request, data):
) )
# UNIVERSAL WEBSOCKET MANAGER
@core_app.websocket("/api/v1/ws/{item_id}") @core_app.websocket("/api/v1/ws/{item_id}")
async def websocket_connect(websocket: WebSocket, item_id: str): async def websocket_connect(websocket: WebSocket, item_id: str):
await websocketManager.connect(websocket) await websocketManager.connect(websocket, item_id)
try: try:
while True: while True:
await websocket.receive_text() await websocket.receive_text()
@ -808,7 +784,10 @@ async def api_install_extension(
ext_info.clean_extension_files() ext_info.clean_extension_files()
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"Failed to install extension {ext_info.id} ({ext_info.installed_version}).", detail=(
f"Failed to install extension {ext_info.id} "
f"({ext_info.installed_version})."
),
) )
@ -831,7 +810,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)
if installed_ext and ext_id in installed_ext.dependencies: if installed_ext and ext_id in installed_ext.dependencies:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=f"Cannot uninstall. Extension '{installed_ext.name}' depends on this one.", detail=(
f"Cannot uninstall. Extension '{installed_ext.name}' "
"depends on this one."
),
) )
try: try:

View file

@ -8,8 +8,9 @@ from httpx import AsyncClient
from lnbits.app import create_app from lnbits.app import create_app
from lnbits.commands import migrate_databases from lnbits.commands import migrate_databases
from lnbits.core.crud import create_account, create_wallet from lnbits.core.crud import create_account, create_wallet
from lnbits.core.models import CreateInvoice
from lnbits.core.services import update_wallet_balance from lnbits.core.services import update_wallet_balance
from lnbits.core.views.api import CreateInvoiceData, api_payments_create_invoice from lnbits.core.views.api import api_payments_create_invoice
from lnbits.db import Database from lnbits.db import Database
from lnbits.settings import settings from lnbits.settings import settings
from tests.helpers import get_hold_invoice, get_random_invoice_data, get_real_invoice from tests.helpers import get_hold_invoice, get_random_invoice_data, get_real_invoice
@ -142,7 +143,7 @@ async def adminkey_headers_to(to_wallet):
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def invoice(to_wallet): async def invoice(to_wallet):
data = await get_random_invoice_data() data = await get_random_invoice_data()
invoiceData = CreateInvoiceData(**data) invoiceData = CreateInvoice(**data)
invoice = await api_payments_create_invoice(invoiceData, to_wallet) invoice = await api_payments_create_invoice(invoiceData, to_wallet)
yield invoice yield invoice
del invoice del invoice

View file

@ -11,7 +11,7 @@ from lnbits.core.views.admin_api import api_auditor
from lnbits.core.views.api import api_payment from lnbits.core.views.api import api_payment
from lnbits.db import DB_TYPE, SQLITE from lnbits.db import DB_TYPE, SQLITE
from lnbits.wallets import get_wallet_class from lnbits.wallets import get_wallet_class
from tests.conftest import CreateInvoiceData, api_payments_create_invoice from tests.conftest import CreateInvoice, api_payments_create_invoice
from ...helpers import ( from ...helpers import (
cancel_invoice, cancel_invoice,
@ -219,9 +219,9 @@ async def test_get_payments(client, from_wallet, adminkey_headers_from):
ts = time() ts = time()
fake_data = [ fake_data = [
CreateInvoiceData(amount=10, memo="aaaa"), CreateInvoice(amount=10, memo="aaaa"),
CreateInvoiceData(amount=100, memo="bbbb"), CreateInvoice(amount=100, memo="bbbb"),
CreateInvoiceData(amount=1000, memo="aabb"), CreateInvoice(amount=1000, memo="aabb"),
] ]
for invoice in fake_data: for invoice in fake_data:
@ -384,7 +384,7 @@ async def test_pay_real_invoice(
@pytest.mark.skipif(is_fake, reason="this only works in regtest") @pytest.mark.skipif(is_fake, reason="this only works in regtest")
async def test_create_real_invoice(client, adminkey_headers_from, inkey_headers_from): async def test_create_real_invoice(client, adminkey_headers_from, inkey_headers_from):
prev_balance = await get_node_balance_sats() prev_balance = await get_node_balance_sats()
create_invoice = CreateInvoiceData(out=False, amount=1000, memo="test") create_invoice = CreateInvoice(out=False, amount=1000, memo="test")
response = await client.post( response = await client.post(
"/api/v1/payments", "/api/v1/payments",
json=create_invoice.dict(), json=create_invoice.dict(),
@ -604,7 +604,7 @@ async def test_receive_real_invoice_set_pending_and_check_state(
5. We recheck the state of the invoice with the backend 5. We recheck the state of the invoice with the backend
6. We verify that the invoice is now marked as paid in the database 6. We verify that the invoice is now marked as paid in the database
""" """
create_invoice = CreateInvoiceData(out=False, amount=1000, memo="test") create_invoice = CreateInvoice(out=False, amount=1000, memo="test")
response = await client.post( response = await client.post(
"/api/v1/payments", "/api/v1/payments",
json=create_invoice.dict(), json=create_invoice.dict(),