mirror of
https://github.com/ElementsProject/lightning.git
synced 2025-01-07 14:29:33 +01:00
8b0635f7d3
switched from pyelliptic to hmac/binascii/cryptography for standard functions use our own ECDH implementation to better match the one from secp256k1 finally, add function to create an encrypted onion
330 lines
12 KiB
Python
330 lines
12 KiB
Python
#!/usr/bin/env python
|
|
|
|
import sys
|
|
|
|
from hashlib import sha256
|
|
from binascii import hexlify, unhexlify
|
|
import hmac
|
|
import random
|
|
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, modes, algorithms
|
|
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
|
from cryptography.hazmat.primitives.ciphers.modes import CTR
|
|
from cryptography.hazmat.backends import default_backend
|
|
# http://cryptography.io
|
|
|
|
from pyelliptic import ecc
|
|
|
|
class MyEx(Exception): pass
|
|
|
|
def hmac_sha256(k, m):
|
|
return hmac.new(k, m, sha256).digest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
## pyelliptic doesn't support compressed pubkey representations
|
|
## so we have to add some code...
|
|
from pyelliptic.openssl import OpenSSL
|
|
import ctypes
|
|
|
|
OpenSSL.EC_POINT_set_compressed_coordinates_GFp = \
|
|
OpenSSL._lib.EC_POINT_set_compressed_coordinates_GFp
|
|
OpenSSL.EC_POINT_set_compressed_coordinates_GFp.restype = ctypes.c_int
|
|
OpenSSL.EC_POINT_set_compressed_coordinates_GFp.argtypes = [
|
|
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int,
|
|
ctypes.c_void_p]
|
|
|
|
def ecc_ecdh_key(sec, pub):
|
|
assert isinstance(sec, ecc.ECC)
|
|
if isinstance(pub, ecc.ECC):
|
|
pub = pub.get_pubkey()
|
|
#return sec.get_ecdh_key(pub)
|
|
|
|
pubkey_x, pubkey_y = ecc.ECC._decode_pubkey(pub, 'binary')
|
|
|
|
other_key = other_pub_key_x = other_pub_key_y = other_pub_key = None
|
|
own_priv_key = res = res_x = res_y = None
|
|
try:
|
|
other_key = OpenSSL.EC_KEY_new_by_curve_name(sec.curve)
|
|
if other_key == 0:
|
|
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())
|
|
|
|
other_pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
|
|
other_pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), 0)
|
|
|
|
other_group = OpenSSL.EC_KEY_get0_group(other_key)
|
|
other_pub_key = OpenSSL.EC_POINT_new(other_group)
|
|
if (other_pub_key == None):
|
|
raise Exception("[OpenSSl] EC_POINT_new FAIL ... " + OpenSSL.get_error())
|
|
|
|
if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(other_group,
|
|
other_pub_key,
|
|
other_pub_key_x,
|
|
other_pub_key_y,
|
|
0)) == 0:
|
|
raise Exception(
|
|
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ..." + OpenSSL.get_error())
|
|
|
|
own_priv_key = OpenSSL.BN_bin2bn(sec.privkey, len(sec.privkey), 0)
|
|
|
|
res = OpenSSL.EC_POINT_new(other_group)
|
|
if (OpenSSL.EC_POINT_mul(other_group, res, 0, other_pub_key, own_priv_key, 0)) == 0:
|
|
raise Exception(
|
|
"[OpenSSL] EC_POINT_mul FAIL ..." + OpenSSL.get_error())
|
|
|
|
res_x = OpenSSL.BN_new()
|
|
res_y = OpenSSL.BN_new()
|
|
|
|
if (OpenSSL.EC_POINT_get_affine_coordinates_GFp(other_group, res,
|
|
res_x,
|
|
res_y, 0
|
|
)) == 0:
|
|
raise Exception(
|
|
"[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())
|
|
|
|
resx = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(res_x))
|
|
resy = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(res_y))
|
|
|
|
OpenSSL.BN_bn2bin(res_x, resx)
|
|
resx = resx.raw
|
|
OpenSSL.BN_bn2bin(res_y, resy)
|
|
resy = resy.raw
|
|
|
|
return resx, resy
|
|
|
|
finally:
|
|
if other_key: OpenSSL.EC_KEY_free(other_key)
|
|
if other_pub_key_x: OpenSSL.BN_free(other_pub_key_x)
|
|
if other_pub_key_y: OpenSSL.BN_free(other_pub_key_y)
|
|
if other_pub_key: OpenSSL.EC_POINT_free(other_pub_key)
|
|
if own_priv_key: OpenSSL.BN_free(own_priv_key)
|
|
if res: OpenSSL.EC_POINT_free(res)
|
|
if res_x: OpenSSL.BN_free(res_x)
|
|
if res_y: OpenSSL.BN_free(res_y)
|
|
|
|
def get_pos_y_for_x(pubkey_x, yneg=0):
|
|
key = pub_key = pub_key_x = pub_key_y = None
|
|
try:
|
|
key = OpenSSL.EC_KEY_new_by_curve_name(OpenSSL.get_curve('secp256k1'))
|
|
group = OpenSSL.EC_KEY_get0_group(key)
|
|
pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
|
|
pub_key = OpenSSL.EC_POINT_new(group)
|
|
|
|
if OpenSSL.EC_POINT_set_compressed_coordinates_GFp(group, pub_key,
|
|
pub_key_x, yneg, 0) == 0:
|
|
raise Exception("[OpenSSL] EC_POINT_set_compressed_coordinates_GFp FAIL ... " + OpenSSL.get_error())
|
|
|
|
|
|
pub_key_y = OpenSSL.BN_new()
|
|
if (OpenSSL.EC_POINT_get_affine_coordinates_GFp(group, pub_key,
|
|
pub_key_x,
|
|
pub_key_y, 0
|
|
)) == 0:
|
|
raise Exception("[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())
|
|
|
|
pubkeyy = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(pub_key_y))
|
|
OpenSSL.BN_bn2bin(pub_key_y, pubkeyy)
|
|
pubkeyy = pubkeyy.raw
|
|
field_size = OpenSSL.EC_GROUP_get_degree(OpenSSL.EC_KEY_get0_group(key))
|
|
secret_len = int((field_size + 7) / 8)
|
|
if len(pubkeyy) < secret_len:
|
|
pubkeyy = pubkeyy.rjust(secret_len, b'\0')
|
|
return pubkeyy
|
|
finally:
|
|
if key is not None: OpenSSL.EC_KEY_free(key)
|
|
if pub_key is not None: OpenSSL.EC_POINT_free(pub_key)
|
|
if pub_key_x is not None: OpenSSL.BN_free(pub_key_x)
|
|
if pub_key_y is not None: OpenSSL.BN_free(pub_key_y)
|
|
|
|
class Onion(object):
|
|
HMAC_LEN = 32
|
|
PKEY_LEN = 32
|
|
MSG_LEN = 128
|
|
ZEROES = b"\x00" * (HMAC_LEN + PKEY_LEN + MSG_LEN)
|
|
|
|
@staticmethod
|
|
def tweak_sha(sha, d):
|
|
sha = sha.copy()
|
|
sha.update(d)
|
|
return sha.digest()
|
|
|
|
@classmethod
|
|
def get_ecdh_secrets(cls, sec, pkey_x, pkey_y):
|
|
pkey = unhexlify('04') + pkey_x + pkey_y
|
|
tmp_key = ecc.ECC(curve='secp256k1', pubkey=pkey)
|
|
sec_x, sec_y = ecc_ecdh_key(sec, tmp_key)
|
|
|
|
b = '\x02' if ord(sec_y[-1]) % 2 == 0 else '\x03'
|
|
sec = sha256(sha256(b + sec_x).digest())
|
|
|
|
enckey = cls.tweak_sha(sec, b'\x00')[:16]
|
|
hmac = cls.tweak_sha(sec, b'\x01')
|
|
iv = cls.tweak_sha(sec, b'\x02')[:16]
|
|
pad_iv = cls.tweak_sha(sec, b'\x03')[:16]
|
|
|
|
return enckey, hmac, iv, pad_iv
|
|
|
|
def enc_pad(self, enckey, pad_iv):
|
|
aes = Cipher(AES(enckey), CTR(pad_iv),
|
|
default_backend()).encryptor()
|
|
return aes.update(self.ZEROES)
|
|
|
|
class OnionDecrypt(Onion):
|
|
def __init__(self, onion, my_ecc):
|
|
self.my_ecc = my_ecc
|
|
|
|
hmac_end = len(onion)
|
|
pkey_end = hmac_end - self.HMAC_LEN
|
|
self.msg_end = pkey_end - self.PKEY_LEN
|
|
self.fwd_end = self.msg_end - self.MSG_LEN
|
|
|
|
self.onion = onion
|
|
self.pkey = onion[self.msg_end:pkey_end]
|
|
self.hmac = onion[pkey_end:hmac_end]
|
|
|
|
self.get_secrets()
|
|
|
|
def decrypt(self):
|
|
pad = self.enc_pad(self.enckey, self.pad_iv)
|
|
|
|
aes = Cipher(AES(self.enckey), CTR(self.iv),
|
|
default_backend()).decryptor()
|
|
self.fwd = pad + aes.update(self.onion[:self.fwd_end])
|
|
self.msg = aes.update(self.onion[self.fwd_end:self.msg_end])
|
|
|
|
def get_secrets(self):
|
|
pkey_x = self.pkey
|
|
pkey_y = get_pos_y_for_x(pkey_x) # always positive by design
|
|
enckey, hmac, iv, pad_iv = self.get_ecdh_secrets(self.my_ecc, pkey_x, pkey_y)
|
|
if not self.check_hmac(hmac):
|
|
raise Exception("HMAC did not verify")
|
|
self.enckey = enckey
|
|
self.iv = iv
|
|
self.pad_iv = pad_iv
|
|
|
|
def check_hmac(self, hmac_key):
|
|
calc = hmac_sha256(hmac_key, self.onion[:-self.HMAC_LEN])
|
|
return calc == self.hmac
|
|
|
|
class OnionEncrypt(Onion):
|
|
def __init__(self, msgs, pubkeys):
|
|
assert len(msgs) == len(pubkeys)
|
|
assert 0 < len(msgs) <= 20
|
|
assert all( len(m) <= self.MSG_LEN for m in msgs )
|
|
|
|
msgs = [m + "\0"*(self.MSG_LEN - len(m)) for m in msgs]
|
|
pubkeys = [ecc.ECC(pubkey=pk, curve='secp256k1') for pk in pubkeys]
|
|
n = len(msgs)
|
|
|
|
tmpkeys = []
|
|
tmppubkeys = []
|
|
for i in range(n):
|
|
while True:
|
|
t = ecc.ECC(curve='secp256k1')
|
|
if ord(t.pubkey_y[-1]) % 2 == 0:
|
|
break
|
|
# or do the math to "flip" the secret key and pub key
|
|
tmpkeys.append(t)
|
|
tmppubkeys.append(t.pubkey_x)
|
|
|
|
enckeys, hmacs, ivs, pad_ivs = zip(*[self.get_ecdh_secrets(tmpkey, pkey.pubkey_x, pkey.pubkey_y)
|
|
for tmpkey, pkey in zip(tmpkeys, pubkeys)])
|
|
|
|
# padding takes the form:
|
|
# E_(n-1)(0000s)
|
|
# D_(n-1)(
|
|
# E(n-2)(0000s)
|
|
# D(n-2)(
|
|
# ...
|
|
# )
|
|
# )
|
|
|
|
padding = ""
|
|
for i in range(n-1):
|
|
pad = self.enc_pad(enckeys[i], pad_ivs[i])
|
|
aes = Cipher(AES(enckeys[i]), CTR(ivs[i]),
|
|
default_backend()).decryptor()
|
|
padding = pad + aes.update(padding)
|
|
|
|
if n < 20:
|
|
padding += str(bytearray(random.getrandbits(8)
|
|
for _ in range(len(self.ZEROES) * (20-n))))
|
|
|
|
# to encrypt the message we need to bump the counter past all
|
|
# the padding, then just encrypt the final message
|
|
aes = Cipher(AES(enckeys[-1]), CTR(ivs[-1]),
|
|
default_backend()).encryptor()
|
|
aes.update(padding) # don't care about cyphertext
|
|
msgenc = aes.update(msgs[-1])
|
|
|
|
msgenc = padding + msgenc + tmppubkeys[-1]
|
|
del padding
|
|
msgenc += hmac_sha256(hmacs[-1], msgenc)
|
|
|
|
# *PHEW*
|
|
# now iterate
|
|
|
|
for i in reversed(range(n-1)):
|
|
# drop the padding this node will add
|
|
msgenc = msgenc[len(self.ZEROES):]
|
|
# adding the msg
|
|
msgenc += msgs[i]
|
|
# encrypt it
|
|
aes = Cipher(AES(enckeys[i]), CTR(ivs[i]),
|
|
default_backend()).encryptor()
|
|
msgenc = aes.update(msgenc)
|
|
# add the tmp key
|
|
msgenc += tmppubkeys[i]
|
|
# add the hmac
|
|
msgenc += hmac_sha256(hmacs[i], msgenc)
|
|
self.onion = msgenc
|
|
|
|
def decode_from_file(f):
|
|
keys = []
|
|
msg = ""
|
|
for ln in f.readlines():
|
|
if ln.startswith(" * Keypair "):
|
|
w = ln.strip().split()
|
|
idx = int(w[2].strip(":"))
|
|
priv = unhexlify(w[3])
|
|
pub = unhexlify(w[4])
|
|
assert idx == len(keys)
|
|
keys.append(ecc.ECC(privkey=priv, pubkey=pub, curve='secp256k1'))
|
|
elif ln.startswith(" * Message:"):
|
|
msg = unhexlify(ln[11:].strip())
|
|
elif ln.startswith("Decrypting"):
|
|
pass
|
|
else:
|
|
print ln
|
|
assert ln.strip() == ""
|
|
|
|
assert msg != ""
|
|
for k in keys:
|
|
o = OnionDecrypt(msg, k)
|
|
o.decrypt()
|
|
print o.msg
|
|
msg = o.fwd
|
|
print "done"
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) > 1 and sys.argv[1] == "generate":
|
|
if len(sys.argv) == 3:
|
|
n = int(sys.argv[2])
|
|
else:
|
|
n = 20
|
|
servers = [ecc.ECC(curve='secp256k1') for _ in range(n)]
|
|
server_pubs = [s.get_pubkey() for s in servers]
|
|
msgs = ["Howzit %d..." % (i,) for i in range(n)]
|
|
|
|
o = OnionEncrypt(msgs, server_pubs)
|
|
|
|
for i, s in enumerate(servers):
|
|
print " * Keypair %d: %s %s" % (
|
|
i, hexlify(s.privkey), hexlify(s.get_pubkey()))
|
|
print " * Message: %s" % (hexlify(o.onion))
|
|
else:
|
|
decode_from_file(sys.stdin)
|