Merge pull request #1469 from lnbits/pyright4-extensions-only

pyright but only for extensions
This commit is contained in:
Arc 2023-02-10 09:49:00 +00:00 committed by GitHub
commit 741f1a3daf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
32 changed files with 221 additions and 191 deletions

View file

@ -10,18 +10,18 @@ jobs:
python-version: ["3.9"] python-version: ["3.9"]
poetry-version: ["1.3.1"] poetry-version: ["1.3.1"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Set up Poetry ${{ matrix.poetry-version }} - name: Set up Poetry ${{ matrix.poetry-version }}
uses: abatilo/actions-poetry@v2 uses: abatilo/actions-poetry@v2
with: with:
poetry-version: ${{ matrix.poetry-version }} poetry-version: ${{ matrix.poetry-version }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: "poetry"
- name: Install dependencies - name: Install dependencies
run: | run: |
poetry config virtualenvs.create false
poetry install poetry install
- name: Run tests - name: Run tests
run: make flake8 run: make flake8

View file

@ -19,9 +19,9 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: 'poetry' cache: "poetry"
- name: Install dependencies - name: Install dependencies
run: | run: |
poetry install poetry install
- name: Run tests - name: Run tests
run: poetry run mypy run: make mypy

View file

@ -10,18 +10,18 @@ jobs:
python-version: ["3.9"] python-version: ["3.9"]
poetry-version: ["1.3.1"] poetry-version: ["1.3.1"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Set up Poetry ${{ matrix.poetry-version }} - name: Set up Poetry ${{ matrix.poetry-version }}
uses: abatilo/actions-poetry@v2 uses: abatilo/actions-poetry@v2
with: with:
poetry-version: ${{ matrix.poetry-version }} poetry-version: ${{ matrix.poetry-version }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: "poetry"
- name: Install dependencies - name: Install dependencies
run: | run: |
poetry config virtualenvs.create false
poetry install poetry install
- name: Run tests - name: Run tests
run: make pylint run: make pylint

28
.github/workflows/pyright.yml vendored Normal file
View file

@ -0,0 +1,28 @@
name: pyright
on: [push, pull_request]
jobs:
check:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
poetry-version: ["1.3.1"]
steps:
- uses: actions/checkout@v3
- name: Set up Poetry ${{ matrix.poetry-version }}
uses: abatilo/actions-poetry@v2
with:
poetry-version: ${{ matrix.poetry-version }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: "poetry"
- name: Install dependencies
run: |
poetry install
npm install
- name: Run tests
run: make pyright

View file

@ -4,13 +4,13 @@ all: format check requirements.txt
format: prettier isort black format: prettier isort black
check: mypy checkprettier checkisort checkblack check: mypy pyright pylint flake8 checkisort checkblack checkprettier
prettier: $(shell find lnbits -name "*.js" -o -name ".html") prettier: $(shell find lnbits -name "*.js" -o -name ".html")
./node_modules/.bin/prettier --write lnbits/static/js/*.js lnbits/core/static/js/*.js lnbits/extensions/*/templates/*/*.html ./lnbits/core/templates/core/*.html lnbits/templates/*.html lnbits/extensions/*/static/js/*.js lnbits/extensions/*/static/components/*/*.js lnbits/extensions/*/static/components/*/*.html ./node_modules/.bin/prettier --write lnbits/static/js/*.js lnbits/core/static/js/*.js lnbits/extensions/*/templates/*/*.html ./lnbits/core/templates/core/*.html lnbits/templates/*.html lnbits/extensions/*/static/js/*.js lnbits/extensions/*/static/components/*/*.js lnbits/extensions/*/static/components/*/*.html
pyright: pyright:
./node_modules/.bin/pyright poetry run ./node_modules/.bin/pyright
black: black:
poetry run black . poetry run black .

View file

@ -80,7 +80,6 @@ async def fetch_fiat_exchange_rate(currency: str, provider: str):
else: else:
data = {} data = {}
getter = exchange_rate_providers[provider]["getter"] getter = exchange_rate_providers[provider]["getter"]
print(getter) if not callable(getter):
if callable(getter): return None
rate = float(getter(data, replacements)) return float(getter(data, replacements))
return rate

View file

@ -68,10 +68,11 @@ class Proof(BaseModel):
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
assert "secret" in d, "no secret in proof" assert "secret" in d, "no secret in proof"
assert "amount" in d, "no amount in proof" assert "amount" in d, "no amount in proof"
assert "C" in d, "no C in proof"
return cls( return cls(
amount=d.get("amount"), amount=d["amount"],
C=d.get("C"), C=d["C"],
secret=d.get("secret"), secret=d["secret"],
reserved=d.get("reserved") or False, reserved=d.get("reserved") or False,
send_id=d.get("send_id") or "", send_id=d.get("send_id") or "",
time_created=d.get("time_created") or "", time_created=d.get("time_created") or "",

View file

@ -1,7 +1,8 @@
import math import math
from typing import Tuple
def si_classifier(val): def si_classifier(val) -> dict:
suffixes = { suffixes = {
24: {"long_suffix": "yotta", "short_suffix": "Y", "scalar": 10**24}, 24: {"long_suffix": "yotta", "short_suffix": "Y", "scalar": 10**24},
21: {"long_suffix": "zetta", "short_suffix": "Z", "scalar": 10**21}, 21: {"long_suffix": "zetta", "short_suffix": "Z", "scalar": 10**21},
@ -22,24 +23,22 @@ def si_classifier(val):
-24: {"long_suffix": "yocto", "short_suffix": "y", "scalar": 10**-24}, -24: {"long_suffix": "yocto", "short_suffix": "y", "scalar": 10**-24},
} }
exponent = int(math.floor(math.log10(abs(val)) / 3.0) * 3) exponent = int(math.floor(math.log10(abs(val)) / 3.0) * 3)
return suffixes.get(exponent, None) suffix = suffixes.get(exponent)
assert suffix, f"could not classify: {val}"
return suffix
def si_formatter(value): def si_formatter(value) -> Tuple:
""" """
Return a triple of scaled value, short suffix, long suffix, or None if Return a triple of scaled value, short suffix, long suffix, or None if
the value cannot be classified. the value cannot be classified.
""" """
classifier = si_classifier(value) classifier = si_classifier(value)
if classifier is None:
# Don't know how to classify this value
return None
scaled = value / classifier["scalar"] scaled = value / classifier["scalar"]
return (scaled, classifier["short_suffix"], classifier["long_suffix"]) return scaled, classifier["short_suffix"], classifier["long_suffix"]
def si_format(value, precision=4, long_form=False, separator=""): def si_format(value: float, precision=4, long_form=False, separator="") -> str:
""" """
"SI prefix" formatted string: return a string with the given precision "SI prefix" formatted string: return a string with the given precision
and an appropriate order-of-3-magnitudes suffix, e.g.: and an appropriate order-of-3-magnitudes suffix, e.g.:
@ -47,11 +46,6 @@ def si_format(value, precision=4, long_form=False, separator=""):
si_format(0.00000000123, long_form=True, separator=' ') => '1.230 nano' si_format(0.00000000123, long_form=True, separator=' ') => '1.230 nano'
""" """
scaled, short_suffix, long_suffix = si_formatter(value) scaled, short_suffix, long_suffix = si_formatter(value)
if scaled is None:
# Don't know how to format this value
return value
suffix = long_suffix if long_form else short_suffix suffix = long_suffix if long_form else short_suffix
if abs(scaled) < 10: if abs(scaled) < 10:

View file

@ -23,26 +23,30 @@ async def create_livestream(*, wallet_id: str) -> int:
if db.type == SQLITE: if db.type == SQLITE:
return result._result_proxy.lastrowid return result._result_proxy.lastrowid
else: else:
return result[0] return result[0] # type: ignore
async def get_livestream(ls_id: int) -> Optional[Livestream]: async def get_livestream(ls_id: int) -> Optional[Livestream]:
row = await db.fetchone( row = await db.fetchone(
"SELECT * FROM livestream.livestreams WHERE id = ?", (ls_id,) "SELECT * FROM livestream.livestreams WHERE id = ?", (ls_id,)
) )
return Livestream(**dict(row)) if row else None return Livestream(**row) if row else None
async def get_livestream_by_track(track_id: int) -> Optional[Livestream]: async def get_livestream_by_track(track_id: int) -> Optional[Livestream]:
row = await db.fetchone( row = await db.fetchone(
""" """
SELECT livestreams.* AS livestreams FROM livestream.livestreams SELECT * FROM livestream.tracks WHERE tracks.id = ?
INNER JOIN livestream.tracks AS tracks ON tracks.livestream = livestreams.id
WHERE tracks.id = ?
""", """,
(track_id,), (track_id,),
) )
return Livestream(**dict(row)) if row else None row2 = await db.fetchone(
"""
SELECT * FROM livestream.livestreams WHERE livestreams.id = ?
""",
(row.livestream,),
)
return Livestream(**row2) if row2 else None
async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream]: async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream]:
@ -55,7 +59,7 @@ async def get_or_create_livestream_by_wallet(wallet: str) -> Optional[Livestream
ls_id = await create_livestream(wallet_id=wallet) ls_id = await create_livestream(wallet_id=wallet)
return await get_livestream(ls_id) return await get_livestream(ls_id)
return Livestream(**dict(row)) if row else None return Livestream(**row) if row else None
async def update_current_track(ls_id: int, track_id: Optional[int]): async def update_current_track(ls_id: int, track_id: Optional[int]):
@ -121,7 +125,7 @@ async def get_track(track_id: Optional[int]) -> Optional[Track]:
""", """,
(track_id,), (track_id,),
) )
return Track(**dict(row)) if row else None return Track(**row) if row else None
async def get_tracks(livestream: int) -> List[Track]: async def get_tracks(livestream: int) -> List[Track]:
@ -132,7 +136,7 @@ async def get_tracks(livestream: int) -> List[Track]:
""", """,
(livestream,), (livestream,),
) )
return [Track(**dict(row)) for row in rows] return [Track(**row) for row in rows]
async def delete_track_from_livestream(livestream: int, track_id: int): async def delete_track_from_livestream(livestream: int, track_id: int):
@ -174,7 +178,7 @@ async def add_producer(livestream: int, name: str) -> int:
if db.type == SQLITE: if db.type == SQLITE:
return result._result_proxy.lastrowid return result._result_proxy.lastrowid
else: else:
return result[0] return result[0] # type: ignore
async def get_producer(producer_id: int) -> Optional[Producer]: async def get_producer(producer_id: int) -> Optional[Producer]:
@ -185,7 +189,7 @@ async def get_producer(producer_id: int) -> Optional[Producer]:
""", """,
(producer_id,), (producer_id,),
) )
return Producer(**dict(row)) if row else None return Producer(**row) if row else None
async def get_producers(livestream: int) -> List[Producer]: async def get_producers(livestream: int) -> List[Producer]:
@ -196,4 +200,4 @@ async def get_producers(livestream: int) -> List[Producer]:
""", """,
(livestream,), (livestream,),
) )
return [Producer(**dict(row)) for row in rows] return [Producer(**row) for row in rows]

View file

@ -1,4 +1,5 @@
import json import json
from sqlite3 import Row
from typing import Optional from typing import Optional
from fastapi import Query, Request from fastapi import Query, Request
@ -27,6 +28,10 @@ class Livestream(BaseModel):
url = request.url_for("livestream.lnurl_livestream", ls_id=self.id) url = request.url_for("livestream.lnurl_livestream", ls_id=self.id)
return lnurl_encode(url) return lnurl_encode(url)
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class Track(BaseModel): class Track(BaseModel):
id: int id: int
@ -35,6 +40,10 @@ class Track(BaseModel):
name: str name: str
producer: int producer: int
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
@property @property
def min_sendable(self) -> int: def min_sendable(self) -> int:
return min(100_000, self.price_msat or 100_000) return min(100_000, self.price_msat or 100_000)
@ -88,3 +97,7 @@ class Producer(BaseModel):
user: str user: str
wallet: str wallet: str
name: str name: str
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))

View file

@ -2,6 +2,7 @@ from http import HTTPStatus
from urllib.parse import urlparse from urllib.parse import urlparse
from fastapi import Depends, HTTPException, Query, Request from fastapi import Depends, HTTPException, Query, Request
from loguru import logger
from lnbits.core.crud import get_user from lnbits.core.crud import get_user
from lnbits.core.services import check_transaction_status, create_invoice from lnbits.core.services import check_transaction_status, create_invoice
@ -70,10 +71,9 @@ async def api_domain_create(
if not cf_response or not cf_response["success"]: if not cf_response or not cf_response["success"]:
await delete_domain(domain.id) await delete_domain(domain.id)
logger.error("Cloudflare failed with: " + cf_response["errors"][0]["message"]) # type: ignore
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST, detail="Problem with cloudflare."
detail="Problem with cloudflare: "
+ cf_response["errors"][0]["message"],
) )
return domain.dict() return domain.dict()

View file

@ -1,7 +1,7 @@
import json import json
from sqlite3 import Row from sqlite3 import Row
from typing import Dict, Optional from typing import Dict, Optional
from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse from urllib.parse import ParseResult, urlparse, urlunparse
from fastapi.param_functions import Query from fastapi.param_functions import Query
from lnurl.types import LnurlPayMetadata from lnurl.types import LnurlPayMetadata
@ -61,9 +61,9 @@ class PayLink(BaseModel):
def success_action(self, payment_hash: str) -> Optional[Dict]: def success_action(self, payment_hash: str) -> Optional[Dict]:
if self.success_url: if self.success_url:
url: ParseResult = urlparse(self.success_url) url: ParseResult = urlparse(self.success_url)
qs: Dict = parse_qs(url.query) #qs = parse_qs(url.query)
qs["payment_hash"] = payment_hash #setattr(qs, "payment_hash", payment_hash)
url = url._replace(query=urlencode(qs, doseq=True)) #url = url._replace(query=urlencode(qs, doseq=True))
return { return {
"tag": "url", "tag": "url",
"description": self.success_text or "~", "description": self.success_text or "~",

View file

@ -6,6 +6,7 @@ and delivery to the specific person
import json import json
from collections import defaultdict from collections import defaultdict
from typing import AsyncGenerator
from fastapi import WebSocket from fastapi import WebSocket
from loguru import logger from loguru import logger
@ -34,7 +35,7 @@ class Notifier:
# Create notification generator: # Create notification generator:
self.generator = self.get_notification_generator() self.generator = self.get_notification_generator()
async def get_notification_generator(self): async def get_notification_generator(self) -> AsyncGenerator:
"""Notification Generator""" """Notification Generator"""
while True: while True:
@ -54,9 +55,8 @@ class Notifier:
logger.exception(f"There is no member in room: {room_name}") logger.exception(f"There is no member in room: {room_name}")
return None return None
async def push(self, message: str, room_name: str = None): async def push(self, message: str, room_name: str):
"""Push a message""" """Push a message"""
message_body = {"message": message, "room_name": room_name} message_body = {"message": message, "room_name": room_name}
await self.generator.asend(message_body) await self.generator.asend(message_body)

View file

@ -1,15 +1,8 @@
import json
from http import HTTPStatus from http import HTTPStatus
from fastapi import ( from fastapi import Depends, Query, Request, WebSocket, WebSocketDisconnect
BackgroundTasks,
Depends,
Query,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from loguru import logger
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
@ -147,24 +140,14 @@ notifier = Notifier()
@market_ext.websocket("/ws/{room_name}") @market_ext.websocket("/ws/{room_name}")
async def websocket_endpoint( async def websocket_endpoint(websocket: WebSocket, room_name: str):
websocket: WebSocket, room_name: str, background_tasks: BackgroundTasks
):
await notifier.connect(websocket, room_name) await notifier.connect(websocket, room_name)
try: try:
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
d = json.loads(data) room_members = notifier.get_members(room_name) or []
d["room_name"] = room_name
room_members = (
notifier.get_members(room_name)
if notifier.get_members(room_name) is not None
else []
)
if websocket not in room_members: if websocket not in room_members:
print("Sender not in room member: Reconnecting...") logger.warning("Sender not in room member: Reconnecting...")
await notifier.connect(websocket, room_name) await notifier.connect(websocket, room_name)
await notifier._notify(data, room_name) await notifier._notify(data, room_name)

View file

@ -1,4 +1,5 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Optional
from fastapi import Depends, Query from fastapi import Depends, Query
from loguru import logger from loguru import logger
@ -224,7 +225,7 @@ async def api_market_stalls(
@market_ext.put("/api/v1/stalls/{stall_id}") @market_ext.put("/api/v1/stalls/{stall_id}")
async def api_market_stall_create( async def api_market_stall_create(
data: createStalls, data: createStalls,
stall_id: str = None, stall_id: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
): ):
@ -447,7 +448,7 @@ async def api_market_market_stalls(market_id: str):
@market_ext.put("/api/v1/markets/{market_id}") @market_ext.put("/api/v1/markets/{market_id}")
async def api_market_market_create( async def api_market_market_create(
data: CreateMarket, data: CreateMarket,
market_id: str = None, market_id: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
): ):
if market_id: if market_id:
@ -506,7 +507,7 @@ async def api_get_settings(wallet: WalletTypeInfo = Depends(require_admin_key)):
@market_ext.put("/api/v1/settings/{usr}") @market_ext.put("/api/v1/settings/{usr}")
async def api_set_settings( async def api_set_settings(
data: SetSettings, data: SetSettings,
usr: str = None, usr: Optional[str] = None,
wallet: WalletTypeInfo = Depends(require_admin_key), wallet: WalletTypeInfo = Depends(require_admin_key),
): ):
if usr: if usr:

View file

@ -2,7 +2,7 @@ from os import getenv
from fastapi import Depends, Request from fastapi import Depends, Request
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from pyngrok import conf, ngrok from pyngrok import conf, ngrok # type: ignore
from lnbits.core.models import User from lnbits.core.models import User
from lnbits.decorators import check_user_exists from lnbits.decorators import check_user_exists

View file

@ -22,12 +22,12 @@ async def create_shop(*, wallet_id: str) -> int:
if db.type == SQLITE: if db.type == SQLITE:
return result._result_proxy.lastrowid return result._result_proxy.lastrowid
else: else:
return result[0] return result[0] # type: ignore
async def get_shop(id: int) -> Optional[Shop]: async def get_shop(id: int) -> Optional[Shop]:
row = await db.fetchone("SELECT * FROM offlineshop.shops WHERE id = ?", (id,)) row = await db.fetchone("SELECT * FROM offlineshop.shops WHERE id = ?", (id,))
return Shop(**dict(row)) if row else None return Shop(**row) if row else None
async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]: async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]:
@ -40,7 +40,7 @@ async def get_or_create_shop_by_wallet(wallet: str) -> Optional[Shop]:
ls_id = await create_shop(wallet_id=wallet) ls_id = await create_shop(wallet_id=wallet)
return await get_shop(ls_id) return await get_shop(ls_id)
return Shop(**dict(row)) if row else None return Shop(**row) if row else None
async def set_method(shop: int, method: str, wordlist: str = "") -> Optional[Shop]: async def set_method(shop: int, method: str, wordlist: str = "") -> Optional[Shop]:

View file

@ -52,7 +52,7 @@ class ShopCounter:
# cleanup confirmation words cache # cleanup confirmation words cache
to_remove = len(self.fulfilled_payments) - 23 to_remove = len(self.fulfilled_payments) - 23
if to_remove > 0: if to_remove > 0:
for i in range(to_remove): for _ in range(to_remove):
self.fulfilled_payments.popitem(False) self.fulfilled_payments.popitem(False)
return word return word
@ -64,6 +64,10 @@ class Shop(BaseModel):
method: str method: str
wordlist: str wordlist: str
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
@property @property
def otp_key(self) -> str: def otp_key(self) -> str:
return base64.b32encode( return base64.b32encode(

View file

@ -18,8 +18,6 @@ from .crud import (
) )
from .models import CreateSatsDicePayment from .models import CreateSatsDicePayment
##############LNURLP STUFF
@satsdice_ext.get( @satsdice_ext.get(
"/api/v1/lnurlp/{link_id}", "/api/v1/lnurlp/{link_id}",
@ -84,7 +82,7 @@ async def api_lnurlp_callback(
data = CreateSatsDicePayment( data = CreateSatsDicePayment(
satsdice_pay=link.id, satsdice_pay=link.id,
value=amount_received / 1000, value=int(amount_received / 1000),
payment_hash=payment_hash, payment_hash=payment_hash,
) )

View file

@ -13,7 +13,7 @@ from .helpers import fetch_onchain_balance
from .models import Charges, CreateCharge, SatsPayThemes from .models import Charges, CreateCharge, SatsPayThemes
async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]: async def create_charge(user: str, data: CreateCharge) -> Charges:
data = CreateCharge(**data.dict()) data = CreateCharge(**data.dict())
charge_id = urlsafe_short_hash() charge_id = urlsafe_short_hash()
if data.onchainwallet: if data.onchainwallet:
@ -79,7 +79,9 @@ async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]:
data.custom_css, data.custom_css,
), ),
) )
return await get_charge(charge_id) charge = await get_charge(charge_id)
assert charge, "Newly created charge does not exist"
return charge
async def update_charge(charge_id: str, **kwargs) -> Optional[Charges]: async def update_charge(charge_id: str, **kwargs) -> Optional[Charges]:

View file

@ -10,7 +10,7 @@ async def get_targets(source_wallet: str) -> List[Target]:
rows = await db.fetchall( rows = await db.fetchall(
"SELECT * FROM splitpayments.targets WHERE source = ?", (source_wallet,) "SELECT * FROM splitpayments.targets WHERE source = ?", (source_wallet,)
) )
return [Target(**dict(row)) for row in rows] return [Target(**row) for row in rows]
async def set_targets(source_wallet: str, targets: List[Target]): async def set_targets(source_wallet: str, targets: List[Target]):

View file

@ -1,6 +1,7 @@
from sqlite3 import Row
from typing import List, Optional from typing import List, Optional
from fastapi.param_functions import Query from fastapi import Query
from pydantic import BaseModel from pydantic import BaseModel
@ -11,6 +12,10 @@ class Target(BaseModel):
tag: str tag: str
alias: Optional[str] alias: Optional[str]
@classmethod
def from_row(cls, row: Row):
return cls(**dict(row))
class TargetPutList(BaseModel): class TargetPutList(BaseModel):
wallet: str = Query(...) wallet: str = Query(...)

View file

@ -26,11 +26,11 @@ async def get_charge_details(service_id):
These might be different depending for services implemented in the future. These might be different depending for services implemented in the future.
""" """
service = await get_service(service_id) service = await get_service(service_id)
assert service assert service, f"Could not fetch service: {service_id}"
wallet_id = service.wallet wallet_id = service.wallet
wallet = await get_wallet(wallet_id) wallet = await get_wallet(wallet_id)
assert wallet assert wallet, f"Could not fetch wallet: {wallet_id}"
user = wallet.user user = wallet.user
return { return {
@ -147,14 +147,16 @@ async def create_service(data: CreateService) -> Service:
if db.type == SQLITE: if db.type == SQLITE:
service_id = result._result_proxy.lastrowid service_id = result._result_proxy.lastrowid
else: else:
service_id = result[0] service_id = result[0] # type: ignore
service = await get_service(service_id) service = await get_service(service_id)
assert service assert service, f"Could not fetch service: {service_id}"
return service return service
async def get_service(service_id: int, by_state: str = None) -> Optional[Service]: async def get_service(
service_id: int, by_state: Optional[str] = None
) -> Optional[Service]:
"""Return a service either by ID or, available, by state """Return a service either by ID or, available, by state
Each Service's donation page is reached through its "state" hash Each Service's donation page is reached through its "state" hash
@ -184,7 +186,9 @@ async def authenticate_service(service_id, code, redirect_uri):
"""Use authentication code from third party API to retreive access token""" """Use authentication code from third party API to retreive access token"""
# The API token is passed in the querystring as 'code' # The API token is passed in the querystring as 'code'
service = await get_service(service_id) service = await get_service(service_id)
assert service, f"Could not fetch service: {service_id}"
wallet = await get_wallet(service.wallet) wallet = await get_wallet(service.wallet)
assert wallet, f"Could not fetch wallet: {service.wallet}"
user = wallet.user user = wallet.user
url = "https://streamlabs.com/api/v1.0/token" url = "https://streamlabs.com/api/v1.0/token"
data = { data = {
@ -208,8 +212,11 @@ async def service_add_token(service_id, token):
is not overwritten. is not overwritten.
Tokens for Streamlabs never need to be refreshed. Tokens for Streamlabs never need to be refreshed.
""" """
if (await get_service(service_id)).authenticated: service = await get_service(service_id)
assert service, f"Could not fetch service: {service_id}"
if service.authenticated:
return False return False
await db.execute( await db.execute(
"UPDATE streamalerts.Services SET authenticated = 1, token = ? where id = ?", "UPDATE streamalerts.Services SET authenticated = 1, token = ? where id = ?",
(token, service_id), (token, service_id),

View file

@ -1,5 +1,3 @@
import json
import httpx import httpx
from .models import Domains from .models import Domains
@ -20,9 +18,7 @@ async def cloudflare_create_subdomain(
"Content-Type": "application/json", "Content-Type": "application/json",
} }
aRecord = subdomain + "." + domain.domain aRecord = subdomain + "." + domain.domain
cf_response = ""
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try:
r = await client.post( r = await client.post(
url, url,
headers=header, headers=header,
@ -35,10 +31,8 @@ async def cloudflare_create_subdomain(
}, },
timeout=40, timeout=40,
) )
cf_response = json.loads(r.text) r.raise_for_status()
except AssertionError: return r.json()
cf_response = "Error occured"
return cf_response
async def cloudflare_deletesubdomain(domain: Domains, domain_id: str): async def cloudflare_deletesubdomain(domain: Domains, domain_id: str):
@ -52,8 +46,4 @@ async def cloudflare_deletesubdomain(domain: Domains, domain_id: str):
"Content-Type": "application/json", "Content-Type": "application/json",
} }
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: await client.delete(url + "/" + domain_id, headers=header, timeout=40)
r = await client.delete(url + "/" + domain_id, headers=header, timeout=40)
r.text
except AssertionError:
pass

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import httpx import httpx
from loguru import logger
from lnbits.core.models import Payment from lnbits.core.models import Payment
from lnbits.helpers import get_current_extension_name from lnbits.helpers import get_current_extension_name
@ -21,7 +22,7 @@ async def wait_for_paid_invoices():
async def on_invoice_paid(payment: Payment) -> None: async def on_invoice_paid(payment: Payment) -> None:
if payment.extra.get("tag") != "lnsubdomain": if payment.extra.get("tag") != "lnsubdomain":
# not an lnurlp invoice # not an lnsubdomain invoice
return return
await payment.set_pending(False) await payment.set_pending(False)
@ -29,12 +30,17 @@ async def on_invoice_paid(payment: Payment) -> None:
domain = await get_domain(subdomain.domain) domain = await get_domain(subdomain.domain)
### Create subdomain ### Create subdomain
try:
cf_response = await cloudflare_create_subdomain( cf_response = await cloudflare_create_subdomain(
domain=domain, # type: ignore domain=domain, # type: ignore
subdomain=subdomain.subdomain, subdomain=subdomain.subdomain,
record_type=subdomain.record_type, record_type=subdomain.record_type,
ip=subdomain.ip, ip=subdomain.ip,
) )
except Exception as exc:
logger.error(exc)
logger.error("could not create subdomain on cloudflare")
return
### Use webhook to notify about cloudflare registration ### Use webhook to notify about cloudflare registration
if domain and domain.webhook: if domain and domain.webhook:

View file

@ -1,6 +1,7 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import Depends, Query from fastapi import Depends, Query
from loguru import logger
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from lnbits.core.crud import get_user from lnbits.core.crud import get_user
@ -121,21 +122,22 @@ async def api_subdomain_make_subdomain(domain_id, data: CreateSubdomain):
detail=f"{data.subdomain}.{domain.domain} domain already taken.", detail=f"{data.subdomain}.{domain.domain} domain already taken.",
) )
## Dry run cloudflare... (create and if create is sucessful delete it) ## Dry run cloudflare... (create and if create is successful delete it)
cf_response = await cloudflare_create_subdomain( try:
res_json = await cloudflare_create_subdomain(
domain=domain, domain=domain,
subdomain=data.subdomain, subdomain=data.subdomain,
record_type=data.record_type, record_type=data.record_type,
ip=data.ip, ip=data.ip,
) )
if cf_response["success"] is True:
await cloudflare_deletesubdomain( await cloudflare_deletesubdomain(
domain=domain, domain_id=cf_response["result"]["id"] domain=domain, domain_id=res_json["result"]["id"]
) )
else: except Exception as exc:
logger.warning(exc)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=f'Problem with cloudflare: {cf_response["errors"][0]["message"]}', detail="Problem with cloudflare.",
) )
## ALL OK - create an invoice and return it to the user ## ALL OK - create an invoice and return it to the user

View file

@ -1,20 +0,0 @@
from lnbits.core.crud import get_wallet
from .crud import get_tipjar
async def get_charge_details(tipjar_id):
"""Return the default details for a satspay charge"""
tipjar = await get_tipjar(tipjar_id)
wallet_id = tipjar.wallet
wallet = await get_wallet(wallet_id)
user = wallet.user
details = {
"time": 1440,
"user": user,
"lnbitswallet": wallet_id,
"onchainwallet": tipjar.onchain,
"completelink": "/tipjar/" + str(tipjar_id),
"completelinktext": "Thanks for the tip!",
}
return details

View file

@ -1,7 +1,6 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import Depends, Request from fastapi import Depends, Query, Request
from fastapi.param_functions import Query
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException

View file

@ -3,7 +3,7 @@ from http import HTTPStatus
from fastapi import Depends, Query from fastapi import Depends, Query
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from lnbits.core.crud import get_user from lnbits.core.crud import get_user, get_wallet
from lnbits.decorators import WalletTypeInfo, get_key_type from lnbits.decorators import WalletTypeInfo, get_key_type
# todo: use the API, not direct import # todo: use the API, not direct import
@ -22,7 +22,6 @@ from .crud import (
update_tip, update_tip,
update_tipjar, update_tipjar,
) )
from .helpers import get_charge_details
from .models import createTip, createTipJar, createTips from .models import createTip, createTipJar, createTips
@ -55,25 +54,32 @@ async def api_create_tip(data: createTips):
status_code=HTTPStatus.NOT_FOUND, detail="Tipjar does not exist." status_code=HTTPStatus.NOT_FOUND, detail="Tipjar does not exist."
) )
webhook = tipjar.webhook wallet_id = tipjar.wallet
charge_details = await get_charge_details(tipjar.id) wallet = await get_wallet(wallet_id)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Tipjar wallet does not exist."
)
name = data.name name = data.name
# Ensure that description string can be split reliably # Ensure that description string can be split reliably
name = name.replace('"', "''") name = name.replace('"', "''")
if not name: if not name:
name = "Anonymous" name = "Anonymous"
description = f"{name}: {message}" description = f"{name}: {message}"
charge = await create_charge( charge = await create_charge(
user=charge_details["user"], user=wallet.user,
data=CreateCharge( data=CreateCharge(
amount=sats, amount=sats,
webhook=webhook or "", webhook=tipjar.webhook or "",
description=description, description=description,
onchainwallet=charge_details["onchainwallet"], onchainwallet=tipjar.onchain or "",
lnbitswallet=charge_details["lnbitswallet"], lnbitswallet=tipjar.wallet,
completelink=charge_details["completelink"], completelink="/tipjar/" + str(tipjar_id),
completelinktext=charge_details["completelinktext"], completelinktext="Thanks for the tip!",
time=charge_details["time"], time=1440,
custom_css="", custom_css="",
), ),
) )

View file

@ -1,3 +1,5 @@
from typing import Optional, Tuple
from embit.descriptor import Descriptor, Key from embit.descriptor import Descriptor, Key
from embit.descriptor.arguments import AllowedDerivation from embit.descriptor.arguments import AllowedDerivation
from embit.networks import NETWORKS from embit.networks import NETWORKS
@ -12,11 +14,12 @@ def detect_network(k):
return net return net
def parse_key(masterpub: str) -> Descriptor: def parse_key(masterpub: str) -> Tuple[Descriptor, Optional[dict]]:
"""Parses masterpub or descriptor and returns a tuple: (Descriptor, network) """Parses masterpub or descriptor and returns a tuple: (Descriptor, network)
To create addresses use descriptor.derive(num).address(network=network) To create addresses use descriptor.derive(num).address(network=network)
""" """
network = None network = None
desc = None
# probably a single key # probably a single key
if "(" not in masterpub: if "(" not in masterpub:
k = Key.from_string(masterpub) k = Key.from_string(masterpub)
@ -47,8 +50,11 @@ def parse_key(masterpub: str) -> Descriptor:
desc = Descriptor.from_string("wpkh(%s)" % str(k)) desc = Descriptor.from_string("wpkh(%s)" % str(k))
break break
# we didn't find correct version # we didn't find correct version
if network is None: if not network:
raise ValueError("Unknown master public key version") raise ValueError("Unknown master public key version")
if not desc:
raise ValueError("descriptor not found, because version did not match")
else: else:
desc = Descriptor.from_string(masterpub) desc = Descriptor.from_string(masterpub)
if not desc.is_wildcard: if not desc.is_wildcard:
@ -61,6 +67,7 @@ def parse_key(masterpub: str) -> Descriptor:
if network is not None and network != net: if network is not None and network != net:
raise ValueError("Keys from different networks") raise ValueError("Keys from different networks")
network = net network = net
return desc, network return desc, network

View file

@ -73,7 +73,8 @@ async def api_wallet_create_or_update(
data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key) data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key)
): ):
try: try:
(descriptor, network) = parse_key(data.masterpub) descriptor, network = parse_key(data.masterpub)
assert network
if data.network != network["name"]: if data.network != network["name"]:
raise ValueError( raise ValueError(
"Account network error. This account is for '{}'".format( "Account network error. This account is for '{}'".format(
@ -308,12 +309,9 @@ async def api_psbt_utxos_tx(
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
@watchonly_ext.put("/api/v1/psbt/extract") @watchonly_ext.put("/api/v1/psbt/extract", dependencies=[Depends(require_admin_key)])
async def api_psbt_extract_tx( async def api_psbt_extract_tx(data: ExtractPsbt):
data: ExtractPsbt, w: WalletTypeInfo = Depends(require_admin_key)
):
network = NETWORKS["main"] if data.network == "Mainnet" else NETWORKS["test"] network = NETWORKS["main"] if data.network == "Mainnet" else NETWORKS["test"]
res = SignedTransaction()
try: try:
psbt = PSBT.from_base64(data.psbtBase64) psbt = PSBT.from_base64(data.psbtBase64)
for i, inp in enumerate(data.inputs): for i, inp in enumerate(data.inputs):
@ -322,9 +320,9 @@ async def api_psbt_extract_tx(
final_psbt = finalizer.finalize_psbt(psbt) final_psbt = finalizer.finalize_psbt(psbt)
if not final_psbt: if not final_psbt:
raise ValueError("PSBT cannot be finalized!") raise ValueError("PSBT cannot be finalized!")
res.tx_hex = final_psbt.to_string()
transaction = Transaction.from_string(res.tx_hex) tx_hex = final_psbt.to_string()
transaction = Transaction.from_string(tx_hex)
tx = { tx = {
"locktime": transaction.locktime, "locktime": transaction.locktime,
"version": transaction.version, "version": transaction.version,
@ -336,10 +334,10 @@ async def api_psbt_extract_tx(
tx["outputs"].append( tx["outputs"].append(
{"amount": out.value, "address": out.script_pubkey.address(network)} {"amount": out.value, "address": out.script_pubkey.address(network)}
) )
res.tx_json = json.dumps(tx) signed_tx = SignedTransaction(tx_hex=tx_hex, tx_json=json.dumps(tx))
return signed_tx.dict()
except Exception as e: except Exception as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
return res.dict()
@watchonly_ext.post("/api/v1/tx") @watchonly_ext.post("/api/v1/tx")

View file

@ -69,6 +69,9 @@ include = [
] ]
exclude = [ exclude = [
"lnbits/wallets/lnd_grpc_files", "lnbits/wallets/lnd_grpc_files",
"lnbits/wallets",
"lnbits/core",
"lnbits/*.py",
] ]
[tool.mypy] [tool.mypy]