refactor: extract models (#2759)

This commit is contained in:
Vlad Stan 2024-11-05 13:26:12 +02:00 committed by GitHub
parent acb1b1ed91
commit ba5f79da2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 791 additions and 716 deletions

View file

@ -21,16 +21,14 @@ from lnbits.core.crud import (
get_installed_extensions,
update_installed_extension_state,
)
from lnbits.core.extensions.extension_manager import (
deactivate_extension,
)
from lnbits.core.extensions.helpers import version_parse
from lnbits.core.helpers import migrate_extension_database
from lnbits.core.services.extensions import deactivate_extension, get_valid_extensions
from lnbits.core.tasks import ( # watchdog_task
killswitch_task,
wait_for_paid_invoices,
)
from lnbits.exceptions import register_exception_handlers
from lnbits.helpers import version_parse
from lnbits.settings import settings
from lnbits.tasks import (
cancel_all_tasks,
@ -48,7 +46,7 @@ from lnbits.wallets import get_funding_source, set_funding_source
from .commands import migrate_databases
from .core import init_core_routers
from .core.db import core_app_extra
from .core.extensions.models import Extension, ExtensionMeta, InstallableExtension
from .core.models.extensions import Extension, ExtensionMeta, InstallableExtension
from .core.services import check_admin_settings, check_webpush_settings
from .middleware import (
CustomGZipMiddleware,
@ -397,7 +395,7 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None:
async def check_and_register_extensions(app: FastAPI):
await check_installed_extensions(app)
for ext in Extension.get_valid_extensions(False):
for ext in await get_valid_extensions(False):
try:
register_ext_routes(app, ext)
except Exception as exc:

View file

@ -25,13 +25,13 @@ from lnbits.core.crud import (
remove_deleted_wallets,
update_payment,
)
from lnbits.core.extensions.models import (
from lnbits.core.helpers import is_valid_url, migrate_databases
from lnbits.core.models import Payment, PaymentState
from lnbits.core.models.extensions import (
CreateExtension,
ExtensionRelease,
InstallableExtension,
)
from lnbits.core.helpers import is_valid_url, migrate_databases
from lnbits.core.models import Payment, PaymentState
from lnbits.core.services import check_admin_settings
from lnbits.core.views.extension_api import (
api_install_extension,

View file

@ -1,7 +1,7 @@
from typing import Optional
from lnbits.core.db import db
from lnbits.core.extensions.models import (
from lnbits.core.models.extensions import (
InstallableExtension,
UserExtension,
)

View file

@ -1,56 +0,0 @@
import hashlib
from typing import Any, Optional
from urllib import request
import httpx
from loguru import logger
from packaging import version
from lnbits.settings import settings
def version_parse(v: str):
"""
Wrapper for version.parse() that does not throw if the version is invalid.
Instead it return the lowest possible version ("0.0.0")
"""
try:
return version.parse(v)
except Exception:
return version.parse("0.0.0")
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
headers = {"User-Agent": settings.user_agent}
if settings.lnbits_ext_github_token:
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
async with httpx.AsyncClient(headers=headers) as client:
resp = await client.get(url)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
def download_url(url, save_path):
with request.urlopen(url, timeout=60) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"

View file

@ -13,8 +13,8 @@ from lnbits.core.crud import (
update_migration_version,
)
from lnbits.core.db import db as core_db
from lnbits.core.extensions.models import InstallableExtension
from lnbits.core.models import DbVersion
from lnbits.core.models.extensions import InstallableExtension
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
from lnbits.settings import settings

View file

@ -1,490 +0,0 @@
from __future__ import annotations
import hashlib
import hmac
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Callable, Optional
from ecdsa import SECP256k1, SigningKey
from fastapi import Query
from passlib.context import CryptContext
from pydantic import BaseModel, Field, validator
from lnbits.db import FilterModel
from lnbits.helpers import url_for
from lnbits.lnurl import encode as lnurl_encode
from lnbits.settings import settings
from lnbits.utils.exchange_rates import allowed_currencies
from lnbits.wallets import get_funding_source
from lnbits.wallets.base import (
PaymentFailedStatus,
PaymentPendingStatus,
PaymentStatus,
PaymentSuccessStatus,
)
class BaseWallet(BaseModel):
id: str
name: str
adminkey: str
inkey: str
balance_msat: int
class Wallet(BaseModel):
id: str
user: str
name: str
adminkey: str
inkey: str
deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
currency: Optional[str] = None
balance_msat: int = Field(default=0, no_database=True)
@property
def balance(self) -> int:
return int(self.balance_msat // 1000)
@property
def withdrawable_balance(self) -> int:
from .services import fee_reserve
return self.balance_msat - fee_reserve(self.balance_msat)
@property
def lnurlwithdraw_full(self) -> str:
url = url_for("/withdraw", external=True, usr=self.user, wal=self.id)
try:
return lnurl_encode(url)
except Exception:
return ""
def lnurlauth_key(self, domain: str) -> SigningKey:
hashing_key = hashlib.sha256(self.id.encode()).digest()
linking_key = hmac.digest(hashing_key, domain.encode(), "sha256")
return SigningKey.from_string(
linking_key, curve=SECP256k1, hashfunc=hashlib.sha256
)
class KeyType(Enum):
admin = 0
invoice = 1
invalid = 2
# backwards compatibility
def __eq__(self, other):
return self.value == other
@dataclass
class WalletTypeInfo:
key_type: KeyType
wallet: Wallet
class UserExtra(BaseModel):
email_verified: Optional[bool] = False
first_name: Optional[str] = None
last_name: Optional[str] = None
display_name: Optional[str] = None
picture: Optional[str] = None
# Auth provider, possible values:
# - "env": the user was created automatically by the system
# - "lnbits": the user was created via register form (username/pass or user_id only)
# - "google | github | ...": the user was created using an SSO provider
provider: Optional[str] = "lnbits" # auth provider
class Account(BaseModel):
id: str
username: Optional[str] = None
password_hash: Optional[str] = None
pubkey: Optional[str] = None
email: Optional[str] = None
extra: UserExtra = UserExtra()
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@property
def is_super_user(self) -> bool:
return self.id == settings.super_user
@property
def is_admin(self) -> bool:
return self.id in settings.lnbits_admin_users or self.is_super_user
def hash_password(self, password: str) -> str:
"""sets and returns the hashed password"""
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
self.password_hash = pwd_context.hash(password)
return self.password_hash
def verify_password(self, password: str) -> bool:
"""returns True if the password matches the hash"""
if not self.password_hash:
return False
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.verify(password, self.password_hash)
class AccountOverview(Account):
transaction_count: Optional[int] = 0
wallet_count: Optional[int] = 0
balance_msat: Optional[int] = 0
last_payment: Optional[datetime] = None
class AccountFilters(FilterModel):
__search_fields__ = ["id", "email", "username"]
__sort_fields__ = [
"balance_msat",
"email",
"username",
"transaction_count",
"wallet_count",
"last_payment",
]
id: str
last_payment: Optional[datetime] = None
transaction_count: Optional[int] = None
wallet_count: Optional[int] = None
username: Optional[str] = None
email: Optional[str] = None
class User(BaseModel):
id: str
created_at: datetime
updated_at: datetime
email: Optional[str] = None
username: Optional[str] = None
pubkey: Optional[str] = None
extensions: list[str] = []
wallets: list[Wallet] = []
admin: bool = False
super_user: bool = False
has_password: bool = False
extra: UserExtra = UserExtra()
@property
def wallet_ids(self) -> list[str]:
return [wallet.id for wallet in self.wallets]
def get_wallet(self, wallet_id: str) -> Optional[Wallet]:
w = [wallet for wallet in self.wallets if wallet.id == wallet_id]
return w[0] if w else None
@classmethod
def is_extension_for_user(cls, ext: str, user: str) -> bool:
if ext not in settings.lnbits_admin_extensions:
return True
if user == settings.super_user:
return True
if user in settings.lnbits_admin_users:
return True
return False
class CreateUser(BaseModel):
email: Optional[str] = Query(default=None)
username: str = Query(default=..., min_length=2, max_length=20)
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
class UpdateUser(BaseModel):
user_id: str
email: Optional[str] = Query(default=None)
username: Optional[str] = Query(default=..., min_length=2, max_length=20)
extra: Optional[UserExtra] = None
class UpdateUserPassword(BaseModel):
user_id: str
password_old: Optional[str] = None
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
username: str = Query(default=..., min_length=2, max_length=20)
class UpdateUserPubkey(BaseModel):
user_id: str
pubkey: str = Query(default=..., max_length=64)
class ResetUserPassword(BaseModel):
reset_key: str
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
class UpdateSuperuserPassword(BaseModel):
username: str = Query(default=..., min_length=2, max_length=20)
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
class LoginUsr(BaseModel):
usr: str
class LoginUsernamePassword(BaseModel):
username: str
password: str
class AccessTokenPayload(BaseModel):
sub: str
usr: Optional[str] = None
email: Optional[str] = None
auth_time: Optional[int] = 0
class PaymentState(str, Enum):
PENDING = "pending"
SUCCESS = "success"
FAILED = "failed"
def __str__(self) -> str:
return self.value
class PaymentExtra(BaseModel):
comment: Optional[str] = None
success_action: Optional[str] = None
lnurl_response: Optional[str] = None
class PayInvoice(BaseModel):
payment_request: str
description: Optional[str] = None
max_sat: Optional[int] = None
extra: Optional[dict] = {}
class CreatePayment(BaseModel):
wallet_id: str
payment_hash: str
bolt11: str
amount_msat: int
memo: str
extra: Optional[dict] = {}
preimage: Optional[str] = None
expiry: Optional[datetime] = None
webhook: Optional[str] = None
fee: int = 0
class Payment(BaseModel):
checking_id: str
payment_hash: str
wallet_id: str
amount: int
fee: int
bolt11: str
status: str = PaymentState.PENDING
memo: Optional[str] = None
expiry: Optional[datetime] = None
webhook: Optional[str] = None
webhook_status: Optional[int] = None
preimage: Optional[str] = None
tag: Optional[str] = None
extension: Optional[str] = None
time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
extra: dict = {}
@property
def pending(self) -> bool:
return self.status == PaymentState.PENDING.value
@property
def success(self) -> bool:
return self.status == PaymentState.SUCCESS.value
@property
def failed(self) -> bool:
return self.status == PaymentState.FAILED.value
@property
def msat(self) -> int:
return self.amount
@property
def sat(self) -> int:
return self.amount // 1000
@property
def is_in(self) -> bool:
return self.amount > 0
@property
def is_out(self) -> bool:
return self.amount < 0
@property
def is_expired(self) -> bool:
return self.expiry < datetime.now(timezone.utc) if self.expiry else False
@property
def is_internal(self) -> bool:
return self.checking_id.startswith("internal_")
async def check_status(self) -> PaymentStatus:
if self.is_internal:
if self.success:
return PaymentSuccessStatus()
if self.failed:
return PaymentFailedStatus()
return PaymentPendingStatus()
funding_source = get_funding_source()
if self.is_out:
status = await funding_source.get_payment_status(self.checking_id)
else:
status = await funding_source.get_invoice_status(self.checking_id)
return status
class PaymentFilters(FilterModel):
__search_fields__ = ["memo", "amount"]
checking_id: str
amount: int
fee: int
memo: Optional[str]
time: datetime
bolt11: str
preimage: str
payment_hash: str
expiry: Optional[datetime]
extra: dict = {}
wallet_id: str
webhook: Optional[str]
webhook_status: Optional[int]
class PaymentHistoryPoint(BaseModel):
date: datetime
income: int
spending: int
balance: int
def _do_nothing(*_):
pass
class CoreAppExtra:
register_new_ext_routes: Callable = _do_nothing
register_new_ratelimiter: Callable
class TinyURL(BaseModel):
id: str
url: str
endless: bool
wallet: str
time: float
class ConversionData(BaseModel):
from_: str = "sat"
amount: float
to: str = "usd"
class Callback(BaseModel):
callback: str
class DecodePayment(BaseModel):
data: str
filter_fields: Optional[list[str]] = []
class CreateLnurl(BaseModel):
description_hash: str
callback: str
amount: int
comment: Optional[str] = None
description: Optional[str] = None
unit: Optional[str] = None
class CreateInvoice(BaseModel):
unit: str = "sat"
internal: bool = False
out: bool = True
amount: float = Query(None, ge=0)
memo: Optional[str] = None
description_hash: Optional[str] = None
unhashed_description: Optional[str] = None
expiry: Optional[int] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
bolt11: Optional[str] = None
lnurl_callback: Optional[str] = None
@validator("unit")
@classmethod
def unit_is_from_allowed_currencies(cls, v):
if v != "sat" and v not in allowed_currencies():
raise ValueError("The provided unit is not supported")
return v
class CreateTopup(BaseModel):
id: str
amount: int
class CreateLnurlAuth(BaseModel):
callback: str
class CreateWallet(BaseModel):
name: Optional[str] = None
class CreateWebPushSubscription(BaseModel):
subscription: str
class WebPushSubscription(BaseModel):
endpoint: str
user: str
data: str
host: str
timestamp: datetime
class BalanceDelta(BaseModel):
lnbits_balance_msats: int
node_balance_msats: int
@property
def delta_msats(self):
return self.node_balance_msats - self.lnbits_balance_msats
class SimpleStatus(BaseModel):
success: bool
message: str
class DbVersion(BaseModel):
db: str
version: int
class PayLnurlWData(BaseModel):
lnurl_w: str

View file

@ -0,0 +1,91 @@
from .lnurl import CreateLnurl, CreateLnurlAuth, PayLnurlWData
from .misc import (
BalanceDelta,
Callback,
ConversionData,
CoreAppExtra,
DbVersion,
SimpleStatus,
)
from .payments import (
CreateInvoice,
CreatePayment,
DecodePayment,
PayInvoice,
Payment,
PaymentExtra,
PaymentFilters,
PaymentHistoryPoint,
PaymentState,
)
from .tinyurl import TinyURL
from .users import (
AccessTokenPayload,
Account,
AccountFilters,
AccountOverview,
CreateTopup,
CreateUser,
LoginUsernamePassword,
LoginUsr,
ResetUserPassword,
UpdateSuperuserPassword,
UpdateUser,
UpdateUserPassword,
UpdateUserPubkey,
User,
UserExtra,
)
from .wallets import BaseWallet, CreateWallet, KeyType, Wallet, WalletTypeInfo
from .webpush import CreateWebPushSubscription, WebPushSubscription
__all__ = [
# lnurl
"CreateLnurl",
"CreateLnurlAuth",
"PayLnurlWData",
# misc
"BalanceDelta",
"Callback",
"ConversionData",
"CoreAppExtra",
"DbVersion",
"SimpleStatus",
# payments
"CreateInvoice",
"CreatePayment",
"DecodePayment",
"PayInvoice",
"Payment",
"PaymentExtra",
"PaymentFilters",
"PaymentHistoryPoint",
"PaymentState",
# tinyurl
"TinyURL",
# users
"AccessTokenPayload",
"Account",
"AccountFilters",
"AccountOverview",
"CreateTopup",
"CreateUser",
"LoginUsernamePassword",
"LoginUsr",
"ResetUserPassword",
"UpdateSuperuserPassword",
"UpdateUser",
"UpdateUserPassword",
"UpdateUserPubkey",
"User",
"UserExtra",
# wallets
"BaseWallet",
"CreateWallet",
"KeyType",
"Wallet",
"WalletTypeInfo",
# webpush
"CreateWebPushSubscription",
"WebPushSubscription",
]

View file

@ -14,15 +14,12 @@ import httpx
from loguru import logger
from pydantic import BaseModel
from lnbits.settings import settings
from .helpers import (
from lnbits.helpers import (
download_url,
file_hash,
github_api_get,
icon_to_github_url,
version_parse,
)
from lnbits.settings import settings
class ExplicitRelease(BaseModel):
@ -145,14 +142,9 @@ class UserExtension(BaseModel):
class Extension(NamedTuple):
code: str
is_valid: bool
is_admin_only: bool
name: Optional[str] = None
short_description: Optional[str] = None
tile: Optional[str] = None
contributors: Optional[list[str]] = None
hidden: bool = False
migration_module: Optional[str] = None
db_name: Optional[str] = None
upgrade_hash: Optional[str] = ""
@property
@ -175,76 +167,12 @@ class Extension(NamedTuple):
return Extension(
code=ext_info.id,
is_valid=True,
is_admin_only=False, # todo: is admin only
name=ext_info.name,
short_description=ext_info.short_description,
tile=ext_info.icon,
upgrade_hash=ext_info.hash if ext_info.module_installed else "",
)
@classmethod
def get_valid_extensions(
cls, include_deactivated: Optional[bool] = True
) -> list[Extension]:
valid_extensions = [
extension for extension in cls._extensions() if extension.is_valid
]
if include_deactivated:
return valid_extensions
if settings.lnbits_extensions_deactivate_all:
return []
return [
e
for e in valid_extensions
if e.code not in settings.lnbits_deactivated_extensions
]
@classmethod
def get_valid_extension(
cls, ext_id: str, include_deactivated: Optional[bool] = True
) -> Optional[Extension]:
all_extensions = cls.get_valid_extensions(include_deactivated)
return next((e for e in all_extensions if e.code == ext_id), None)
@classmethod
def _extensions(cls) -> list[Extension]:
p = Path(settings.lnbits_extensions_path, "extensions")
Path(p).mkdir(parents=True, exist_ok=True)
extension_folders: list[Path] = [f for f in p.iterdir() if f.is_dir()]
# todo: remove this property somehow, it is too expensive
output: list[Extension] = []
for extension_folder in extension_folders:
extension_code = extension_folder.parts[-1]
try:
with open(extension_folder / "config.json") as json_file:
config = json.load(json_file)
is_valid = True
is_admin_only = extension_code in settings.lnbits_admin_extensions
except Exception:
config = {}
is_valid = False
is_admin_only = False
output.append(
Extension(
extension_code,
is_valid,
is_admin_only,
config.get("name"),
config.get("short_description"),
config.get("tile"),
config.get("contributors"),
config.get("hidden") or False,
config.get("migration_module"),
config.get("db_name"),
)
)
return output
class ExtensionRelease(BaseModel):
name: str
@ -393,10 +321,6 @@ class InstallableExtension(BaseModel):
stars: int = 0
meta: Optional[ExtensionMeta] = None
@property
def is_admin_only(self) -> bool:
return self.id in settings.lnbits_admin_extensions
@property
def hash(self) -> str:
if self.meta and self.meta.installed_release:
@ -765,3 +689,23 @@ class ExtensionDetailsRequest(BaseModel):
ext_id: str
source_repo: str
version: str
async def github_api_get(url: str, error_msg: Optional[str]) -> Any:
headers = {"User-Agent": settings.user_agent}
if settings.lnbits_ext_github_token:
headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}"
async with httpx.AsyncClient(headers=headers) as client:
resp = await client.get(url)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"

View file

@ -0,0 +1,20 @@
from typing import Optional
from pydantic import BaseModel
class CreateLnurl(BaseModel):
description_hash: str
callback: str
amount: int
comment: Optional[str] = None
description: Optional[str] = None
unit: Optional[str] = None
class CreateLnurlAuth(BaseModel):
callback: str
class PayLnurlWData(BaseModel):
lnurl_w: str

View file

@ -0,0 +1,43 @@
from __future__ import annotations
from typing import Callable
from pydantic import BaseModel
def _do_nothing(*_):
pass
class CoreAppExtra:
register_new_ext_routes: Callable = _do_nothing
register_new_ratelimiter: Callable
class ConversionData(BaseModel):
from_: str = "sat"
amount: float
to: str = "usd"
class Callback(BaseModel):
callback: str
class BalanceDelta(BaseModel):
lnbits_balance_msats: int
node_balance_msats: int
@property
def delta_msats(self):
return self.node_balance_msats - self.lnbits_balance_msats
class SimpleStatus(BaseModel):
success: bool
message: str
class DbVersion(BaseModel):
db: str
version: int

View file

@ -0,0 +1,176 @@
from __future__ import annotations
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from fastapi import Query
from pydantic import BaseModel, Field, validator
from lnbits.db import FilterModel
from lnbits.utils.exchange_rates import allowed_currencies
from lnbits.wallets import get_funding_source
from lnbits.wallets.base import (
PaymentFailedStatus,
PaymentPendingStatus,
PaymentStatus,
PaymentSuccessStatus,
)
class PaymentState(str, Enum):
PENDING = "pending"
SUCCESS = "success"
FAILED = "failed"
def __str__(self) -> str:
return self.value
class PaymentExtra(BaseModel):
comment: Optional[str] = None
success_action: Optional[str] = None
lnurl_response: Optional[str] = None
class PayInvoice(BaseModel):
payment_request: str
description: Optional[str] = None
max_sat: Optional[int] = None
extra: Optional[dict] = {}
class CreatePayment(BaseModel):
wallet_id: str
payment_hash: str
bolt11: str
amount_msat: int
memo: str
extra: Optional[dict] = {}
preimage: Optional[str] = None
expiry: Optional[datetime] = None
webhook: Optional[str] = None
fee: int = 0
class Payment(BaseModel):
checking_id: str
payment_hash: str
wallet_id: str
amount: int
fee: int
bolt11: str
status: str = PaymentState.PENDING
memo: Optional[str] = None
expiry: Optional[datetime] = None
webhook: Optional[str] = None
webhook_status: Optional[int] = None
preimage: Optional[str] = None
tag: Optional[str] = None
extension: Optional[str] = None
time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
extra: dict = {}
@property
def pending(self) -> bool:
return self.status == PaymentState.PENDING.value
@property
def success(self) -> bool:
return self.status == PaymentState.SUCCESS.value
@property
def failed(self) -> bool:
return self.status == PaymentState.FAILED.value
@property
def msat(self) -> int:
return self.amount
@property
def sat(self) -> int:
return self.amount // 1000
@property
def is_in(self) -> bool:
return self.amount > 0
@property
def is_out(self) -> bool:
return self.amount < 0
@property
def is_expired(self) -> bool:
return self.expiry < datetime.now(timezone.utc) if self.expiry else False
@property
def is_internal(self) -> bool:
return self.checking_id.startswith("internal_")
async def check_status(self) -> PaymentStatus:
if self.is_internal:
if self.success:
return PaymentSuccessStatus()
if self.failed:
return PaymentFailedStatus()
return PaymentPendingStatus()
funding_source = get_funding_source()
if self.is_out:
status = await funding_source.get_payment_status(self.checking_id)
else:
status = await funding_source.get_invoice_status(self.checking_id)
return status
class PaymentFilters(FilterModel):
__search_fields__ = ["memo", "amount"]
checking_id: str
amount: int
fee: int
memo: Optional[str]
time: datetime
bolt11: str
preimage: str
payment_hash: str
expiry: Optional[datetime]
extra: dict = {}
wallet_id: str
webhook: Optional[str]
webhook_status: Optional[int]
class PaymentHistoryPoint(BaseModel):
date: datetime
income: int
spending: int
balance: int
class DecodePayment(BaseModel):
data: str
filter_fields: Optional[list[str]] = []
class CreateInvoice(BaseModel):
unit: str = "sat"
internal: bool = False
out: bool = True
amount: float = Query(None, ge=0)
memo: Optional[str] = None
description_hash: Optional[str] = None
unhashed_description: Optional[str] = None
expiry: Optional[int] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
bolt11: Optional[str] = None
lnurl_callback: Optional[str] = None
@validator("unit")
@classmethod
def unit_is_from_allowed_currencies(cls, v):
if v != "sat" and v not in allowed_currencies():
raise ValueError("The provided unit is not supported")
return v

View file

@ -0,0 +1,9 @@
from pydantic import BaseModel
class TinyURL(BaseModel):
id: str
url: str
endless: bool
wallet: str
time: float

177
lnbits/core/models/users.py Normal file
View file

@ -0,0 +1,177 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Optional
from fastapi import Query
from passlib.context import CryptContext
from pydantic import BaseModel, Field
from lnbits.db import FilterModel
from lnbits.settings import settings
from .wallets import Wallet
class UserExtra(BaseModel):
email_verified: Optional[bool] = False
first_name: Optional[str] = None
last_name: Optional[str] = None
display_name: Optional[str] = None
picture: Optional[str] = None
# Auth provider, possible values:
# - "env": the user was created automatically by the system
# - "lnbits": the user was created via register form (username/pass or user_id only)
# - "google | github | ...": the user was created using an SSO provider
provider: Optional[str] = "lnbits" # auth provider
class Account(BaseModel):
id: str
username: Optional[str] = None
password_hash: Optional[str] = None
pubkey: Optional[str] = None
email: Optional[str] = None
extra: UserExtra = UserExtra()
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@property
def is_super_user(self) -> bool:
return self.id == settings.super_user
@property
def is_admin(self) -> bool:
return self.id in settings.lnbits_admin_users or self.is_super_user
def hash_password(self, password: str) -> str:
"""sets and returns the hashed password"""
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
self.password_hash = pwd_context.hash(password)
return self.password_hash
def verify_password(self, password: str) -> bool:
"""returns True if the password matches the hash"""
if not self.password_hash:
return False
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.verify(password, self.password_hash)
class AccountOverview(Account):
transaction_count: Optional[int] = 0
wallet_count: Optional[int] = 0
balance_msat: Optional[int] = 0
last_payment: Optional[datetime] = None
class AccountFilters(FilterModel):
__search_fields__ = ["id", "email", "username"]
__sort_fields__ = [
"balance_msat",
"email",
"username",
"transaction_count",
"wallet_count",
"last_payment",
]
id: str
last_payment: Optional[datetime] = None
transaction_count: Optional[int] = None
wallet_count: Optional[int] = None
username: Optional[str] = None
email: Optional[str] = None
class User(BaseModel):
id: str
created_at: datetime
updated_at: datetime
email: Optional[str] = None
username: Optional[str] = None
pubkey: Optional[str] = None
extensions: list[str] = []
wallets: list[Wallet] = []
admin: bool = False
super_user: bool = False
has_password: bool = False
extra: UserExtra = UserExtra()
@property
def wallet_ids(self) -> list[str]:
return [wallet.id for wallet in self.wallets]
def get_wallet(self, wallet_id: str) -> Optional[Wallet]:
w = [wallet for wallet in self.wallets if wallet.id == wallet_id]
return w[0] if w else None
@classmethod
def is_extension_for_user(cls, ext: str, user: str) -> bool:
if ext not in settings.lnbits_admin_extensions:
return True
if user == settings.super_user:
return True
if user in settings.lnbits_admin_users:
return True
return False
class CreateUser(BaseModel):
email: Optional[str] = Query(default=None)
username: str = Query(default=..., min_length=2, max_length=20)
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
class UpdateUser(BaseModel):
user_id: str
email: Optional[str] = Query(default=None)
username: Optional[str] = Query(default=..., min_length=2, max_length=20)
extra: Optional[UserExtra] = None
class UpdateUserPassword(BaseModel):
user_id: str
password_old: Optional[str] = None
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
username: str = Query(default=..., min_length=2, max_length=20)
class UpdateUserPubkey(BaseModel):
user_id: str
pubkey: str = Query(default=..., max_length=64)
class ResetUserPassword(BaseModel):
reset_key: str
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
class UpdateSuperuserPassword(BaseModel):
username: str = Query(default=..., min_length=2, max_length=20)
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
class LoginUsr(BaseModel):
usr: str
class LoginUsernamePassword(BaseModel):
username: str
password: str
class AccessTokenPayload(BaseModel):
sub: str
usr: Optional[str] = None
email: Optional[str] = None
auth_time: Optional[int] = 0
class CreateTopup(BaseModel):
id: str
amount: int

View file

@ -0,0 +1,80 @@
from __future__ import annotations
import hashlib
import hmac
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from ecdsa import SECP256k1, SigningKey
from pydantic import BaseModel, Field
from lnbits.helpers import url_for
from lnbits.lnurl import encode as lnurl_encode
from lnbits.settings import settings
class BaseWallet(BaseModel):
id: str
name: str
adminkey: str
inkey: str
balance_msat: int
class Wallet(BaseModel):
id: str
user: str
name: str
adminkey: str
inkey: str
deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
currency: Optional[str] = None
balance_msat: int = Field(default=0, no_database=True)
@property
def balance(self) -> int:
return int(self.balance_msat // 1000)
@property
def withdrawable_balance(self) -> int:
return self.balance_msat - settings.fee_reserve(self.balance_msat)
@property
def lnurlwithdraw_full(self) -> str:
url = url_for("/withdraw", external=True, usr=self.user, wal=self.id)
try:
return lnurl_encode(url)
except Exception:
return ""
def lnurlauth_key(self, domain: str) -> SigningKey:
hashing_key = hashlib.sha256(self.id.encode()).digest()
linking_key = hmac.digest(hashing_key, domain.encode(), "sha256")
return SigningKey.from_string(
linking_key, curve=SECP256k1, hashfunc=hashlib.sha256
)
class CreateWallet(BaseModel):
name: Optional[str] = None
class KeyType(Enum):
admin = 0
invoice = 1
invalid = 2
# backwards compatibility
def __eq__(self, other):
return self.value == other
@dataclass
class WalletTypeInfo:
key_type: KeyType
wallet: Wallet

View file

@ -0,0 +1,15 @@
from datetime import datetime
from pydantic import BaseModel
class CreateWebPushSubscription(BaseModel):
subscription: str
class WebPushSubscription(BaseModel):
endpoint: str
user: str
data: str
host: str
timestamp: datetime

View file

@ -1,5 +1,6 @@
import asyncio
import importlib
from typing import Optional
from loguru import logger
@ -11,10 +12,11 @@ from lnbits.core.crud import (
get_installed_extension,
update_installed_extension_state,
)
from lnbits.core.crud.extensions import get_installed_extensions
from lnbits.core.helpers import migrate_extension_database
from lnbits.settings import settings
from .models import Extension, InstallableExtension
from ..models.extensions import Extension, InstallableExtension
async def install_extension(ext_info: InstallableExtension) -> Extension:
@ -70,7 +72,7 @@ async def stop_extension_background_work(ext_id: str) -> bool:
Extensions SHOULD expose a `api_stop()` function.
"""
upgrade_hash = settings.lnbits_upgraded_extensions.get(ext_id, "")
ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash)
ext = Extension(ext_id, True, upgrade_hash=upgrade_hash)
try:
logger.info(f"Stopping background work for extension '{ext.module_name}'.")
@ -96,3 +98,38 @@ async def stop_extension_background_work(ext_id: str) -> bool:
return False
return True
async def get_valid_extensions(
include_deactivated: Optional[bool] = True,
) -> list[Extension]:
installed_extensions = await get_installed_extensions()
valid_extensions = [Extension.from_installable_ext(e) for e in installed_extensions]
if include_deactivated:
return valid_extensions
if settings.lnbits_extensions_deactivate_all:
return []
return [
e
for e in valid_extensions
if e.code not in settings.lnbits_deactivated_extensions
]
async def get_valid_extension(
ext_id: str, include_deactivated: Optional[bool] = True
) -> Optional[Extension]:
ext = await get_installed_extension(ext_id)
if not ext:
return None
if include_deactivated:
return Extension.from_installable_ext(ext)
if settings.lnbits_extensions_deactivate_all:
return None
return Extension.from_installable_ext(ext)

View file

@ -175,14 +175,8 @@ def fee_reserve_total(amount_msat: int, internal: bool = False) -> int:
return fee_reserve(amount_msat, internal) + service_fee(amount_msat, internal)
# WARN: this same value must be used for balance check and passed to
# funding_source.pay_invoice(), it may cause a vulnerability if the values differ
def fee_reserve(amount_msat: int, internal: bool = False) -> int:
if internal:
return 0
reserve_min = settings.lnbits_reserve_fee_min
reserve_percent = settings.lnbits_reserve_fee_percent
return max(int(reserve_min), int(amount_msat * reserve_percent / 100.0))
return settings.fee_reserve(amount_msat, internal)
def service_fee(amount_msat: int, internal: bool = False) -> int:

View file

@ -4,7 +4,7 @@ from uuid import UUID, uuid4
from loguru import logger
from lnbits.core.extensions.models import UserExtension
from lnbits.core.models.extensions import UserExtension
from lnbits.settings import (
EditableSettings,
SuperSettings,

View file

@ -137,7 +137,7 @@
multiple
hint="Extensions only user with admin privileges can use"
label="Admin extensions"
:options="g.extensions.map(e => e.code)"
:options="g.extensions"
></q-select>
</div>
@ -149,7 +149,7 @@
multiple
hint="Extensions that will be enabled by default for the users."
label="User extensions"
:options="g.extensions.map(e => e.code)"
:options="g.extensions"
></q-select>
</div>
<div class="col-12 col-md-6">

View file

@ -410,7 +410,7 @@ def _new_sso(provider: str) -> Optional[SSOBase]:
def _find_auth_provider_class(provider: str) -> Callable:
sso_modules = ["lnbits.core.sso", "fastapi_sso.sso"]
sso_modules = ["lnbits.core.models.sso", "fastapi_sso.sso"]
for module in sso_modules:
try:
provider_module = importlib.import_module(f"{module}.{provider}")

View file

@ -10,13 +10,12 @@ from fastapi import (
)
from loguru import logger
from lnbits.core.extensions.extension_manager import (
activate_extension,
deactivate_extension,
install_extension,
uninstall_extension,
from lnbits.core.crud.extensions import get_user_extensions
from lnbits.core.models import (
SimpleStatus,
User,
)
from lnbits.core.extensions.models import (
from lnbits.core.models.extensions import (
CreateExtension,
Extension,
ExtensionConfig,
@ -28,11 +27,15 @@ from lnbits.core.extensions.models import (
UserExtension,
UserExtensionInfo,
)
from lnbits.core.models import (
SimpleStatus,
User,
)
from lnbits.core.services import check_transaction_status, create_invoice
from lnbits.core.services.extensions import (
activate_extension,
deactivate_extension,
get_valid_extension,
get_valid_extensions,
install_extension,
uninstall_extension,
)
from lnbits.decorators import (
check_admin,
check_user_exists,
@ -168,7 +171,7 @@ async def api_update_pay_to_enable(
async def api_enable_extension(
ext_id: str, user: User = Depends(check_user_exists)
) -> SimpleStatus:
if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
if ext_id not in [e.code for e in await get_valid_extensions()]:
raise HTTPException(
HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' doesn't exist."
)
@ -236,7 +239,7 @@ async def api_enable_extension(
async def api_disable_extension(
ext_id: str, user: User = Depends(check_user_exists)
) -> SimpleStatus:
if ext_id not in [e.code for e in Extension.get_valid_extensions()]:
if ext_id not in [e.code for e in await get_valid_extensions()]:
raise HTTPException(
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
)
@ -256,7 +259,7 @@ async def api_activate_extension(ext_id: str) -> SimpleStatus:
try:
logger.info(f"Activating extension: '{ext_id}'.")
ext = Extension.get_valid_extension(ext_id)
ext = await get_valid_extension(ext_id)
assert ext, f"Extension '{ext_id}' doesn't exist."
await activate_extension(ext)
@ -275,7 +278,7 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus:
try:
logger.info(f"Deactivating extension: '{ext_id}'.")
ext = Extension.get_valid_extension(ext_id)
ext = await get_valid_extension(ext_id)
assert ext, f"Extension '{ext_id}' doesn't exist."
await deactivate_extension(ext_id)
@ -300,7 +303,7 @@ async def api_uninstall_extension(ext_id: str) -> SimpleStatus:
installed_extensions = await get_installed_extensions()
# check that other extensions do not depend on this one
for valid_ext_id in [ext.code for ext in Extension.get_valid_extensions()]:
for valid_ext_id in [ext.code for ext in await get_valid_extensions()]:
installed_ext = next(
(ext for ext in installed_extensions if ext.id == valid_ext_id), None
)
@ -453,7 +456,7 @@ async def get_pay_to_enable_invoice(
@extension_router.get(
"/release/{org}/{repo}/{tag_name}",
dependencies=[Depends(check_admin)],
dependencies=[Depends(check_user_exists)],
)
async def get_extension_release(org: str, repo: str, tag_name: str):
try:
@ -472,6 +475,19 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
) from exc
@extension_router.get("")
async def api_get_user_extensions(
user: User = Depends(check_user_exists),
) -> list[Extension]:
user_extensions_ids = [ue.extension for ue in await get_user_extensions(user.id)]
return [
ext
for ext in await get_valid_extensions(False)
if ext.code in user_extensions_ids
]
@extension_router.delete(
"/{ext_id}/db",
dependencies=[Depends(check_admin)],

View file

@ -11,10 +11,11 @@ from fastapi.routing import APIRouter
from lnurl import decode as lnurl_decode
from pydantic.types import UUID4
from lnbits.core.extensions.models import Extension, ExtensionMeta, InstallableExtension
from lnbits.core.helpers import to_valid_user_id
from lnbits.core.models import User
from lnbits.core.models.extensions import ExtensionMeta, InstallableExtension
from lnbits.core.services import create_invoice, create_user_account
from lnbits.core.services.extensions import get_valid_extensions
from lnbits.decorators import check_admin, check_user_exists
from lnbits.helpers import template_renderer
from lnbits.settings import settings
@ -102,7 +103,7 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)):
e.short_description = installed_ext.short_description
e.icon = installed_ext.icon
all_ext_ids = [ext.code for ext in Extension.get_valid_extensions()]
all_ext_ids = [ext.code for ext in await get_valid_extensions()]
inactive_extensions = [e.id for e in await get_installed_extensions(active=False)]
db_versions = await get_db_versions()

View file

@ -1,15 +1,17 @@
import hashlib
import json
import re
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Optional, Type
from urllib import request
import jinja2
import jwt
import shortuuid
from packaging import version
from pydantic.schema import field_schema
from lnbits.core.extensions.models import Extension
from lnbits.jinja2_templating import Jinja2Templates
from lnbits.nodes import get_node_class
from lnbits.requestvars import g
@ -91,7 +93,8 @@ def template_renderer(additional_folders: Optional[list] = None) -> Jinja2Templa
settings.lnbits_node_ui and get_node_class() is not None
)
t.env.globals["LNBITS_NODE_UI_AVAILABLE"] = get_node_class() is not None
t.env.globals["EXTENSIONS"] = Extension.get_valid_extensions(False)
t.env.globals["EXTENSIONS"] = list(settings.lnbits_all_extensions_ids)
if settings.lnbits_custom_logo:
t.env.globals["USE_CUSTOM_LOGO"] = settings.lnbits_custom_logo
@ -211,3 +214,31 @@ def filter_dict_keys(data: dict, filter_keys: Optional[list[str]]) -> dict:
# return shallow clone of the dict even if there are no filters
return {**data}
return {key: data[key] for key in filter_keys if key in data}
def version_parse(v: str):
"""
Wrapper for version.parse() that does not throw if the version is invalid.
Instead it return the lowest possible version ("0.0.0")
"""
try:
# todo: handle -rc0x
return version.parse(v)
except Exception:
return version.parse("0.0.0")
def download_url(url, save_path):
with request.urlopen(url, timeout=60) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()

View file

@ -223,14 +223,27 @@ class ThemesSettings(LNbitsSettings):
class OpsSettings(LNbitsSettings):
lnbits_baseurl: str = Field(default="http://127.0.0.1:5000/")
lnbits_hide_api: bool = Field(default=False)
lnbits_denomination: str = Field(default="sats")
class FeeSettings(LNbitsSettings):
lnbits_reserve_fee_min: int = Field(default=2000)
lnbits_reserve_fee_percent: float = Field(default=1.0)
lnbits_service_fee: float = Field(default=0)
lnbits_service_fee_ignore_internal: bool = Field(default=True)
lnbits_service_fee_max: int = Field(default=0)
lnbits_service_fee_wallet: Optional[str] = Field(default=None)
lnbits_hide_api: bool = Field(default=False)
lnbits_denomination: str = Field(default="sats")
# WARN: this same value must be used for balance check and passed to
# funding_source.pay_invoice(), it may cause a vulnerability if the values differ
def fee_reserve(self, amount_msat: int, internal: bool = False) -> int:
if internal:
return 0
reserve_min = self.lnbits_reserve_fee_min
reserve_percent = self.lnbits_reserve_fee_percent
return max(int(reserve_min), int(amount_msat * reserve_percent / 100.0))
class SecuritySettings(LNbitsSettings):
@ -489,6 +502,7 @@ class EditableSettings(
ExtensionsSettings,
ThemesSettings,
OpsSettings,
FeeSettings,
SecuritySettings,
FundingSourcesSettings,
LightningSettings,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -217,8 +217,7 @@ window.LNbits = {
'name',
'shortDescription',
'tile',
'contributors',
'hidden'
'contributors'
],
data
)
@ -719,33 +718,7 @@ window.windowMixin = {
this.g.wallet = Object.freeze(window.LNbits.map.wallet(window.wallet))
}
if (window.extensions) {
const user = this.g.user
const extensions = Object.freeze(
window.extensions
.map(function (data) {
return window.LNbits.map.extension(data)
})
.filter(function (obj) {
return !obj.hidden
})
.filter(function (obj) {
if (window.user?.admin) return obj
return !obj.isAdminOnly
})
.map(function (obj) {
if (user) {
obj.isEnabled = user.extensions.indexOf(obj.code) !== -1
} else {
obj.isEnabled = false
}
return obj
})
.sort(function (a, b) {
const nameA = a.name.toUpperCase()
const nameB = b.name.toUpperCase()
return nameA < nameB ? -1 : nameA > nameB ? 1 : 0
})
)
const extensions = Object.freeze(window.extensions)
this.g.extensions = extensions
}

View file

@ -87,19 +87,22 @@ window.app.component('lnbits-extension-list', {
})
}
},
created: function () {
if (window.extensions) {
this.extensions = window.extensions
created: async function () {
if (window.user) {
this.user = LNbits.map.user(window.user)
}
try {
const {data} = await LNbits.api.request('GET', '/api/v1/extension')
this.extensions = data
.map(function (data) {
return LNbits.map.extension(data)
})
.sort(function (a, b) {
return a.name.localeCompare(b.name)
})
}
if (window.user) {
this.user = LNbits.map.user(window.user)
} catch (error) {
LNbits.utils.notifyApiError(error)
}
}
})

View file

@ -1,7 +1,6 @@
import pytest
from lnbits.core.services import (
fee_reserve,
fee_reserve_total,
service_fee,
)
@ -9,8 +8,8 @@ from lnbits.settings import Settings
@pytest.mark.asyncio
async def test_fee_reserve_internal():
fee = fee_reserve(10_000, internal=True)
async def test_fee_reserve_internal(settings: Settings):
fee = settings.fee_reserve(10_000, internal=True)
assert fee == 0
@ -18,7 +17,7 @@ async def test_fee_reserve_internal():
async def test_fee_reserve_min(settings: Settings):
settings.lnbits_reserve_fee_percent = 2
settings.lnbits_reserve_fee_min = 500
fee = fee_reserve(10000)
fee = settings.fee_reserve(10000)
assert fee == 500
@ -26,7 +25,7 @@ async def test_fee_reserve_min(settings: Settings):
async def test_fee_reserve_percent(settings: Settings):
settings.lnbits_reserve_fee_percent = 1
settings.lnbits_reserve_fee_min = 100
fee = fee_reserve(100000)
fee = settings.fee_reserve(100000)
assert fee == 1000
@ -70,6 +69,6 @@ async def test_fee_reserve_total(settings: Settings):
settings.lnbits_service_fee_wallet = "wallet_id"
amount = 100_000
fee = service_fee(amount)
reserve = fee_reserve(amount)
reserve = settings.fee_reserve(amount)
total = fee_reserve_total(amount)
assert fee + reserve == total