fix pyright lnbits/extensions

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
Pavol Rusnak 2023-02-02 12:57:52 +00:00 committed by dni ⚡
parent 09eab69e12
commit 3bd68cb394
No known key found for this signature in database
GPG key ID: 886317704CC4E618
25 changed files with 142 additions and 160 deletions

View file

@ -80,7 +80,5 @@ async def fetch_fiat_exchange_rate(currency: str, provider: str):
else: else:
data = {} data = {}
getter = exchange_rate_providers[provider]["getter"] getter = exchange_rate_providers[provider]["getter"]
print(getter) assert callable(getter), "cannot call getter function"
if callable(getter): return float(getter(data, replacements))
rate = float(getter(data, replacements))
return rate

View file

@ -68,10 +68,11 @@ class Proof(BaseModel):
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
assert "secret" in d, "no secret in proof" assert "secret" in d, "no secret in proof"
assert "amount" in d, "no amount in proof" assert "amount" in d, "no amount in proof"
assert "C" in d, "no C in proof"
return cls( return cls(
amount=d.get("amount"), amount=d["amount"],
C=d.get("C"), C=d["C"],
secret=d.get("secret"), secret=d["secret"],
reserved=d.get("reserved") or False, reserved=d.get("reserved") or False,
send_id=d.get("send_id") or "", send_id=d.get("send_id") or "",
time_created=d.get("time_created") or "", time_created=d.get("time_created") or "",

View file

@ -1,7 +1,8 @@
import math import math
from typing import Tuple
def si_classifier(val): def si_classifier(val) -> dict:
suffixes = { suffixes = {
24: {"long_suffix": "yotta", "short_suffix": "Y", "scalar": 10**24}, 24: {"long_suffix": "yotta", "short_suffix": "Y", "scalar": 10**24},
21: {"long_suffix": "zetta", "short_suffix": "Z", "scalar": 10**21}, 21: {"long_suffix": "zetta", "short_suffix": "Z", "scalar": 10**21},
@ -22,24 +23,22 @@ def si_classifier(val):
-24: {"long_suffix": "yocto", "short_suffix": "y", "scalar": 10**-24}, -24: {"long_suffix": "yocto", "short_suffix": "y", "scalar": 10**-24},
} }
exponent = int(math.floor(math.log10(abs(val)) / 3.0) * 3) exponent = int(math.floor(math.log10(abs(val)) / 3.0) * 3)
return suffixes.get(exponent, None) suffix = suffixes.get(exponent)
assert suffix, f"could not classify: {val}"
return suffix
def si_formatter(value): def si_formatter(value) -> Tuple:
""" """
Return a triple of scaled value, short suffix, long suffix, or None if Return a triple of scaled value, short suffix, long suffix, or None if
the value cannot be classified. the value cannot be classified.
""" """
classifier = si_classifier(value) classifier = si_classifier(value)
if classifier is None:
# Don't know how to classify this value
return None
scaled = value / classifier["scalar"] scaled = value / classifier["scalar"]
return (scaled, classifier["short_suffix"], classifier["long_suffix"]) return scaled, classifier["short_suffix"], classifier["long_suffix"]
def si_format(value, precision=4, long_form=False, separator=""): def si_format(value: float, precision=4, long_form=False, separator="") -> str:
""" """
"SI prefix" formatted string: return a string with the given precision "SI prefix" formatted string: return a string with the given precision
and an appropriate order-of-3-magnitudes suffix, e.g.: and an appropriate order-of-3-magnitudes suffix, e.g.:
@ -47,11 +46,6 @@ def si_format(value, precision=4, long_form=False, separator=""):
si_format(0.00000000123, long_form=True, separator=' ') => '1.230 nano' si_format(0.00000000123, long_form=True, separator=' ') => '1.230 nano'
""" """
scaled, short_suffix, long_suffix = si_formatter(value) scaled, short_suffix, long_suffix = si_formatter(value)
if scaled is None:
# Don't know how to format this value
return value
suffix = long_suffix if long_form else short_suffix suffix = long_suffix if long_form else short_suffix
if abs(scaled) < 10: if abs(scaled) < 10:

View file

@ -23,14 +23,14 @@ async def create_livestream(*, wallet_id: str) -> int:
if db.type == SQLITE: if db.type == SQLITE:
return result._result_proxy.lastrowid return result._result_proxy.lastrowid
else: else:
return result[0] return result[0] # type: ignore
async def get_livestream(ls_id: int) -> Optional[Livestream]: async def get_livestream(ls_id: int) -> Optional[Livestream]:
row = await db.fetchone( row = await db.fetchone(
"SELECT * FROM livestream.livestreams WHERE id = ?", (ls_id,) "SELECT * FROM livestream.livestreams WHERE id = ?", (ls_id,)
) )
return Livestream(**dict(row)) if row else None return Livestream(**row) if row else None
async def get_livestream_by_track(track_id: int) -> Optional[Livestream]: async def get_livestream_by_track(track_id: int) -> Optional[Livestream]:
@ -42,7 +42,7 @@ async def get_livestream_by_track(track_id: int) -> Optional[Livestream]:
""", """,
(track_id,), (track_id,),
) )
return Livestream(**dict(row)) if row else None return Livestream(**row) if row else None
async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream]: async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream]:
@ -55,7 +55,7 @@ async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream
ls_id = await create_livestream(wallet_id=wallet) ls_id = await create_livestream(wallet_id=wallet)
return await get_livestream(ls_id) return await get_livestream(ls_id)
return Livestream(**dict(row)) if row else None return Livestream(**row) if row else None
async def update_current_track(ls_id: int, track_id: Optional[int]): async def update_current_track(ls_id: int, track_id: Optional[int]):
@ -121,7 +121,7 @@ async def get_track(track_id: Optional[int]) -> Optional[Track]:
""", """,
(track_id,), (track_id,),
) )
return Track(**dict(row)) if row else None return Track(**row) if row else None
async def get_tracks(livestream: int) -> List[Track]: async def get_tracks(livestream: int) -> List[Track]:
@ -132,7 +132,7 @@ async def get_tracks(livestream: int) -> List[Track]:
""", """,
(livestream,), (livestream,),
) )
return [Track(**dict(row)) for row in rows] return [Track(**row) for row in rows]
async def delete_track_from_livestream(livestream: int, track_id: int): async def delete_track_from_livestream(livestream: int, track_id: int):
@ -174,7 +174,7 @@ async def add_producer(livestream: int, name: str) -> int:
if db.type == SQLITE: if db.type == SQLITE:
return result._result_proxy.lastrowid return result._result_proxy.lastrowid
else: else:
return result[0] return result[0] # type: ignore
async def get_producer(producer_id: int) -> Optional[Producer]: async def get_producer(producer_id: int) -> Optional[Producer]:
@ -185,7 +185,7 @@ async def get_producer(producer_id: int) -> Optional[Producer]:
""", """,
(producer_id,), (producer_id,),
) )
return Producer(**dict(row)) if row else None return Producer(**row) if row else None
async def get_producers(livestream: int) -> List[Producer]: async def get_producers(livestream: int) -> List[Producer]:
@ -196,4 +196,4 @@ async def get_producers(livestream: int) -> List[Producer]:
""", """,
(livestream,), (livestream,),
) )
return [Producer(**dict(row)) for row in rows] return [Producer(**row) for row in rows]

View file

@ -1,4 +1,5 @@
import json import json
from sqlite3 import Row
from typing import Optional from typing import Optional
from fastapi import Query, Request from fastapi import Query, Request
@ -27,6 +28,10 @@ class Livestream(BaseModel):
url = request.url_for("livestream.lnurl_livestream", ls_id=self.id) url = request.url_for("livestream.lnurl_livestream", ls_id=self.id)
return lnurl_encode(url) return lnurl_encode(url)
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class Track(BaseModel): class Track(BaseModel):
id: int id: int
@ -35,6 +40,10 @@ class Track(BaseModel):
name: str name: str
producer: int producer: int
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
@property @property
def min_sendable(self) -> int: def min_sendable(self) -> int:
return min(100_000, self.price_msat or 100_000) return min(100_000, self.price_msat or 100_000)
@ -88,3 +97,7 @@ class Producer(BaseModel):
user: str user: str
wallet: str wallet: str
name: str name: str
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))

View file

@ -71,9 +71,7 @@ async def api_domain_create(
if not cf_response or not cf_response["success"]: if not cf_response or not cf_response["success"]:
await delete_domain(domain.id) await delete_domain(domain.id)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST, detail="Problem with cloudflare."
detail="Problem with cloudflare: "
+ cf_response["errors"][0]["message"],
) )
return domain.dict() return domain.dict()

View file

@ -61,8 +61,8 @@ class PayLink(BaseModel):
def success_action(self, payment_hash: str) -> Optional[Dict]: def success_action(self, payment_hash: str) -> Optional[Dict]:
if self.success_url: if self.success_url:
url: ParseResult = urlparse(self.success_url) url: ParseResult = urlparse(self.success_url)
qs: Dict = parse_qs(url.query) qs = parse_qs(url.query)
qs["payment_hash"] = payment_hash setattr(qs, "payment_hash", payment_hash)
url = url._replace(query=urlencode(qs, doseq=True)) url = url._replace(query=urlencode(qs, doseq=True))
return { return {
"tag": "url", "tag": "url",

View file

@ -6,6 +6,7 @@ and delivery to the specific person
import json import json
from collections import defaultdict from collections import defaultdict
from typing import AsyncGenerator
from fastapi import WebSocket from fastapi import WebSocket
from loguru import logger from loguru import logger
@ -34,7 +35,7 @@ class Notifier:
# Create notification generator: # Create notification generator:
self.generator = self.get_notification_generator() self.generator = self.get_notification_generator()
async def get_notification_generator(self): async def get_notification_generator(self) -> AsyncGenerator:
"""Notification Generator""" """Notification Generator"""
while True: while True:
@ -54,9 +55,8 @@ class Notifier:
logger.exception(f"There is no member in room: {room_name}") logger.exception(f"There is no member in room: {room_name}")
return None return None
async def push(self, message: str, room_name: str = None): async def push(self, message: str, room_name: str):
"""Push a message""" """Push a message"""
message_body = {"message": message, "room_name": room_name} message_body = {"message": message, "room_name": room_name}
await self.generator.asend(message_body) await self.generator.asend(message_body)

View file

@ -1,15 +1,8 @@
import json
from http import HTTPStatus from http import HTTPStatus
from fastapi import ( from fastapi import Depends, Query, Request, WebSocket, WebSocketDisconnect
BackgroundTasks,
Depends,
Query,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from loguru import logger
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
@ -147,24 +140,14 @@ notifier = Notifier()
@market_ext.websocket("/ws/{room_name}") @market_ext.websocket("/ws/{room_name}")
async def websocket_endpoint( async def websocket_endpoint(websocket: WebSocket, room_name: str):
websocket: WebSocket, room_name: str, background_tasks: BackgroundTasks
):
await notifier.connect(websocket, room_name) await notifier.connect(websocket, room_name)
try: try:
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
d = json.loads(data) room_members = notifier.get_members(room_name) or []
d["room_name"] = room_name
room_members = (
notifier.get_members(room_name)
if notifier.get_members(room_name) is not None
else []
)
if websocket not in room_members: if websocket not in room_members:
print("Sender not in room member: Reconnecting...") logger.warning("Sender not in room member: Reconnecting...")
await notifier.connect(websocket, room_name) await notifier.connect(websocket, room_name)
await notifier._notify(data, room_name) await notifier._notify(data, room_name)

View file

@ -1,4 +1,5 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Optional
from fastapi import Depends, Query from fastapi import Depends, Query
from loguru import logger from loguru import logger
@ -224,7 +225,7 @@ async def api_market_stalls(
@market_ext.put("/api/v1/stalls/{stall_id}") @market_ext.put("/api/v1/stalls/{stall_id}")
async def api_market_stall_create( async def api_market_stall_create(
data: createStalls, data: createStalls,
stall_id: str = None, stall_id: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
): ):
@ -447,7 +448,7 @@ async def api_market_market_stalls(market_id: str):
@market_ext.put("/api/v1/markets/{market_id}") @market_ext.put("/api/v1/markets/{market_id}")
async def api_market_market_create( async def api_market_market_create(
data: CreateMarket, data: CreateMarket,
market_id: str = None, market_id: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
): ):
if market_id: if market_id:
@ -506,7 +507,7 @@ async def api_get_settings(wallet: WalletTypeInfo = Depends(require_admin_key)):
@market_ext.put("/api/v1/settings/{usr}") @market_ext.put("/api/v1/settings/{usr}")
async def api_set_settings( async def api_set_settings(
data: SetSettings, data: SetSettings,
usr: str = None, usr: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_admin_key), wallet: WalletTypeInfo = Depends(require_admin_key),
): ):
if usr: if usr:

View file

@ -2,7 +2,7 @@ from os import getenv
from fastapi import Depends, Request from fastapi import Depends, Request
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from pyngrok import conf, ngrok from pyngrok import conf, ngrok # type: ignore
from lnbits.core.models import User from lnbits.core.models import User
from lnbits.decorators import check_user_exists from lnbits.decorators import check_user_exists

View file

@ -22,12 +22,12 @@ async def create_shop(*, wallet_id: str) -> int:
if db.type == SQLITE: if db.type == SQLITE:
return result._result_proxy.lastrowid return result._result_proxy.lastrowid
else: else:
return result[0] return result[0] # type: ignore
async def get_shop(id: int) -> Optional[Shop]: async def get_shop(id: int) -> Optional[Shop]:
row = await db.fetchone("SELECT * FROM offlineshop.shops WHERE id = ?", (id,)) row = await db.fetchone("SELECT * FROM offlineshop.shops WHERE id = ?", (id,))
return Shop(**dict(row)) if row else None return Shop(**row) if row else None
async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]: async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]:
@ -40,7 +40,7 @@ async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]:
ls_id = await create_shop(wallet_id=wallet) ls_id = await create_shop(wallet_id=wallet)
return await get_shop(ls_id) return await get_shop(ls_id)
return Shop(**dict(row)) if row else None return Shop(**row) if row else None
async def set_method(shop: int, method: str, wordlist: str = "") -> Optional[Shop]: async def set_method(shop: int, method: str, wordlist: str = "") -> Optional[Shop]:

View file

@ -52,7 +52,7 @@ class ShopCounter:
# cleanup confirmation words cache # cleanup confirmation words cache
to_remove = len(self.fulfilled_payments) - 23 to_remove = len(self.fulfilled_payments) - 23
if to_remove > 0: if to_remove > 0:
for i in range(to_remove): for _ in range(to_remove):
self.fulfilled_payments.popitem(False) self.fulfilled_payments.popitem(False)
return word return word
@ -64,6 +64,10 @@ class Shop(BaseModel):
method: str method: str
wordlist: str wordlist: str
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
@property @property
def otp_key(self) -> str: def otp_key(self) -> str:
return base64.b32encode( return base64.b32encode(

View file

@ -18,8 +18,6 @@ from .crud import (
) )
from .models import CreateSatsDicePayment from .models import CreateSatsDicePayment
##############LNURLP STUFF
@satsdice_ext.get( @satsdice_ext.get(
"/api/v1/lnurlp/{link_id}", "/api/v1/lnurlp/{link_id}",
@ -84,7 +82,7 @@ async def api_lnurlp_callback(
data = CreateSatsDicePayment( data = CreateSatsDicePayment(
satsdice_pay=link.id, satsdice_pay=link.id,
value=amount_received / 1000, value=int(amount_received / 1000),
payment_hash=payment_hash, payment_hash=payment_hash,
) )

View file

@ -13,7 +13,7 @@ from .helpers import fetch_onchain_balance
from .models import Charges, CreateCharge, SatsPayThemes from .models import Charges, CreateCharge, SatsPayThemes
async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]: async def create_charge(user: str, data: CreateCharge) -> Charges:
data = CreateCharge(**data.dict()) data = CreateCharge(**data.dict())
charge_id = urlsafe_short_hash() charge_id = urlsafe_short_hash()
if data.onchainwallet: if data.onchainwallet:
@ -79,7 +79,9 @@ async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]:
data.custom_css, data.custom_css,
), ),
) )
return await get_charge(charge_id) charge = await get_charge(charge_id)
assert charge, "Newly created charge does not exist"
return charge
async def update_charge(charge_id: str, **kwargs) -> Optional[Charges]: async def update_charge(charge_id: str, **kwargs) -> Optional[Charges]:

View file

@ -10,7 +10,7 @@ async def get_targets(source_wallet: str) -> List[Target]:
rows = await db.fetchall( rows = await db.fetchall(
"SELECT * FROM splitpayments.targets WHERE source = ?", (source_wallet,) "SELECT * FROM splitpayments.targets WHERE source = ?", (source_wallet,)
) )
return [Target(**dict(row)) for row in rows] return [Target(**row) for row in rows]
async def set_targets(source_wallet: str, targets: List[Target]): async def set_targets(source_wallet: str, targets: List[Target]):

View file

@ -1,6 +1,7 @@
from sqlite3 import Row
from typing import List, Optional from typing import List, Optional
from fastapi.param_functions import Query from fastapi import Query
from pydantic import BaseModel from pydantic import BaseModel
@ -11,6 +12,10 @@ class Target(BaseModel):
tag: str tag: str
alias: Optional[str] alias: Optional[str]
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class TargetPutList(BaseModel): class TargetPutList(BaseModel):
wallet: str = Query(...) wallet: str = Query(...)

View file

@ -147,14 +147,16 @@ async def create_service(data: CreateService) -> Service:
if db.type == SQLITE: if db.type == SQLITE:
service_id = result._result_proxy.lastrowid service_id = result._result_proxy.lastrowid
else: else:
service_id = result[0] service_id = result[0] # type: ignore
service = await get_service(service_id) service = await get_service(service_id)
assert service assert service
return service return service
async def get_service(service_id: int, by_state: str = None) -> Optional[Service]: async def get_service(
service_id: int, by_state: Optional[str] = None
) -> Optional[Service]:
"""Return a service either by ID or, available, by state """Return a service either by ID or, available, by state
Each Service's donation page is reached through its "state" hash Each Service's donation page is reached through its "state" hash
@ -184,7 +186,9 @@ async def authenticate_service(service_id, code, redirect_uri):
"""Use authentication code from third party API to retreive access token""" """Use authentication code from third party API to retreive access token"""
# The API token is passed in the querystring as 'code' # The API token is passed in the querystring as 'code'
service = await get_service(service_id) service = await get_service(service_id)
assert service
wallet = await get_wallet(service.wallet) wallet = await get_wallet(service.wallet)
assert wallet
user = wallet.user user = wallet.user
url = "https://streamlabs.com/api/v1.0/token" url = "https://streamlabs.com/api/v1.0/token"
data = { data = {
@ -208,8 +212,11 @@ async def service_add_token(service_id, token):
is not overwritten. is not overwritten.
Tokens for Streamlabs never need to be refreshed. Tokens for Streamlabs never need to be refreshed.
""" """
if (await get_service(service_id)).authenticated: service = await get_service(service_id)
assert service
if service.authenticated:
return False return False
await db.execute( await db.execute(
"UPDATE streamalerts.Services SET authenticated = 1, token = ? where id = ?", "UPDATE streamalerts.Services SET authenticated = 1, token = ? where id = ?",
(token, service_id), (token, service_id),

View file

@ -1,5 +1,3 @@
import json
import httpx import httpx
from .models import Domains from .models import Domains
@ -20,25 +18,21 @@ async def cloudflare_create_subdomain(
"Content-Type": "application/json", "Content-Type": "application/json",
} }
aRecord = subdomain + "." + domain.domain aRecord = subdomain + "." + domain.domain
cf_response = ""
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: r = await client.post(
r = await client.post( url,
url, headers=header,
headers=header, json={
json={ "type": record_type,
"type": record_type, "name": aRecord,
"name": aRecord, "content": ip,
"content": ip, "ttl": 0,
"ttl": 0, "proxied": False,
"proxied": False, },
}, timeout=40,
timeout=40, )
) r.raise_for_status()
cf_response = json.loads(r.text) return r.json()
except AssertionError:
cf_response = "Error occured"
return cf_response
async def cloudflare_deletesubdomain(domain: Domains, domain_id: str): async def cloudflare_deletesubdomain(domain: Domains, domain_id: str):
@ -52,8 +46,4 @@ async def cloudflare_deletesubdomain(domain: Domains, domain_id: str):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: await client.delete(url + "/" + domain_id, headers=header, timeout=40)
r = await client.delete(url + "/" + domain_id, headers=header, timeout=40)
r.text
except AssertionError:
pass

View file

@ -121,21 +121,21 @@ async def api_subdomain_make_subdomain(domain_id, data: CreateSubdomain):
detail=f"{data.subdomain}.{domain.domain} domain already taken.", detail=f"{data.subdomain}.{domain.domain} domain already taken.",
) )
## Dry run cloudflare... (create and if create is sucessful delete it) ## Dry run cloudflare... (create and if create is successful delete it)
cf_response = await cloudflare_create_subdomain( try:
domain=domain, res_json = await cloudflare_create_subdomain(
subdomain=data.subdomain, domain=domain,
record_type=data.record_type, subdomain=data.subdomain,
ip=data.ip, record_type=data.record_type,
) ip=data.ip,
if cf_response["success"] is True:
await cloudflare_deletesubdomain(
domain=domain, domain_id=cf_response["result"]["id"]
) )
else: await cloudflare_deletesubdomain(
domain=domain, domain_id=res_json["result"]["id"]
)
except:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=f'Problem with cloudflare: {cf_response["errors"][0]["message"]}', detail="Problem with cloudflare.",
) )
## ALL OK - create an invoice and return it to the user ## ALL OK - create an invoice and return it to the user

View file

@ -1,20 +0,0 @@
from lnbits.core.crud import get_wallet
from .crud import get_tipjar
async def get_charge_details(tipjar_id):
"""Return the default details for a satspay charge"""
tipjar = await get_tipjar(tipjar_id)
wallet_id = tipjar.wallet
wallet = await get_wallet(wallet_id)
user = wallet.user
details = {
"time": 1440,
"user": user,
"lnbitswallet": wallet_id,
"onchainwallet": tipjar.onchain,
"completelink": "/tipjar/" + str(tipjar_id),
"completelinktext": "Thanks for the tip!",
}
return details

View file

@ -1,7 +1,6 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import Depends, Request from fastapi import Depends, Query, Request
from fastapi.param_functions import Query
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException

View file

@ -3,7 +3,7 @@ from http import HTTPStatus
from fastapi import Depends, Query from fastapi import Depends, Query
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from lnbits.core.crud import get_user from lnbits.core.crud import get_user, get_wallet
from lnbits.decorators import WalletTypeInfo, get_key_type from lnbits.decorators import WalletTypeInfo, get_key_type
# todo: use the API, not direct import # todo: use the API, not direct import
@ -22,7 +22,6 @@ from .crud import (
update_tip, update_tip,
update_tipjar, update_tipjar,
) )
from .helpers import get_charge_details
from .models import createTip, createTipJar, createTips from .models import createTip, createTipJar, createTips
@ -55,25 +54,32 @@ async def api_create_tip(data: createTips):
status_code=HTTPStatus.NOT_FOUND, detail="Tipjar does not exist." status_code=HTTPStatus.NOT_FOUND, detail="Tipjar does not exist."
) )
webhook = tipjar.webhook wallet_id = tipjar.wallet
charge_details = await get_charge_details(tipjar.id) wallet = await get_wallet(wallet_id)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Tipjar wallet does not exist."
)
name = data.name name = data.name
# Ensure that description string can be split reliably # Ensure that description string can be split reliably
name = name.replace('"', "''") name = name.replace('"', "''")
if not name: if not name:
name = "Anonymous" name = "Anonymous"
description = f"{name}: {message}" description = f"{name}: {message}"
charge = await create_charge( charge = await create_charge(
user=charge_details["user"], user=wallet.user,
data=CreateCharge( data=CreateCharge(
amount=sats, amount=sats,
webhook=webhook or "", webhook=tipjar.webhook or "",
description=description, description=description,
onchainwallet=charge_details["onchainwallet"], onchainwallet=tipjar.onchain or "",
lnbitswallet=charge_details["lnbitswallet"], lnbitswallet=tipjar.wallet,
completelink=charge_details["completelink"], completelink="/tipjar/" + str(tipjar_id),
completelinktext=charge_details["completelinktext"], completelinktext="Thanks for the tip!",
time=charge_details["time"], time=1440,
custom_css="", custom_css="",
), ),
) )

View file

@ -1,3 +1,5 @@
from typing import Optional, Tuple
from embit.descriptor import Descriptor, Key from embit.descriptor import Descriptor, Key
from embit.descriptor.arguments import AllowedDerivation from embit.descriptor.arguments import AllowedDerivation
from embit.networks import NETWORKS from embit.networks import NETWORKS
@ -12,7 +14,7 @@ def detect_network(k):
return net return net
def parse_key(masterpub: str) -> Descriptor: def parse_key(masterpub: str) -> Tuple[Descriptor, Optional[dict]]:
"""Parses masterpub or descriptor and returns a tuple: (Descriptor, network) """Parses masterpub or descriptor and returns a tuple: (Descriptor, network)
To create addresses use descriptor.derive(num).address(network=network) To create addresses use descriptor.derive(num).address(network=network)
""" """
@ -34,6 +36,7 @@ def parse_key(masterpub: str) -> Descriptor:
k.allowed_derivation = AllowedDerivation.default() k.allowed_derivation = AllowedDerivation.default()
# get version bytes # get version bytes
version = k.key.version version = k.key.version
desc = Descriptor()
for network_name in NETWORKS: for network_name in NETWORKS:
net = NETWORKS[network_name] net = NETWORKS[network_name]
# not found in this network # not found in this network
@ -47,8 +50,9 @@ def parse_key(masterpub: str) -> Descriptor:
desc = Descriptor.from_string("wpkh(%s)" % str(k)) desc = Descriptor.from_string("wpkh(%s)" % str(k))
break break
# we didn't find correct version # we didn't find correct version
if network is None: if not network:
raise ValueError("Unknown master public key version") raise ValueError("Unknown master public key version")
else: else:
desc = Descriptor.from_string(masterpub) desc = Descriptor.from_string(masterpub)
if not desc.is_wildcard: if not desc.is_wildcard:
@ -61,6 +65,7 @@ def parse_key(masterpub: str) -> Descriptor:
if network is not None and network != net: if network is not None and network != net:
raise ValueError("Keys from different networks") raise ValueError("Keys from different networks")
network = net network = net
return desc, network return desc, network

View file

@ -73,7 +73,8 @@ async def api_wallet_create_or_update(
data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key) data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key)
): ):
try: try:
(descriptor, network) = parse_key(data.masterpub) descriptor, network = parse_key(data.masterpub)
assert network
if data.network != network["name"]: if data.network != network["name"]:
raise ValueError( raise ValueError(
"Account network error. This account is for '{}'".format( "Account network error. This account is for '{}'".format(
@ -308,12 +309,9 @@ async def api_psbt_utxos_tx(
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
@watchonly_ext.put("/api/v1/psbt/extract") @watchonly_ext.put("/api/v1/psbt/extract", dependencies=[Depends(require_admin_key)])
async def api_psbt_extract_tx( async def api_psbt_extract_tx(data: ExtractPsbt):
data: ExtractPsbt, w: WalletTypeInfo = Depends(require_admin_key)
):
network = NETWORKS["main"] if data.network == "Mainnet" else NETWORKS["test"] network = NETWORKS["main"] if data.network == "Mainnet" else NETWORKS["test"]
res = SignedTransaction()
try: try:
psbt = PSBT.from_base64(data.psbtBase64) psbt = PSBT.from_base64(data.psbtBase64)
for i, inp in enumerate(data.inputs): for i, inp in enumerate(data.inputs):
@ -322,9 +320,9 @@ async def api_psbt_extract_tx(
final_psbt = finalizer.finalize_psbt(psbt) final_psbt = finalizer.finalize_psbt(psbt)
if not final_psbt: if not final_psbt:
raise ValueError("PSBT cannot be finalized!") raise ValueError("PSBT cannot be finalized!")
res.tx_hex = final_psbt.to_string()
transaction = Transaction.from_string(res.tx_hex) tx_hex = final_psbt.to_string()
transaction = Transaction.from_string(tx_hex)
tx = { tx = {
"locktime": transaction.locktime, "locktime": transaction.locktime,
"version": transaction.version, "version": transaction.version,
@ -336,10 +334,10 @@ async def api_psbt_extract_tx(
tx["outputs"].append( tx["outputs"].append(
{"amount": out.value, "address": out.script_pubkey.address(network)} {"amount": out.value, "address": out.script_pubkey.address(network)}
) )
res.tx_json = json.dumps(tx) signed_tx = SignedTransaction(tx_hex=tx_hex, tx_json=json.dumps(tx))
return signed_tx.dict()
except Exception as e: except Exception as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
return res.dict()
@watchonly_ext.post("/api/v1/tx") @watchonly_ext.post("/api/v1/tx")