import hashlib import re import time from binascii import unhexlify from decimal import Decimal from typing import List, NamedTuple, Optional import bitstring # type: ignore import embit import secp256k1 from bech32 import CHARSET, bech32_decode, bech32_encode from ecdsa import SECP256k1, VerifyingKey # type: ignore from ecdsa.util import sigdecode_string # type: ignore class Route(NamedTuple): pubkey: str short_channel_id: str base_fee_msat: int ppm_fee: int cltv: int class Invoice(object): payment_hash: str amount_msat: int = 0 description: Optional[str] = None description_hash: Optional[str] = None payee: Optional[str] = None date: int expiry: int = 3600 secret: Optional[str] = None route_hints: List[Route] = [] min_final_cltv_expiry: int = 18 def decode(pr: str) -> Invoice: """bolt11 decoder, based on https://github.com/rustyrussell/lightning-payencode/blob/master/lnaddr.py """ hrp, decoded_data = bech32_decode(pr) if hrp is None or decoded_data is None: raise ValueError("Bad bech32 checksum") if not hrp.startswith("ln"): raise ValueError("Does not start with ln") bitarray = _u5_to_bitarray(decoded_data) # final signature 65 bytes, split it off. if len(bitarray) < 65 * 8: raise ValueError("Too short to contain signature") # extract the signature signature = bitarray[-65 * 8 :].tobytes() # the tagged fields as a bitstream data = bitstring.ConstBitStream(bitarray[: -65 * 8]) # build the invoice object invoice = Invoice() # decode the amount from the hrp m = re.search(r"[^\d]+", hrp[2:]) if m: amountstr = hrp[2 + m.end() :] if amountstr != "": invoice.amount_msat = _unshorten_amount(amountstr) # pull out date invoice.date = data.read(35).uint while data.pos != data.len: tag, tagdata, data = _pull_tagged(data) data_length = len(tagdata) / 5 if tag == "d": invoice.description = _trim_to_bytes(tagdata).decode("utf-8") elif tag == "h" and data_length == 52: invoice.description_hash = _trim_to_bytes(tagdata).hex() elif tag == "p" and data_length == 52: invoice.payment_hash = _trim_to_bytes(tagdata).hex() elif tag == "x": invoice.expiry = tagdata.uint elif tag == "n": invoice.payee = _trim_to_bytes(tagdata).hex() # this won't work in most cases, we must extract the payee # from the signature elif tag == "s": invoice.secret = _trim_to_bytes(tagdata).hex() elif tag == "r": s = bitstring.ConstBitStream(tagdata) while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: route = Route( pubkey=s.read(264).tobytes().hex(), short_channel_id=_readable_scid(s.read(64).intbe), base_fee_msat=s.read(32).intbe, ppm_fee=s.read(32).intbe, cltv=s.read(16).intbe, ) invoice.route_hints.append(route) # BOLT #11: # A reader MUST check that the `signature` is valid (see the `n` tagged # field specified below). # A reader MUST use the `n` field to validate the signature instead of # performing signature recovery if a valid `n` field is provided. message = bytearray([ord(c) for c in hrp]) + data.tobytes() sig = signature[0:64] if invoice.payee: key = VerifyingKey.from_string(unhexlify(invoice.payee), curve=SECP256k1) key.verify(sig, message, hashlib.sha256, sigdecode=sigdecode_string) else: keys = VerifyingKey.from_public_key_recovery( sig, message, SECP256k1, hashlib.sha256 ) signaling_byte = signature[64] key = keys[int(signaling_byte)] invoice.payee = key.to_string("compressed").hex() return invoice def encode(options): """Convert options into LnAddr and pass it to the encoder""" addr = LnAddr() addr.currency = options["currency"] addr.fallback = options["fallback"] if options["fallback"] else None if options["amount"]: addr.amount = options["amount"] if options["timestamp"]: addr.date = int(options["timestamp"]) addr.paymenthash = unhexlify(options["paymenthash"]) if options["description"]: addr.tags.append(("d", options["description"])) if options["description_hash"]: addr.tags.append(("h", options["description_hash"])) if options["expires"]: addr.tags.append(("x", options["expires"])) if options["fallback"]: addr.tags.append(("f", options["fallback"])) if options["route"]: for r in options["route"]: splits = r.split("/") route = [] while len(splits) >= 5: route.append( ( unhexlify(splits[0]), unhexlify(splits[1]), int(splits[2]), int(splits[3]), int(splits[4]), ) ) splits = splits[5:] assert len(splits) == 0 addr.tags.append(("r", route)) return lnencode(addr, options["privkey"]) def lnencode(addr, privkey): if addr.amount: amount = Decimal(str(addr.amount)) # We can only send down to millisatoshi. if amount * 10**12 % 10: raise ValueError( "Cannot encode {}: too many decimal places".format(addr.amount) ) amount = addr.currency + shorten_amount(amount) else: amount = addr.currency if addr.currency else "" hrp = "ln" + amount + "0n" # Start with the timestamp data = bitstring.pack("uint:35", addr.date) # Payment hash data += tagged_bytes("p", addr.paymenthash) tags_set = set() for k, v in addr.tags: # BOLT #11: # # A writer MUST NOT include more than one `d`, `h`, `n` or `x` fields, if k in ("d", "h", "n", "x"): if k in tags_set: raise ValueError("Duplicate '{}' tag".format(k)) if k == "r": route = bitstring.BitArray() for step in v: pubkey, channel, feebase, feerate, cltv = step route.append( bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack("intbe:32", feebase) + bitstring.pack("intbe:32", feerate) + bitstring.pack("intbe:16", cltv) ) data += tagged("r", route) elif k == "f": data += encode_fallback(v, addr.currency) elif k == "d": data += tagged_bytes("d", v.encode()) elif k == "x": # Get minimal length by trimming leading 5 bits at a time. expirybits = bitstring.pack("intbe:64", v)[4:64] while expirybits.startswith("0b00000"): expirybits = expirybits[5:] data += tagged("x", expirybits) elif k == "h": data += tagged_bytes("h", v) elif k == "n": data += tagged_bytes("n", v) else: # FIXME: Support unknown tags? raise ValueError("Unknown tag {}".format(k)) tags_set.add(k) # BOLT #11: # # A writer MUST include either a `d` or `h` field, and MUST NOT include # both. if "d" in tags_set and "h" in tags_set: raise ValueError("Cannot include both 'd' and 'h'") if not "d" in tags_set and not "h" in tags_set: raise ValueError("Must include either 'd' or 'h'") # We actually sign the hrp, then data (padded to 8 bits with zeroes). privkey = secp256k1.PrivateKey(bytes(unhexlify(privkey))) sig = privkey.ecdsa_sign_recoverable( bytearray([ord(c) for c in hrp]) + data.tobytes() ) # This doesn't actually serialize, but returns a pair of values :( sig, recid = privkey.ecdsa_recoverable_serialize(sig) data += bytes(sig) + bytes([recid]) return bech32_encode(hrp, bitarray_to_u5(data)) class LnAddr(object): def __init__( self, paymenthash=None, amount=None, currency="bc", tags=None, date=None ): self.date = int(time.time()) if not date else int(date) self.tags = [] if not tags else tags self.unknown_tags = [] self.paymenthash = paymenthash self.signature = None self.pubkey = None self.currency = currency self.amount = amount def __str__(self): return "LnAddr[{}, amount={}{} tags=[{}]]".format( hexlify(self.pubkey.serialize()).decode("utf-8"), self.amount, self.currency, ", ".join([k + "=" + str(v) for k, v in self.tags]), ) def shorten_amount(amount): """Given an amount in bitcoin, shorten it""" # Convert to pico initially amount = int(amount * 10**12) units = ["p", "n", "u", "m", ""] for unit in units: if amount % 1000 == 0: amount //= 1000 else: break return str(amount) + unit def _unshorten_amount(amount: str) -> int: """Given a shortened amount, return millisatoshis""" # BOLT #11: # The following `multiplier` letters are defined: # # * `m` (milli): multiply by 0.001 # * `u` (micro): multiply by 0.000001 # * `n` (nano): multiply by 0.000000001 # * `p` (pico): multiply by 0.000000000001 units = {"p": 10**12, "n": 10**9, "u": 10**6, "m": 10**3} unit = str(amount)[-1] # BOLT #11: # A reader SHOULD fail if `amount` contains a non-digit, or is followed by # anything except a `multiplier` in the table above. if not re.fullmatch(r"\d+[pnum]?", str(amount)): raise ValueError("Invalid amount '{}'".format(amount)) if unit in units: return int(int(amount[:-1]) * 100_000_000_000 / units[unit]) else: return int(amount) * 100_000_000_000 def _pull_tagged(stream): tag = stream.read(5).uint length = stream.read(5).uint * 32 + stream.read(5).uint return (CHARSET[tag], stream.read(length * 5), stream) def is_p2pkh(currency, prefix): return prefix == base58_prefix_map[currency][0] def is_p2sh(currency, prefix): return prefix == base58_prefix_map[currency][1] # Tagged field containing BitArray def tagged(char, l): # Tagged fields need to be zero-padded to 5 bits. while l.len % 5 != 0: l.append("0b0") return ( bitstring.pack( "uint:5, uint:5, uint:5", CHARSET.find(char), (l.len / 5) / 32, (l.len / 5) % 32, ) + l ) def tagged_bytes(char, l): return tagged(char, bitstring.BitArray(l)) def _trim_to_bytes(barr): # Adds a byte if necessary. b = barr.tobytes() if barr.len % 8 != 0: return b[:-1] return b def _readable_scid(short_channel_id: int) -> str: return "{blockheight}x{transactionindex}x{outputindex}".format( blockheight=((short_channel_id >> 40) & 0xFFFFFF), transactionindex=((short_channel_id >> 16) & 0xFFFFFF), outputindex=(short_channel_id & 0xFFFF), ) def _u5_to_bitarray(arr: List[int]) -> bitstring.BitArray: ret = bitstring.BitArray() for a in arr: ret += bitstring.pack("uint:5", a) return ret def bitarray_to_u5(barr): assert barr.len % 5 == 0 ret = [] s = bitstring.ConstBitStream(barr) while s.pos != s.len: ret.append(s.read(5).uint) return ret