From 7d1e22c7de7e628c637938b8420fdc6dd8e98026 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Mon, 15 Jul 2024 13:34:26 +0300 Subject: [PATCH] fix: always create default wallet for user (#2580) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: always create default wallet for user * no assert in api --------- Co-authored-by: dni ⚡ --- lnbits/core/services.py | 3 +++ lnbits/core/views/api.py | 7 ++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lnbits/core/services.py b/lnbits/core/services.py index 155ff470c..ce4df160e 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -785,6 +785,7 @@ async def create_user_account( email: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, + wallet_name: Optional[str] = None, user_config: Optional[UserConfig] = None, ) -> User: if not settings.new_accounts_allowed: @@ -805,6 +806,8 @@ async def create_user_account( password = pwd_context.hash(password) if password else None account = await create_account(user_id, username, email, password, user_config) + wallet = await create_wallet(user_id=account.id, wallet_name=wallet_name) + account.wallets = [wallet] for ext_id in settings.lnbits_user_default_extensions: await update_user_extension(user_id=account.id, extension=ext_id, active=True) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 9a9f2ac74..5202c9269 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -36,9 +36,6 @@ from lnbits.utils.exchange_rates import ( satoshis_amount_as_fiat, ) -from ..crud import ( - create_wallet, -) from ..services import create_user_account, perform_lnurlauth # backwards compatibility for extension @@ -69,8 +66,8 @@ async def api_create_account(data: CreateWallet) -> Wallet: status_code=HTTPStatus.FORBIDDEN, detail="Account creation is disabled.", ) - account = await create_user_account() - return await create_wallet(user_id=account.id, wallet_name=data.name) + account = await create_user_account(wallet_name=data.name) + return account.wallets[0] @api_router.get("/api/v1/lnurlscan/{code}")