diff --git a/lnbits/wallets/fake.py b/lnbits/wallets/fake.py index 9cc498485..4725baff6 100644 --- a/lnbits/wallets/fake.py +++ b/lnbits/wallets/fake.py @@ -1,9 +1,8 @@ import asyncio import hashlib -import random from datetime import datetime from os import urandom -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, Dict, Optional, Set from bolt11 import ( Bolt11, @@ -29,6 +28,8 @@ from .base import ( class FakeWallet(Wallet): queue: asyncio.Queue = asyncio.Queue(0) + payment_secrets: Dict[str, str] = dict() + paid_invoices: Set[str] = set() secret: str = settings.fake_wallet_secret privkey: str = hashlib.pbkdf2_hmac( "sha256", @@ -70,20 +71,18 @@ class FakeWallet(Wallet): if expiry: tags.add(TagChar.expire_time, expiry) - # random hash - checking_id = ( - self.privkey[:6] - + hashlib.sha256(str(random.getrandbits(256)).encode()).hexdigest()[6:] - ) - - tags.add(TagChar.payment_hash, checking_id) - if payment_secret: secret = payment_secret.hex() else: secret = urandom(32).hex() tags.add(TagChar.payment_secret, secret) + payment_hash = hashlib.sha256(secret.encode()).hexdigest() + + tags.add(TagChar.payment_hash, payment_hash) + + self.payment_secrets[payment_hash] = secret + bolt11 = Bolt11( currency="bc", amount_msat=MilliSatoshi(amount * 1000), @@ -93,7 +92,9 @@ class FakeWallet(Wallet): payment_request = encode(bolt11, self.privkey) - return InvoiceResponse(True, checking_id, payment_request) + return InvoiceResponse( + ok=True, checking_id=payment_hash, payment_request=payment_request + ) async def pay_invoice(self, bolt11: str, _: int) -> PaymentResponse: try: @@ -101,16 +102,23 @@ class FakeWallet(Wallet): except Bolt11Exception as exc: return PaymentResponse(ok=False, error_message=str(exc)) - if invoice.payment_hash[:6] == self.privkey[:6]: + if invoice.payment_hash in self.payment_secrets: await self.queue.put(invoice) - return PaymentResponse(True, invoice.payment_hash, 0) + self.paid_invoices.add(invoice.payment_hash) + return PaymentResponse( + ok=True, + checking_id=invoice.payment_hash, + fee_msat=0, + preimage=self.payment_secrets.get(invoice.payment_hash) or "0" * 64, + ) else: return PaymentResponse( ok=False, error_message="Only internal invoices can be used!" ) - async def get_invoice_status(self, _: str) -> PaymentStatus: - return PaymentStatus(None) + async def get_invoice_status(self, checking_id: str) -> PaymentStatus: + paid = checking_id in self.paid_invoices + return PaymentStatus(paid) async def get_payment_status(self, _: str) -> PaymentStatus: return PaymentStatus(None)