core-lightning/contrib/pyln-proto/pyln/proto/invoice.py
Christian Decker 6c19818314 pyln-proto: Add invoice utilities
We are about to disect a couple of invoices for features, so let's add a class
that can encode and decode invoices from bolt11 strings. This is pretty much
the lnaddr.py file created by @rustyrussell with some minor changes. I'm
planning to clean this up further which is why I'm only exporting the
`Invoice` class for now.
2020-02-01 16:50:58 +01:00

471 lines
15 KiB
Python
Executable File

from .bech32 import bech32_encode, bech32_decode, CHARSET
from binascii import hexlify, unhexlify
from decimal import Decimal
from io import BufferedReader, BytesIO
import base58
import bitstring
import hashlib
import re
import secp256k1
import time
import struct
# BOLT #11:
#
# A writer MUST encode `amount` as a positive decimal integer with no
# leading zeroes, SHOULD use the shortest representation possible.
def shorten_amount(amount):
""" Given an amount in bitcoin, shorten it
"""
# Convert to pico initially
amount = int(amount * 10**12)
units = ['p', 'n', 'u', 'm', '']
for unit in units:
if amount % 1000 == 0:
amount //= 1000
else:
break
return str(amount) + unit
def unshorten_amount(amount):
""" Given a shortened amount, convert it into a decimal
"""
# BOLT #11:
# The following `multiplier` letters are defined:
#
# * `m` (milli): multiply by 0.001
# * `u` (micro): multiply by 0.000001
# * `n` (nano): multiply by 0.000000001
# * `p` (pico): multiply by 0.000000000001
units = {
'p': 10**12,
'n': 10**9,
'u': 10**6,
'm': 10**3,
}
unit = str(amount)[-1]
# BOLT #11:
# A reader SHOULD fail if `amount` contains a non-digit, or is followed by
# anything except a `multiplier` in the table above.
if not re.fullmatch(r'\d+[pnum]?', str(amount)):
raise ValueError("Invalid amount '{}'".format(amount))
if unit in units.keys():
return Decimal(amount[:-1]) / units[unit]
else:
return Decimal(amount)
# Bech32 spits out array of 5-bit values. Shim here.
def u5_to_bitarray(arr):
ret = bitstring.BitArray()
for a in arr:
ret += bitstring.pack("uint:5", a)
return ret
def bitarray_to_u5(barr):
assert barr.len % 5 == 0
ret = []
s = bitstring.ConstBitStream(barr)
while s.pos != s.len:
ret.append(s.read(5).uint)
return ret
def encode_fallback(fallback, currency):
""" Encode all supported fallback addresses.
"""
if currency == 'bc' or currency == 'tb':
fbhrp, witness = bech32_decode(fallback)
if fbhrp:
if fbhrp != currency:
raise ValueError("Not a bech32 address for this currency")
wver = witness[0]
if wver > 16:
raise ValueError("Invalid witness version {}".format(witness[0]))
wprog = u5_to_bitarray(witness[1:])
else:
addr = base58.b58decode_check(fallback)
if is_p2pkh(currency, addr[0]):
wver = 17
elif is_p2sh(currency, addr[0]):
wver = 18
else:
raise ValueError("Unknown address type for {}".format(currency))
wprog = addr[1:]
return tagged('f', bitstring.pack("uint:5", wver) + wprog)
else:
raise NotImplementedError("Support for currency {} not implemented".format(currency))
def parse_fallback(fallback, currency):
if currency == 'bc' or currency == 'tb':
wver = fallback[0:5].uint
if wver == 17:
addr = base58.b58encode_check(bytes([base58_prefix_map[currency][0]])
+ fallback[5:].tobytes())
elif wver == 18:
addr = base58.b58encode_check(bytes([base58_prefix_map[currency][1]])
+ fallback[5:].tobytes())
elif wver <= 16:
addr = bech32_encode(currency, bitarray_to_u5(fallback))
else:
return None
else:
addr = fallback.tobytes()
return addr
# Map of classical and witness address prefixes
base58_prefix_map = {
'bc': (0, 5),
'tb': (111, 196)
}
def is_p2pkh(currency, prefix):
return prefix == base58_prefix_map[currency][0]
def is_p2sh(currency, prefix):
return prefix == base58_prefix_map[currency][1]
# Tagged field containing BitArray
def tagged(char, l):
# Tagged fields need to be zero-padded to 5 bits.
while l.len % 5 != 0:
l.append('0b0')
return bitstring.pack("uint:5, uint:5, uint:5",
CHARSET.find(char),
(l.len / 5) / 32, (l.len / 5) % 32) + l
# Tagged field containing bytes
def tagged_bytes(char, l):
return tagged(char, bitstring.BitArray(l))
# Discard trailing bits, convert to bytes.
def trim_to_bytes(barr):
# Adds a byte if necessary.
b = barr.tobytes()
if barr.len % 8 != 0:
return b[:-1]
return b
# Try to pull out tagged data: returns tag, tagged data and remainder.
def pull_tagged(stream):
tag = stream.read(5).uint
length = stream.read(5).uint * 32 + stream.read(5).uint
return (CHARSET[tag], stream.read(length * 5), stream)
class Invoice(object):
def __init__(self, paymenthash=None, amount=None, currency='bc', tags=None, date=None):
self.date = int(time.time()) if not date else int(date)
self.tags = [] if not tags else tags
self.unknown_tags = []
self.paymenthash = paymenthash
self.signature = None
self.pubkey = None
self.currency = currency
self.amount = amount
self.min_final_cltv_expiry = None
self.route_hints = None
def __str__(self):
return "Invoice[{}, amount={}{} tags=[{}]]".format(
hexlify(self.pubkey.serialize()).decode('utf-8'),
self.amount, self.currency,
", ".join([k + '=' + str(v) for k, v in self.tags])
)
@property
def hexpubkey(self):
return hexlify(self.pubkey.serialize()).decode('ASCII')
@property
def hexpaymenthash(self):
return hexlify(self.paymenthash).decode('ASCII')
def _get_tagged(self, tag):
return [t[1] for t in self.tags + self.unknown_tags if t[0] == tag]
@property
def featurebits(self):
features = self._get_tagged('9')
assert(len(features) <= 1)
if features == []:
return 0
else:
return features[0]
def encode(self, privkey):
if self.amount:
amount = Decimal(str(self.amount))
# We can only send down to millisatoshi.
if amount * 10**12 % 10:
raise ValueError("Cannot encode {}: too many decimal places".format(
self.amount))
amount = self.currency + shorten_amount(amount)
else:
amount = self.currency if self.currency else ''
hrp = 'ln' + amount
# Start with the timestamp
data = bitstring.pack('uint:35', self.date)
# Payment hash
data += tagged_bytes('p', self.paymenthash)
tags_set = set()
if self.route_hints is not None:
for rh in self.route_hints.route_hints:
data += tagged_bytes('r', rh.to_bytes())
for k, v in self.tags:
# BOLT #11:
#
# A writer MUST NOT include more than one `d`, `h`, `n` or `x` fields,
if k in ('d', 'h', 'n', 'x'):
if k in tags_set:
raise ValueError("Duplicate '{}' tag".format(k))
if k == 'r':
pubkey, channel, fee, cltv = v
route = bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:64', fee) + bitstring.pack('intbe:16', cltv)
data += tagged('r', route)
elif k == 'f':
data += encode_fallback(v, self.currency)
elif k == 'd':
data += tagged_bytes('d', v.encode())
elif k == 'x':
# Get minimal length by trimming leading 5 bits at a time.
expirybits = bitstring.pack('intbe:64', v)[4:64]
while expirybits.startswith('0b00000'):
expirybits = expirybits[5:]
data += tagged('x', expirybits)
elif k == 'h':
data += tagged_bytes('h', hashlib.sha256(v.encode('utf-8')).digest())
elif k == 'n':
data += tagged_bytes('n', v)
else:
# FIXME: Support unknown tags?
raise ValueError("Unknown tag {}".format(k))
tags_set.add(k)
# BOLT #11:
#
# A writer MUST include either a `d` or `h` field, and MUST NOT include
# both.
if 'd' in tags_set and 'h' in tags_set:
raise ValueError("Cannot include both 'd' and 'h'")
if 'd' not in tags_set and 'h' not in tags_set:
raise ValueError("Must include either 'd' or 'h'")
# We actually sign the hrp, then data (padded to 8 bits with zeroes).
privkey = secp256k1.PrivateKey(bytes(unhexlify(privkey)))
sig = privkey.ecdsa_sign_recoverable(bytearray([ord(c) for c in hrp]) + data.tobytes())
# This doesn't actually serialize, but returns a pair of values :(
sig, recid = privkey.ecdsa_recoverable_serialize(sig)
data += bytes(sig) + bytes([recid])
return bech32_encode(hrp, bitarray_to_u5(data))
@classmethod
def decode(cls, b):
hrp, data = bech32_decode(b)
if not hrp:
raise ValueError("Bad bech32 checksum")
# BOLT #11:
#
# A reader MUST fail if it does not understand the `prefix`.
if not hrp.startswith('ln'):
raise ValueError("Does not start with ln")
data = u5_to_bitarray(data)
# Final signature 65 bytes, split it off.
if len(data) < 65 * 8:
raise ValueError("Too short to contain signature")
sigdecoded = data[-65 * 8:].tobytes()
data = bitstring.ConstBitStream(data[:-65 * 8])
inv = Invoice()
inv.pubkey = None
m = re.search(r'[^\d]+', hrp[2:])
if m:
inv.currency = m.group(0)
amountstr = hrp[2 + m.end():]
# BOLT #11:
#
# A reader SHOULD indicate if amount is unspecified, otherwise it MUST
# multiply `amount` by the `multiplier` value (if any) to derive the
# amount required for payment.
if amountstr != '':
inv.amount = unshorten_amount(amountstr)
inv.date = data.read(35).uint
while data.pos != data.len:
tag, tagdata, data = pull_tagged(data)
# BOLT #11:
#
# A reader MUST skip over unknown fields, an `f` field with unknown
# `version`, or a `p`, `h`, `n` or `r` field which does not have
# `data_length` 52, 52, 53 or 82 respectively.
data_length = len(tagdata) / 5
if tag == 'r':
inv.route_hints = RouteHintSet.from_bytes(trim_to_bytes(tagdata))
continue
if data_length != 82:
inv.unknown_tags.append((tag, tagdata))
continue
tagbytes = trim_to_bytes(tagdata)
inv.tags.append(('r', (
tagbytes[0:33],
tagbytes[33:41],
tagdata[41 * 8:49 * 8].intbe,
tagdata[49 * 8:51 * 8].intbe
)))
elif tag == 'f':
fallback = parse_fallback(tagdata, inv.currency)
if fallback:
inv.tags.append(('f', fallback))
else:
# Incorrect version.
inv.unknown_tags.append((tag, tagdata))
continue
elif tag == 'd':
inv.tags.append(('d', trim_to_bytes(tagdata).decode('utf-8')))
elif tag == 'h':
if data_length != 52:
inv.unknown_tags.append((tag, tagdata))
continue
inv.tags.append(('h', trim_to_bytes(tagdata)))
elif tag == 'x':
inv.tags.append(('x', tagdata.uint))
elif tag == 'p':
if data_length != 52:
inv.unknown_tags.append((tag, tagdata))
continue
inv.paymenthash = trim_to_bytes(tagdata)
elif tag == 'n':
if data_length != 53:
inv.unknown_tags.append((tag, tagdata))
continue
inv.pubkey = secp256k1.PublicKey(flags=secp256k1.ALL_FLAGS)
inv.pubkey.deserialize(trim_to_bytes(tagdata))
elif tag == 'c':
inv.min_final_cltv_expiry = tagdata.uint
else:
inv.unknown_tags.append((tag, tagdata))
# BOLT #11:
#
# A reader MUST check that the `signature` is valid (see the `n` tagged
# field specified below).
if inv.pubkey: # Specified by `n`
# BOLT #11:
#
# A reader MUST use the `n` field to validate the signature instead of
# performing signature recovery if a valid `n` field is provided.
inv.signature = inv.pubkey.ecdsa_deserialize_compact(sigdecoded[0:64])
if not inv.pubkey.ecdsa_verify(bytearray([ord(c) for c in hrp]) + data.tobytes(), inv.signature):
raise ValueError('Invalid signature')
else: # Recover pubkey from signature.
inv.pubkey = secp256k1.PublicKey(flags=secp256k1.ALL_FLAGS)
inv.signature = inv.pubkey.ecdsa_recoverable_deserialize(
sigdecoded[0:64], sigdecoded[64])
inv.pubkey.public_key = inv.pubkey.ecdsa_recover(
bytearray([ord(c) for c in hrp]) + data.tobytes(), inv.signature)
return inv
class RouteHint(object):
length = 33 + 8 + 4 + 4 + 2
def __init__(self):
self.pubkey = None
self.short_channel_id = None
self.fee_base_msat = None
self.fee_proportional_millionths = None
self.cltv_expiry_delta = None
@classmethod
def from_bytes(cls, b):
inst = RouteHint()
inst.pubkey = b.read(33)
inst.short_channel_id, = struct.unpack("!Q", b.read(8))
inst.fee_base_msat, inst.fee_proportional_millionths, inst.cltv_expiry_delta = struct.unpack("!IIH", b.read(10))
return inst
def to_bytes(self):
return self.pubkey + struct.pack(
"!QIIH", self.short_channel_id, self.fee_base_msat,
self.fee_proportional_millionths, self.cltv_expiry_delta
)
def __str__(self):
pubkey = hexlify(self.pubkey).decode('ASCII')
return f"RouteHint<pubkey={pubkey}, short_channel_id={self.short_channel_id}, fee_base_msat={self.fee_base_msat}, fee_prop={self.fee_proportional_millionths}, cltv_expiry_delta={self.cltv_expiry_delta}>"
class RouteHintSet(object):
def __init__(self):
self.route_hints = []
@classmethod
def from_bytes(cls, b):
if isinstance(b, bytes):
b = BufferedReader(BytesIO(b))
if not isinstance(b, BufferedReader):
raise TypeError('from_bytes can only read from bytes-arrays or BufferedReader')
if len(b.raw.getvalue()) % RouteHint.length != 0:
raise TypeError("byte string is not a multiple of the route hint size: {}".format(
len(b.raw.getvalue())
))
instance = RouteHintSet()
while b.peek():
instance.route_hints.append(RouteHint.from_bytes(b))
return instance
def to_bytes(self):
return b''.join([rh.to_bytes() for rh in self.route_hints])
def __str__(self):
return "RouteHintSet[{}]".format(
", ".join([str(rh) for rh in self.route_hints])
)
def add(self, rh: RouteHint):
self.route_hints.append(rh)