[REFACTOR] cleanup views/api.py (#1865)

pull out models
line lengths
change models for tests
This commit is contained in:
dni ⚡ 2023-08-18 11:22:22 +02:00 committed by GitHub
parent 905afc1f5c
commit 59acd3a2ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 72 deletions

View file

@ -7,6 +7,7 @@ from sqlite3 import Row
from typing import Callable, Dict, List, Optional
from ecdsa import SECP256k1, SigningKey
from fastapi import Query
from lnurl import encode as lnurl_encode
from loguru import logger
from pydantic import BaseModel
@ -179,7 +180,8 @@ class Payment(FromRowModel):
return PaymentStatus(None)
logger.debug(
f"Checking {'outgoing' if self.is_out else 'incoming'} pending payment {self.checking_id}"
f"Checking {'outgoing' if self.is_out else 'incoming'} "
f"pending payment {self.checking_id}"
)
WALLET = get_wallet_class()
@ -193,7 +195,8 @@ class Payment(FromRowModel):
if self.is_in and status.pending and self.is_expired and self.expiry:
expiration_date = datetime.datetime.fromtimestamp(self.expiry)
logger.debug(
f"Deleting expired incoming pending payment {self.checking_id}: expired {expiration_date}"
f"Deleting expired incoming pending payment {self.checking_id}: "
f"expired {expiration_date}"
)
await self.delete(conn)
elif self.is_out and status.failed:
@ -203,7 +206,8 @@ class Payment(FromRowModel):
await self.delete(conn)
elif not status.pending:
logger.info(
f"Marking '{'in' if self.is_in else 'out'}' {self.checking_id} as not pending anymore: {status}"
f"Marking '{'in' if self.is_in else 'out'}' "
f"{self.checking_id} as not pending anymore: {status}"
)
await self.update_status(status, conn=conn)
return status
@ -257,3 +261,41 @@ class TinyURL(BaseModel):
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class ConversionData(BaseModel):
from_: str = "sat"
amount: float
to: str = "usd"
class Callback(BaseModel):
callback: str
class DecodePayment(BaseModel):
data: str
class CreateLnurl(BaseModel):
description_hash: str
callback: str
amount: int
comment: Optional[str] = None
description: Optional[str] = None
class CreateInvoice(BaseModel):
unit: str = "sat"
internal: bool = False
out: bool = True
amount: float = Query(None, ge=0)
memo: Optional[str] = None
description_hash: Optional[str] = None
unhashed_description: Optional[str] = None
expiry: Optional[int] = None
lnurl_callback: Optional[str] = None
lnurl_balance_check: Optional[str] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
bolt11: Optional[str] = None

View file

@ -519,7 +519,8 @@ class WebsocketConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
async def connect(self, websocket: WebSocket, item_id: str):
logger.debug(f"Websocket connected to {item_id}")
await websocket.accept()
self.active_connections.append(websocket)

View file

@ -14,7 +14,6 @@ from fastapi import (
Body,
Depends,
Header,
Query,
Request,
Response,
WebSocket,
@ -22,8 +21,6 @@ from fastapi import (
)
from fastapi.exceptions import HTTPException
from loguru import logger
from pydantic import BaseModel
from pydantic.fields import Field
from sse_starlette.sse import EventSourceResponse
from starlette.responses import RedirectResponse, StreamingResponse
@ -32,7 +29,17 @@ from lnbits.core.helpers import (
migrate_extension_database,
stop_extension_background_work,
)
from lnbits.core.models import Payment, PaymentFilters, User, Wallet
from lnbits.core.models import (
Callback,
ConversionData,
CreateInvoice,
CreateLnurl,
DecodePayment,
Payment,
PaymentFilters,
User,
Wallet,
)
from lnbits.db import Filters, Page
from lnbits.decorators import (
WalletTypeInfo,
@ -179,23 +186,7 @@ async def api_payments_paginated(
return page
class CreateInvoiceData(BaseModel):
out: Optional[bool] = True
amount: float = Query(None, ge=0)
memo: Optional[str] = None
unit: Optional[str] = "sat"
description_hash: Optional[str] = None
unhashed_description: Optional[str] = None
expiry: Optional[int] = None
lnurl_callback: Optional[str] = None
lnurl_balance_check: Optional[str] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
internal: Optional[bool] = False
bolt11: Optional[str] = None
async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
async def api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
if data.description_hash or data.unhashed_description:
try:
description_hash = (
@ -209,7 +200,10 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
except ValueError:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="'description_hash' and 'unhashed_description' must be a valid hex strings",
detail=(
"'description_hash' and 'unhashed_description' "
"must be a valid hex strings"
),
)
memo = ""
else:
@ -310,7 +304,7 @@ async def api_payments_pay_invoice(bolt11: str, wallet: Wallet):
)
async def api_payments_create(
wallet: WalletTypeInfo = Depends(require_invoice_key),
invoiceData: CreateInvoiceData = Body(...),
invoiceData: CreateInvoice = Body(...),
):
if invoiceData.out is True and wallet.wallet_type == 0:
if not invoiceData.bolt11:
@ -331,17 +325,9 @@ async def api_payments_create(
)
class CreateLNURLData(BaseModel):
description_hash: str
callback: str
amount: int
comment: Optional[str] = None
description: Optional[str] = None
@core_app.post("/api/v1/payments/lnurl")
async def api_payments_pay_lnurl(
data: CreateLNURLData, wallet: WalletTypeInfo = Depends(require_admin_key)
data: CreateLnurl, wallet: WalletTypeInfo = Depends(require_admin_key)
):
domain = urlparse(data.callback).netloc
@ -377,13 +363,19 @@ async def api_payments_pay_lnurl(
if invoice.amount_msat != data.amount:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"{domain} returned an invalid invoice. Expected {data.amount} msat, got {invoice.amount_msat}.",
detail=(
f"{domain} returned an invalid invoice. Expected {data.amount} msat, "
f"got {invoice.amount_msat}.",
),
)
if invoice.description_hash != data.description_hash:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"{domain} returned an invalid invoice. Expected description_hash == {data.description_hash}, got {invoice.description_hash}.",
detail=(
f"{domain} returned an invalid invoice. Expected description_hash == "
f"{data.description_hash}, got {invoice.description_hash}.",
),
)
extra = {}
@ -468,10 +460,11 @@ async def api_payments_sse(
async def api_payment(payment_hash, X_Api_Key: Optional[str] = Header(None)):
# We use X_Api_Key here because we want this call to work with and without keys
# If a valid key is given, we also return the field "details", otherwise not
wallet = await get_wallet_for_key(X_Api_Key) if isinstance(X_Api_Key, str) else None # type: ignore
wallet = await get_wallet_for_key(X_Api_Key) if isinstance(X_Api_Key, str) else None
# we have to specify the wallet id here, because postgres and sqlite return internal payments in different order
# and get_standalone_payment otherwise just fetches the first one, causing unpredictable results
# we have to specify the wallet id here, because postgres and sqlite return
# internal payments in different order and get_standalone_payment otherwise
# just fetches the first one, causing unpredictable results
payment = await get_standalone_payment(
payment_hash, wallet_id=wallet.id if wallet else None
)
@ -623,10 +616,6 @@ async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type
return params
class DecodePayment(BaseModel):
data: str
@core_app.post("/api/v1/payments/decode", status_code=HTTPStatus.OK)
async def api_payments_decode(data: DecodePayment, response: Response):
payment_str = data.data
@ -653,15 +642,11 @@ async def api_payments_decode(data: DecodePayment, response: Response):
return {"message": "Failed to decode"}
class Callback(BaseModel):
callback: str = Query(...)
@core_app.post("/api/v1/lnurlauth")
async def api_perform_lnurlauth(
callback: Callback, wallet: WalletTypeInfo = Depends(require_admin_key)
data: Callback, wallet: WalletTypeInfo = Depends(require_admin_key)
):
err = await perform_lnurlauth(callback.callback, wallet=wallet)
err = await perform_lnurlauth(data.callback, wallet=wallet)
if err:
raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=err.reason
@ -680,12 +665,6 @@ async def api_list_currencies_available():
return list(currencies.keys())
class ConversionData(BaseModel):
from_: str = Field("sat", alias="from")
amount: float
to: str = Query("usd")
@core_app.post("/api/v1/conversion")
async def api_fiat_as_sats(data: ConversionData):
output = {}
@ -705,7 +684,7 @@ async def api_fiat_as_sats(data: ConversionData):
@core_app.get("/api/v1/qrcode/{data}", response_class=StreamingResponse)
async def img(request: Request, data):
async def img(data):
qr = pyqrcode.create(data)
stream = BytesIO()
qr.svg(stream, scale=3)
@ -725,12 +704,9 @@ async def img(request: Request, data):
)
# UNIVERSAL WEBSOCKET MANAGER
@core_app.websocket("/api/v1/ws/{item_id}")
async def websocket_connect(websocket: WebSocket, item_id: str):
await websocketManager.connect(websocket)
await websocketManager.connect(websocket, item_id)
try:
while True:
await websocket.receive_text()
@ -808,7 +784,10 @@ async def api_install_extension(
ext_info.clean_extension_files()
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"Failed to install extension {ext_info.id} ({ext_info.installed_version}).",
detail=(
f"Failed to install extension {ext_info.id} "
f"({ext_info.installed_version})."
),
)
@ -831,7 +810,10 @@ async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)
if installed_ext and ext_id in installed_ext.dependencies:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Cannot uninstall. Extension '{installed_ext.name}' depends on this one.",
detail=(
f"Cannot uninstall. Extension '{installed_ext.name}' "
"depends on this one."
),
)
try:

View file

@ -8,8 +8,9 @@ from httpx import AsyncClient
from lnbits.app import create_app
from lnbits.commands import migrate_databases
from lnbits.core.crud import create_account, create_wallet
from lnbits.core.models import CreateInvoice
from lnbits.core.services import update_wallet_balance
from lnbits.core.views.api import CreateInvoiceData, api_payments_create_invoice
from lnbits.core.views.api import api_payments_create_invoice
from lnbits.db import Database
from lnbits.settings import settings
from tests.helpers import get_hold_invoice, get_random_invoice_data, get_real_invoice
@ -142,7 +143,7 @@ async def adminkey_headers_to(to_wallet):
@pytest_asyncio.fixture(scope="session")
async def invoice(to_wallet):
data = await get_random_invoice_data()
invoiceData = CreateInvoiceData(**data)
invoiceData = CreateInvoice(**data)
invoice = await api_payments_create_invoice(invoiceData, to_wallet)
yield invoice
del invoice

View file

@ -11,7 +11,7 @@ from lnbits.core.views.admin_api import api_auditor
from lnbits.core.views.api import api_payment
from lnbits.db import DB_TYPE, SQLITE
from lnbits.wallets import get_wallet_class
from tests.conftest import CreateInvoiceData, api_payments_create_invoice
from tests.conftest import CreateInvoice, api_payments_create_invoice
from ...helpers import (
cancel_invoice,
@ -219,9 +219,9 @@ async def test_get_payments(client, from_wallet, adminkey_headers_from):
ts = time()
fake_data = [
CreateInvoiceData(amount=10, memo="aaaa"),
CreateInvoiceData(amount=100, memo="bbbb"),
CreateInvoiceData(amount=1000, memo="aabb"),
CreateInvoice(amount=10, memo="aaaa"),
CreateInvoice(amount=100, memo="bbbb"),
CreateInvoice(amount=1000, memo="aabb"),
]
for invoice in fake_data:
@ -384,7 +384,7 @@ async def test_pay_real_invoice(
@pytest.mark.skipif(is_fake, reason="this only works in regtest")
async def test_create_real_invoice(client, adminkey_headers_from, inkey_headers_from):
prev_balance = await get_node_balance_sats()
create_invoice = CreateInvoiceData(out=False, amount=1000, memo="test")
create_invoice = CreateInvoice(out=False, amount=1000, memo="test")
response = await client.post(
"/api/v1/payments",
json=create_invoice.dict(),
@ -604,7 +604,7 @@ async def test_receive_real_invoice_set_pending_and_check_state(
5. We recheck the state of the invoice with the backend
6. We verify that the invoice is now marked as paid in the database
"""
create_invoice = CreateInvoiceData(out=False, amount=1000, memo="test")
create_invoice = CreateInvoice(out=False, amount=1000, memo="test")
response = await client.post(
"/api/v1/payments",
json=create_invoice.dict(),