diff --git a/lnbits/extensions/lnurlpos/lnurl.py b/lnbits/extensions/lnurlpos/lnurl.py index dccacef03..2ca1b0f88 100644 --- a/lnbits/extensions/lnurlpos/lnurl.py +++ b/lnbits/extensions/lnurlpos/lnurl.py @@ -3,6 +3,12 @@ import hashlib from http import HTTPStatus from typing import Optional +from embit import bech32 +from embit import compact +import base64 +from io import BytesIO +import hmac + from fastapi import Request from fastapi.param_functions import Query from starlette.exceptions import HTTPException @@ -18,39 +24,73 @@ from .crud import ( update_lnurlpospayment, ) +def bech32_decode(bech): + """tweaked version of bech32_decode that ignores length limitations""" + if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or + (bech.lower() != bech and bech.upper() != bech)): + return + bech = bech.lower() + pos = bech.rfind('1') + if pos < 1 or pos + 7 > len(bech): + return + if not all(x in bech32.CHARSET for x in bech[pos+1:]): + return + hrp = bech[:pos] + data = [bech32.CHARSET.find(x) for x in bech[pos+1:]] + encoding = bech32.bech32_verify_checksum(hrp, data) + if encoding is None: + return + return bytes(bech32.convertbits(data[:-6], 5, 8, False)) + +def xor_decrypt(key, blob): + s = BytesIO(blob) + variant = s.read(1)[0] + if variant != 1: + raise RuntimeError("Not implemented") + # reading nonce + l = s.read(1)[0] + nonce = s.read(l) + if len(nonce) != l: + raise RuntimeError("Missing nonce bytes") + if l < 8: + raise RuntimeError("Nonce is too short") + # reading payload + l = s.read(1)[0] + payload = s.read(l) + if len(payload) > 32: + raise RuntimeError("Payload is too long for this encryption method") + if len(payload) != l: + raise RuntimeError("Missing payload bytes") + hmacval = s.read() + expected = hmac.new(key, b"Data:" + blob[:-len(hmacval)], digestmod="sha256").digest() + if len(hmacval) < 8: + raise RuntimeError("HMAC is too short") + if hmacval != expected[:len(hmacval)]: + raise RuntimeError("HMAC is invalid") + secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest() + payload = bytearray(payload) + for i in range(len(payload)): + payload[i] = payload[i] ^ secret[i] + s = BytesIO(payload) + pin = compact.read_from(s) + amount_in_cent = compact.read_from(s) + return pin, amount_in_cent @lnurlpos_ext.get( - "/api/v1/lnurl/{nonce}/{payload}/{pos_id}", + "/api/v1/lnurl/{pos_id}", status_code=HTTPStatus.OK, - name="lnurlpos.lnurl_response", + name="lnurlpos.lnurl_v1_params", ) -async def lnurl_response( - request: Request, - nonce: str = Query(None), - pos_id: str = Query(None), - payload: str = Query(None), -): - return await handle_lnurl_firstrequest( - request, pos_id, nonce, payload, verify_checksum=False - ) - - -@lnurlpos_ext.get( - "/api/v2/lnurl/{pos_id}", - status_code=HTTPStatus.OK, - name="lnurlpos.lnurl_v2_params", -) -async def lnurl_v2_params( +async def lnurl_v1_params( request: Request, pos_id: str = Query(None), - n: str = Query(None), p: str = Query(None), ): - return await handle_lnurl_firstrequest(request, pos_id, n, p, verify_checksum=True) + return await handle_lnurl_firstrequest(request, pos_id, p) async def handle_lnurl_firstrequest( - request: Request, pos_id: str, nonce: str, payload: str, verify_checksum: bool + request: Request, pos_id: str, payload: str ): pos = await get_lnurlpos(pos_id) if not pos: @@ -59,53 +99,13 @@ async def handle_lnurl_firstrequest( "reason": f"lnurlpos {pos_id} not found on this server", } - try: - nonceb = bytes.fromhex(nonce) - except ValueError: - try: - nonce += "=" * ((4 - len(nonce) % 4) % 4) - nonceb = base64.urlsafe_b64decode(nonce) - except: - return { - "status": "ERROR", - "reason": f"Invalid hex or base64 nonce: {nonce}", - } - - try: - payloadb = bytes.fromhex(payload) - except ValueError: - try: - payload += "=" * ((4 - len(payload) % 4) % 4) - payloadb = base64.urlsafe_b64decode(payload) - except: - return { - "status": "ERROR", - "reason": f"Invalid hex or base64 payload: {payload}", - } - - # check payload and nonce sizes - if len(payloadb) != 8 or len(nonceb) != 8: - return {"status": "ERROR", "reason": "Expected 8 bytes"} - - # verify hmac - if verify_checksum: - expected = hmac.new( - pos.key.encode(), payloadb[:-2], digestmod="sha256" - ).digest() - if expected[:2] != payloadb[-2:]: - return {"status": "ERROR", "reason": "Invalid HMAC"} - - # decrypt - s = hmac.new(pos.key.encode(), nonceb, digestmod="sha256").digest() - res = bytearray(payloadb) - for i in range(len(res)): - res[i] = res[i] ^ s[i] - - pin = int.from_bytes(res[0:2], "little") - amount = int.from_bytes(res[2:6], "little") + if len(payload) % 4 > 0: + payload += "="*(4-(len(payload)%4)) + data = base64.urlsafe_b64decode(payload) + pin, amount_in_cent = xor_decrypt(pos.key.encode(), data) price_msat = ( - await fiat_amount_as_satoshis(float(amount) / 100, pos.currency) + await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency) if pos.currency != "sat" else amount ) * 1000 @@ -161,7 +161,7 @@ async def lnurl_callback(request: Request, paymentid: str = Query(None)): "successAction": { "tag": "url", "description": "Check the attached link", - "url": req.url_for("lnurlpos.displaypin", paymentid=paymentid), + "url": request.url_for("lnurlpos.displaypin", paymentid=paymentid), }, "routes": [], }