mirror of
https://github.com/bitcoin/bips.git
synced 2025-01-19 05:45:07 +01:00
cc177ab7bc
Includes: * Simpler (but equivalent) ElligatorSwift encoding function & spec * Improved test vectors * Test vector generation code * Code for converting test vectors for libsecp256k1 code. * Code for running test vectors against SwiftEC paper authors' code. * Miscellaneous reference code improvements (style, comments).
419 lines
17 KiB
Python
419 lines
17 KiB
Python
"""Generate the BIP-0324 test vectors."""
|
|
|
|
import csv
|
|
import hashlib
|
|
import os
|
|
import sys
|
|
from reference import (
|
|
FE,
|
|
GE,
|
|
MINUS_3_SQRT,
|
|
hkdf_sha256,
|
|
SECP256K1_G,
|
|
ellswift_decode,
|
|
ellswift_ecdh_xonly,
|
|
xswiftec_inv,
|
|
xswiftec,
|
|
v2_ecdh,
|
|
initialize_v2_transport,
|
|
v2_enc_packet
|
|
)
|
|
|
|
FILENAME_PACKET_TEST = os.path.join(sys.path[0], 'packet_encoding_test_vectors.csv')
|
|
FILENAME_XSWIFTEC_INV_TEST = os.path.join(sys.path[0], 'xswiftec_inv_test_vectors.csv')
|
|
FILENAME_ELLSWIFT_DECODE_TEST = os.path.join(sys.path[0], 'ellswift_decode_test_vectors.csv')
|
|
|
|
def xswiftec_flagged(u, t, simplified=False):
|
|
"""A variant of xswiftec which also returns 'flags', describing conditions encountered."""
|
|
flags = []
|
|
if u == 0:
|
|
flags.append("u%p=0")
|
|
u = FE(1)
|
|
if t == 0:
|
|
flags.append("t%p=0")
|
|
t = FE(1)
|
|
if u**3 + t**2 + 7 == 0:
|
|
flags.append("(u'^3+t'^2+7)%p=0")
|
|
t = 2 * t
|
|
X = (u**3 + 7 - t**2) / (2 * t)
|
|
Y = (X + t) / (MINUS_3_SQRT * u)
|
|
if X == 0:
|
|
if not simplified:
|
|
flags.append("(u'^3-t'^2+7)%p=0")
|
|
x3 = u + 4 * Y**2
|
|
if GE.is_valid_x(x3):
|
|
flags.append("valid_x(x3)")
|
|
x2 = (-X / Y - u) / 2
|
|
if GE.is_valid_x(x2):
|
|
flags.append("valid_x(x2)")
|
|
x1 = (X / Y - u) / 2
|
|
if GE.is_valid_x(x1):
|
|
flags.append("valid_x(x1)")
|
|
for x in (x3, x2, x1):
|
|
if GE.is_valid_x(x):
|
|
break
|
|
return x, flags
|
|
|
|
|
|
def ellswift_create_deterministic(seed, features):
|
|
"""This is a variant of ellswift_create which doesn't use randomness.
|
|
|
|
features is an integer selecting some properties of the result:
|
|
- (f & 3) == 0: only x1 is valid on decoding (see xswiftec{_flagged})
|
|
- (f & 3) == 1: only x2 is valid on decoding
|
|
- (f & 3) == 2: only x3 is valid on decoding
|
|
- (f & 3) == 3: x1,x2,x3 are all valid on decoding
|
|
- (f & 4) == 4: u >= p
|
|
- (f & 8) == 8: u mod n == 0
|
|
|
|
Returns privkey, ellswift
|
|
"""
|
|
|
|
cnt = 0
|
|
while True:
|
|
sec = hkdf_sha256(32, seed, (cnt).to_bytes(4, 'little'), b"sec")
|
|
xval = (int.from_bytes(sec, 'big') * SECP256K1_G).x
|
|
cnt += 1
|
|
if features & 8:
|
|
u = 0
|
|
if features & 4:
|
|
u += FE.SIZE
|
|
else:
|
|
udat = hkdf_sha256(64, seed, (cnt).to_bytes(4, 'little'), b"u")
|
|
if features & 4:
|
|
u = FE.SIZE + 1 + int.from_bytes(udat, 'big') % (2**256 - FE.SIZE - 1)
|
|
else:
|
|
u = 1 + int.from_bytes(udat, 'big') % (FE.SIZE - 1)
|
|
case = hkdf_sha256(1, seed, (cnt).to_bytes(4, 'little'), b"case")[0] & 7
|
|
coru = FE(u) + ((features & 8) == 8)
|
|
t = xswiftec_inv(xval, coru, case)
|
|
if t is None:
|
|
continue
|
|
assert xswiftec(FE(u), t) == xval
|
|
x2, flags = xswiftec_flagged(FE(u), t)
|
|
assert x2 == xval
|
|
have_x1 = "valid_x(x1)" in flags
|
|
have_x2 = "valid_x(x2)" in flags
|
|
have_x3 = "valid_x(x3)" in flags
|
|
if (features & 4) == 0 and not (have_x1 and not have_x2 and not have_x3):
|
|
continue
|
|
if (features & 4) == 1 and not (not have_x1 and have_x2 and not have_x3):
|
|
continue
|
|
if (features & 4) == 2 and not (not have_x1 and not have_x2 and have_x3):
|
|
continue
|
|
if (features & 4) == 3 and not (have_x1 and have_x2 and have_x3):
|
|
continue
|
|
return sec, u.to_bytes(32, 'big') + t.to_bytes()
|
|
|
|
def ellswift_decode_flagged(ellswift, simplified=False):
|
|
"""Decode a 64-byte ElligatorSwift encoded coordinate, returning byte array + flag string."""
|
|
uv = int.from_bytes(ellswift[:32], 'big')
|
|
tv = int.from_bytes(ellswift[32:], 'big')
|
|
x, flags = xswiftec_flagged(FE(uv), FE(tv))
|
|
if not simplified:
|
|
if uv >= FE.SIZE:
|
|
flags.append("u>=p")
|
|
if tv >= FE.SIZE:
|
|
flags.append("t>=p")
|
|
return int(x).to_bytes(32, 'big'), ";".join(flags)
|
|
|
|
def random_fe_int(_, seed, i, p):
|
|
"""Function to use in tuple_expand, generating a random integer in 0..p-1."""
|
|
rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe" % p)
|
|
return int.from_bytes(rng_out, 'big') % FE.SIZE
|
|
|
|
def random_fe_int_high(_, seed, i, p):
|
|
"""Function to use in tuple_expand, generating a random integer in p..2^256-1."""
|
|
rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe_high" % p)
|
|
return FE.SIZE + int.from_bytes(rng_out, 'big') % (2**256 - FE.SIZE)
|
|
|
|
def fn_of(p_in, fn):
|
|
"""Function to use in tuple_expand, to pick one variable in function of another."""
|
|
def inner(vs, _seed, _i, p):
|
|
assert p != p_in
|
|
if isinstance(vs[p_in], int):
|
|
return fn(vs[p_in])
|
|
return None
|
|
return inner
|
|
|
|
def tuple_expand(out, tuplespec, prio, seed=None, cnt=1):
|
|
"""Given a tuple specification, expand it cnt times, and add results to out.
|
|
|
|
Expansion is defined recursively:
|
|
- If any of the spec elements is a list, each element of the list results
|
|
in an expansion (by replacing the list with its element).
|
|
- If any of the spec elements is a function, that function is invoked with
|
|
(spec, seed, expansion count, index in spec) as arguments. If the function
|
|
needs to wait for other indices to be expanded, it can return None.
|
|
|
|
The output consists of (prio, expansion count, SHA256(result), result, seed)
|
|
tuples."""
|
|
|
|
def recurse(vs, seed, i, change_pos=None, change=None):
|
|
if change_pos is not None:
|
|
vs = list(vs)
|
|
vs[change_pos] = change
|
|
for p, v in enumerate(vs):
|
|
if v is None:
|
|
return
|
|
if isinstance(v, list):
|
|
for ve in v:
|
|
recurse(vs, seed, i, p, ve)
|
|
return
|
|
if callable(v):
|
|
res = v(vs, seed, i, p)
|
|
if res is not None:
|
|
recurse(vs, seed, i, p, res)
|
|
return
|
|
h = hashlib.sha256()
|
|
for v in vs:
|
|
h.update(int(v).to_bytes(32, 'big'))
|
|
out.append((prio, i, h.digest(), vs, seed))
|
|
for i in range(cnt):
|
|
recurse(tuplespec, seed, i)
|
|
|
|
def gen_ellswift_decode_cases(seed, simplified=False):
|
|
"""Generate a set of interesting (ellswift, x, flags) ellswift decoding cases."""
|
|
inputs = []
|
|
|
|
# Aggregate for use in tuple_expand, expanding to int in 0..p-1, and one in p..2^256-1.
|
|
RANDOM_VAL = [random_fe_int, random_fe_int_high]
|
|
# Aggregate for use in tuple_expand, expanding to integers which %p equal 0.
|
|
ZERO_VAL = [0, FE.SIZE]
|
|
# Helpers for constructing u and t values such that u^3+t^2+7=0 or u^3-t^2+7=0.
|
|
T_FOR_SUM_ZERO = fn_of(0, lambda u: (-FE(u)**3 - 7).sqrts())
|
|
T_FOR_DIFF_ZERO = fn_of(0, lambda u: (FE(u)**3 + 7).sqrts())
|
|
U_FOR_SUM_ZERO = fn_of(1, lambda t: (-FE(t)**2 - 7).cbrts())
|
|
U_FOR_DIFF_ZERO = fn_of(1, lambda t: (FE(t)**2 - 7).cbrts())
|
|
|
|
tuple_expand(inputs, [RANDOM_VAL, RANDOM_VAL], 0, seed + b"random", 64)
|
|
tuple_expand(inputs, [RANDOM_VAL, T_FOR_SUM_ZERO], 1, seed + b"t=sqrt(-u^3-7)", 64)
|
|
tuple_expand(inputs, [U_FOR_SUM_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(-t^2-7)", 64)
|
|
tuple_expand(inputs, [RANDOM_VAL, T_FOR_DIFF_ZERO], 1, seed + b"t=sqrt(u^3+7)", 64)
|
|
tuple_expand(inputs, [U_FOR_DIFF_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(t^2-7)", 64)
|
|
tuple_expand(inputs, [ZERO_VAL, RANDOM_VAL], 2, seed + b"u=0", 64)
|
|
tuple_expand(inputs, [RANDOM_VAL, ZERO_VAL], 2, seed + b"t=0", 64)
|
|
tuple_expand(inputs, [ZERO_VAL, FE(8).sqrts()], 3, seed + b"u=0;t=sqrt(8)")
|
|
tuple_expand(inputs, [FE(-8).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-8)")
|
|
tuple_expand(inputs, [FE(-6).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-6)")
|
|
tuple_expand(inputs, [ZERO_VAL, ZERO_VAL], 3, seed + b"u=0;t=0")
|
|
# Unused.
|
|
tuple_expand(inputs, [ZERO_VAL, FE(-8).sqrts()], 4, seed + b"u=0;t=sqrt(-8)")
|
|
|
|
seen = set()
|
|
cases = []
|
|
for _prio, _cnt, _hash, vs, _seed in sorted(inputs):
|
|
inp = int(vs[0]).to_bytes(32, 'big') + int(vs[1]).to_bytes(32, 'big')
|
|
outp, flags = ellswift_decode_flagged(inp, simplified)
|
|
if flags not in seen:
|
|
cases.append((inp, outp, flags))
|
|
seen.add(flags)
|
|
|
|
return cases
|
|
|
|
def gen_all_ellswift_decode_vectors(fil):
|
|
"""Generate all xelligatorswift decoding test vectors."""
|
|
|
|
cases = gen_ellswift_decode_cases(b"")
|
|
writer = csv.DictWriter(fil, ["ellswift", "x", "comment"])
|
|
writer.writeheader()
|
|
for val, x, flags in sorted(cases):
|
|
writer.writerow({"ellswift": val.hex(), "x": x.hex(), "comment": flags})
|
|
|
|
def xswiftec_inv_flagged(x, u, case):
|
|
"""A variant of xswiftec_inv which also returns flags, describing conditions encountered."""
|
|
|
|
flags = []
|
|
|
|
if case & 2 == 0:
|
|
if GE.is_valid_x(-x - u):
|
|
flags.append("bad[valid_x(-x-u)]")
|
|
return None, flags
|
|
v = x if case & 1 == 0 else -x - u
|
|
if v == 0:
|
|
flags.append("info[v=0]")
|
|
s = -(u**3 + 7) / (u**2 + u*v + v**2)
|
|
assert s != 0 # would imply X=0 on curve
|
|
else:
|
|
s = x - u
|
|
if s == 0:
|
|
flags.append("bad[s=0]")
|
|
return None, flags
|
|
q = (-s * (4 * (u**3 + 7) + 3 * s * u**2))
|
|
if q == 0:
|
|
flags.append("info[q=0]")
|
|
r = q.sqrt()
|
|
if r is None:
|
|
flags.append("bad[non_square(q)]")
|
|
return None, flags
|
|
if case & 1:
|
|
if r == 0:
|
|
flags.append("bad[r=0]")
|
|
return None, flags
|
|
r = -r
|
|
v = (-u + r / s) / 2
|
|
if v == 0:
|
|
flags.append("info[v=0]")
|
|
w = s.sqrt()
|
|
assert w != 0
|
|
if w is None:
|
|
flags.append("bad[non_square(s)]")
|
|
return None, flags
|
|
if case & 4:
|
|
w = -w
|
|
Y = w / 2
|
|
assert Y != 0
|
|
X = 2 * Y * (v + u / 2)
|
|
if X == 0:
|
|
flags.append("info[X=0]")
|
|
flags.append("ok")
|
|
return w * (u * (MINUS_3_SQRT - 1) / 2 - v), flags
|
|
|
|
def xswiftec_inv_combo_flagged(x, u):
|
|
"""Compute the aggregate results and flags from xswiftec_inv_flagged for case=0..7."""
|
|
ts = []
|
|
allflags = []
|
|
for case in range(8):
|
|
t, flags = xswiftec_inv_flagged(x, u, case)
|
|
if t is not None:
|
|
assert x == xswiftec(u, t)
|
|
ts.append(t)
|
|
allflags.append(f"case{case}:{'&'.join(flags)}")
|
|
return ts, ";".join(allflags)
|
|
|
|
def gen_all_xswiftec_inv_vectors(fil):
|
|
"""Generate all xswiftec_inv test vectors."""
|
|
|
|
# Two constants used below. Compute them only once.
|
|
C1 = (FE(MINUS_3_SQRT) - 1) / 2
|
|
C2 = (-FE(MINUS_3_SQRT) - 1) / 2
|
|
# Helper functions that pick x and u with special properties.
|
|
TRIGGER_Q_ZERO = fn_of(1, lambda u: (FE(u)**3 + 28) / (FE(-3) * FE(u)**2))
|
|
TRIGGER_DIVZERO_A = fn_of(1, lambda u: FE(u) * C1)
|
|
TRIGGER_DIVZERO_B = fn_of(1, lambda u: FE(u) * C2)
|
|
TRIGGER_V_ZERO = fn_of(1, lambda u: FE(-7) / FE(u)**2)
|
|
TRIGGER_X_ZERO = fn_of(0, lambda x: FE(-2) * FE(x))
|
|
|
|
inputs = []
|
|
tuple_expand(inputs, [random_fe_int, random_fe_int], 0, b"uniform", 256)
|
|
tuple_expand(inputs, [TRIGGER_Q_ZERO, random_fe_int], 1, b"x=-(u^3+28)/(3*u^2)", 64)
|
|
tuple_expand(inputs, [TRIGGER_V_ZERO, random_fe_int], 1, b"x=-7/u^2", 512)
|
|
tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: x)], 2, b"u=x", 64)
|
|
tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: -FE(x))], 2, b"u=-x", 64)
|
|
# Unused.
|
|
tuple_expand(inputs, [TRIGGER_DIVZERO_A, random_fe_int], 3, b"x=u*(sqrt(-3)-1)/2", 64)
|
|
tuple_expand(inputs, [TRIGGER_DIVZERO_B, random_fe_int], 3, b"x=u*(-sqrt(-3)-1)/2", 64)
|
|
tuple_expand(inputs, [random_fe_int, TRIGGER_X_ZERO], 3, b"u=-2x", 64)
|
|
|
|
seen = set()
|
|
cases = []
|
|
for _prio, _cnt, _hash, vs, _seed in sorted(inputs):
|
|
x, u = FE(vs[0]), FE(vs[1])
|
|
if u == 0:
|
|
continue
|
|
if not GE.is_valid_x(x):
|
|
continue
|
|
ts, flags = xswiftec_inv_combo_flagged(x, u)
|
|
if flags not in seen:
|
|
cases.append((int(u), int(x), ts, flags))
|
|
seen.add(flags)
|
|
|
|
writer = csv.DictWriter(fil, ["u", "x"] + [f"case{c}_t" for c in range(8)] + ["comment"])
|
|
writer.writeheader()
|
|
for u, x, ts, flags in sorted(cases):
|
|
row = {"u": FE(u), "x": FE(x), "comment": flags}
|
|
for c in range(8):
|
|
if ts[c] is not None:
|
|
row[f"case{c}_t"] = FE(ts[c])
|
|
writer.writerow(row)
|
|
|
|
def gen_packet_encoding_vector(case):
|
|
"""Given a dict case with specs, construct a packet_encoding test vector as a CSV line."""
|
|
ikm = str(case).encode('utf-8')
|
|
in_initiating = case["init"]
|
|
in_ignore = int(case["ignore"])
|
|
in_priv_ours, in_ellswift_ours = ellswift_create_deterministic(ikm, case["features"])
|
|
mid_x_ours = (int.from_bytes(in_priv_ours, 'big') * SECP256K1_G).x.to_bytes()
|
|
assert mid_x_ours == ellswift_decode(in_ellswift_ours)
|
|
in_ellswift_theirs = case["theirs"]
|
|
in_contents = hkdf_sha256(case["contentlen"], ikm, b"contents", b"")
|
|
contents = in_contents * case["multiply"]
|
|
in_aad = hkdf_sha256(case["aadlen"], ikm, b"aad", b"")
|
|
mid_shared_secret = v2_ecdh(in_priv_ours, in_ellswift_theirs, in_ellswift_ours, in_initiating)
|
|
|
|
peer = initialize_v2_transport(mid_shared_secret, in_initiating)
|
|
for _ in range(case["idx"]):
|
|
v2_enc_packet(peer, b"")
|
|
ciphertext = v2_enc_packet(peer, contents, in_aad, case["ignore"])
|
|
long_msg = len(ciphertext) > 128
|
|
|
|
return {
|
|
"in_idx": case['idx'],
|
|
"in_priv_ours": in_priv_ours.hex(),
|
|
"in_ellswift_ours": in_ellswift_ours.hex(),
|
|
"in_ellswift_theirs": in_ellswift_theirs.hex(),
|
|
"in_initiating": int(in_initiating),
|
|
"in_contents": in_contents.hex(),
|
|
"in_multiply": case['multiply'],
|
|
"in_aad": in_aad.hex(),
|
|
"in_ignore": in_ignore,
|
|
"mid_x_ours": mid_x_ours.hex(),
|
|
"mid_x_theirs": ellswift_decode(in_ellswift_theirs).hex(),
|
|
"mid_x_shared": ellswift_ecdh_xonly(in_ellswift_theirs, in_priv_ours).hex(),
|
|
"mid_shared_secret": mid_shared_secret.hex(),
|
|
"mid_initiator_l": peer['initiator_L'].hex(),
|
|
"mid_initiator_p": peer['initiator_P'].hex(),
|
|
"mid_responder_l": peer['responder_L'].hex(),
|
|
"mid_responder_p": peer['responder_P'].hex(),
|
|
"mid_send_garbage_terminator": peer["send_garbage_terminator"].hex(),
|
|
"mid_recv_garbage_terminator": peer["recv_garbage_terminator"].hex(),
|
|
"out_session_id": peer["session_id"].hex(),
|
|
"out_ciphertext": "" if long_msg else ciphertext.hex(),
|
|
"out_ciphertext_endswith": ciphertext[-128:].hex() if long_msg else ""
|
|
}
|
|
|
|
def gen_all_packet_encoding_vectors(fil):
|
|
"""Return a list of CSV lines, one for each packet encoding vector."""
|
|
|
|
ellswift = gen_ellswift_decode_cases(b"simplified_", simplified=True)
|
|
ellswift.sort(key=lambda x: hashlib.sha256(b"simplified:" + x[0]).digest())
|
|
|
|
fields = [
|
|
"in_idx", "in_priv_ours", "in_ellswift_ours", "in_ellswift_theirs", "in_initiating",
|
|
"in_contents", "in_multiply", "in_aad", "in_ignore", "mid_x_ours", "mid_x_theirs",
|
|
"mid_x_shared", "mid_shared_secret", "mid_initiator_l", "mid_initiator_p",
|
|
"mid_responder_l", "mid_responder_p", "mid_send_garbage_terminator",
|
|
"mid_recv_garbage_terminator", "out_session_id", "out_ciphertext", "out_ciphertext_endswith"
|
|
]
|
|
|
|
writer = csv.DictWriter(fil, fields)
|
|
writer.writeheader()
|
|
for case in [
|
|
{"init": True, "contentlen": 1, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 1,
|
|
"theirs": ellswift[0][0], "features": 0},
|
|
{"init": False, "contentlen": 17, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 999,
|
|
"theirs": ellswift[1][0], "features": 1},
|
|
{"init": True, "contentlen": 63, "multiply": 1, "aadlen": 4095, "ignore": False, "idx": 0,
|
|
"theirs": ellswift[2][0], "features": 2},
|
|
{"init": False, "contentlen": 128, "multiply": 1, "aadlen": 0, "ignore": True, "idx": 223,
|
|
"theirs": ellswift[3][0], "features": 3},
|
|
{"init": True, "contentlen": 193, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 448,
|
|
"theirs": ellswift[4][0], "features": 4},
|
|
{"init": False, "contentlen": 41, "multiply": 97561, "aadlen": 0, "ignore": False,
|
|
"idx": 673, "theirs": ellswift[5][0], "features": 5},
|
|
{"init": True, "contentlen": 241, "multiply": 69615, "aadlen": 0, "ignore": True,
|
|
"idx": 1024, "theirs": ellswift[6][0], "features": 6},
|
|
]:
|
|
writer.writerow(gen_packet_encoding_vector(case))
|
|
|
|
if __name__ == "__main__":
|
|
print(f"Generating {FILENAME_PACKET_TEST}...")
|
|
with open(FILENAME_PACKET_TEST, "w", encoding="utf-8") as fil_packet:
|
|
gen_all_packet_encoding_vectors(fil_packet)
|
|
print(f"Generating {FILENAME_XSWIFTEC_INV_TEST}...")
|
|
with open(FILENAME_XSWIFTEC_INV_TEST, "w", encoding="utf-8") as fil_xswiftec_inv:
|
|
gen_all_xswiftec_inv_vectors(fil_xswiftec_inv)
|
|
print(f"Generating {FILENAME_ELLSWIFT_DECODE_TEST}...")
|
|
with open(FILENAME_ELLSWIFT_DECODE_TEST, "w", encoding="utf-8") as fil_ellswift_decode:
|
|
gen_all_ellswift_decode_vectors(fil_ellswift_decode)
|