1
0
mirror of https://github.com/bitcoin/bips.git synced 2024-11-19 09:50:06 +01:00
bitcoin-bips/bip-0324/gen_test_vectors.py
Pieter Wuille cc177ab7bc BIP324 updates
Includes:
* Simpler (but equivalent) ElligatorSwift encoding function & spec
* Improved test vectors
* Test vector generation code
* Code for converting test vectors for libsecp256k1 code.
* Code for running test vectors against SwiftEC paper authors' code.
* Miscellaneous reference code improvements (style, comments).
2023-01-11 17:39:56 -05:00

419 lines
17 KiB
Python

"""Generate the BIP-0324 test vectors."""
import csv
import hashlib
import os
import sys
from reference import (
FE,
GE,
MINUS_3_SQRT,
hkdf_sha256,
SECP256K1_G,
ellswift_decode,
ellswift_ecdh_xonly,
xswiftec_inv,
xswiftec,
v2_ecdh,
initialize_v2_transport,
v2_enc_packet
)
FILENAME_PACKET_TEST = os.path.join(sys.path[0], 'packet_encoding_test_vectors.csv')
FILENAME_XSWIFTEC_INV_TEST = os.path.join(sys.path[0], 'xswiftec_inv_test_vectors.csv')
FILENAME_ELLSWIFT_DECODE_TEST = os.path.join(sys.path[0], 'ellswift_decode_test_vectors.csv')
def xswiftec_flagged(u, t, simplified=False):
"""A variant of xswiftec which also returns 'flags', describing conditions encountered."""
flags = []
if u == 0:
flags.append("u%p=0")
u = FE(1)
if t == 0:
flags.append("t%p=0")
t = FE(1)
if u**3 + t**2 + 7 == 0:
flags.append("(u'^3+t'^2+7)%p=0")
t = 2 * t
X = (u**3 + 7 - t**2) / (2 * t)
Y = (X + t) / (MINUS_3_SQRT * u)
if X == 0:
if not simplified:
flags.append("(u'^3-t'^2+7)%p=0")
x3 = u + 4 * Y**2
if GE.is_valid_x(x3):
flags.append("valid_x(x3)")
x2 = (-X / Y - u) / 2
if GE.is_valid_x(x2):
flags.append("valid_x(x2)")
x1 = (X / Y - u) / 2
if GE.is_valid_x(x1):
flags.append("valid_x(x1)")
for x in (x3, x2, x1):
if GE.is_valid_x(x):
break
return x, flags
def ellswift_create_deterministic(seed, features):
"""This is a variant of ellswift_create which doesn't use randomness.
features is an integer selecting some properties of the result:
- (f & 3) == 0: only x1 is valid on decoding (see xswiftec{_flagged})
- (f & 3) == 1: only x2 is valid on decoding
- (f & 3) == 2: only x3 is valid on decoding
- (f & 3) == 3: x1,x2,x3 are all valid on decoding
- (f & 4) == 4: u >= p
- (f & 8) == 8: u mod n == 0
Returns privkey, ellswift
"""
cnt = 0
while True:
sec = hkdf_sha256(32, seed, (cnt).to_bytes(4, 'little'), b"sec")
xval = (int.from_bytes(sec, 'big') * SECP256K1_G).x
cnt += 1
if features & 8:
u = 0
if features & 4:
u += FE.SIZE
else:
udat = hkdf_sha256(64, seed, (cnt).to_bytes(4, 'little'), b"u")
if features & 4:
u = FE.SIZE + 1 + int.from_bytes(udat, 'big') % (2**256 - FE.SIZE - 1)
else:
u = 1 + int.from_bytes(udat, 'big') % (FE.SIZE - 1)
case = hkdf_sha256(1, seed, (cnt).to_bytes(4, 'little'), b"case")[0] & 7
coru = FE(u) + ((features & 8) == 8)
t = xswiftec_inv(xval, coru, case)
if t is None:
continue
assert xswiftec(FE(u), t) == xval
x2, flags = xswiftec_flagged(FE(u), t)
assert x2 == xval
have_x1 = "valid_x(x1)" in flags
have_x2 = "valid_x(x2)" in flags
have_x3 = "valid_x(x3)" in flags
if (features & 4) == 0 and not (have_x1 and not have_x2 and not have_x3):
continue
if (features & 4) == 1 and not (not have_x1 and have_x2 and not have_x3):
continue
if (features & 4) == 2 and not (not have_x1 and not have_x2 and have_x3):
continue
if (features & 4) == 3 and not (have_x1 and have_x2 and have_x3):
continue
return sec, u.to_bytes(32, 'big') + t.to_bytes()
def ellswift_decode_flagged(ellswift, simplified=False):
"""Decode a 64-byte ElligatorSwift encoded coordinate, returning byte array + flag string."""
uv = int.from_bytes(ellswift[:32], 'big')
tv = int.from_bytes(ellswift[32:], 'big')
x, flags = xswiftec_flagged(FE(uv), FE(tv))
if not simplified:
if uv >= FE.SIZE:
flags.append("u>=p")
if tv >= FE.SIZE:
flags.append("t>=p")
return int(x).to_bytes(32, 'big'), ";".join(flags)
def random_fe_int(_, seed, i, p):
"""Function to use in tuple_expand, generating a random integer in 0..p-1."""
rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe" % p)
return int.from_bytes(rng_out, 'big') % FE.SIZE
def random_fe_int_high(_, seed, i, p):
"""Function to use in tuple_expand, generating a random integer in p..2^256-1."""
rng_out = hkdf_sha256(64, seed, i.to_bytes(4, 'little'), b"v%i_fe_high" % p)
return FE.SIZE + int.from_bytes(rng_out, 'big') % (2**256 - FE.SIZE)
def fn_of(p_in, fn):
"""Function to use in tuple_expand, to pick one variable in function of another."""
def inner(vs, _seed, _i, p):
assert p != p_in
if isinstance(vs[p_in], int):
return fn(vs[p_in])
return None
return inner
def tuple_expand(out, tuplespec, prio, seed=None, cnt=1):
"""Given a tuple specification, expand it cnt times, and add results to out.
Expansion is defined recursively:
- If any of the spec elements is a list, each element of the list results
in an expansion (by replacing the list with its element).
- If any of the spec elements is a function, that function is invoked with
(spec, seed, expansion count, index in spec) as arguments. If the function
needs to wait for other indices to be expanded, it can return None.
The output consists of (prio, expansion count, SHA256(result), result, seed)
tuples."""
def recurse(vs, seed, i, change_pos=None, change=None):
if change_pos is not None:
vs = list(vs)
vs[change_pos] = change
for p, v in enumerate(vs):
if v is None:
return
if isinstance(v, list):
for ve in v:
recurse(vs, seed, i, p, ve)
return
if callable(v):
res = v(vs, seed, i, p)
if res is not None:
recurse(vs, seed, i, p, res)
return
h = hashlib.sha256()
for v in vs:
h.update(int(v).to_bytes(32, 'big'))
out.append((prio, i, h.digest(), vs, seed))
for i in range(cnt):
recurse(tuplespec, seed, i)
def gen_ellswift_decode_cases(seed, simplified=False):
"""Generate a set of interesting (ellswift, x, flags) ellswift decoding cases."""
inputs = []
# Aggregate for use in tuple_expand, expanding to int in 0..p-1, and one in p..2^256-1.
RANDOM_VAL = [random_fe_int, random_fe_int_high]
# Aggregate for use in tuple_expand, expanding to integers which %p equal 0.
ZERO_VAL = [0, FE.SIZE]
# Helpers for constructing u and t values such that u^3+t^2+7=0 or u^3-t^2+7=0.
T_FOR_SUM_ZERO = fn_of(0, lambda u: (-FE(u)**3 - 7).sqrts())
T_FOR_DIFF_ZERO = fn_of(0, lambda u: (FE(u)**3 + 7).sqrts())
U_FOR_SUM_ZERO = fn_of(1, lambda t: (-FE(t)**2 - 7).cbrts())
U_FOR_DIFF_ZERO = fn_of(1, lambda t: (FE(t)**2 - 7).cbrts())
tuple_expand(inputs, [RANDOM_VAL, RANDOM_VAL], 0, seed + b"random", 64)
tuple_expand(inputs, [RANDOM_VAL, T_FOR_SUM_ZERO], 1, seed + b"t=sqrt(-u^3-7)", 64)
tuple_expand(inputs, [U_FOR_SUM_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(-t^2-7)", 64)
tuple_expand(inputs, [RANDOM_VAL, T_FOR_DIFF_ZERO], 1, seed + b"t=sqrt(u^3+7)", 64)
tuple_expand(inputs, [U_FOR_DIFF_ZERO, RANDOM_VAL], 1, seed + b"u=cbrt(t^2-7)", 64)
tuple_expand(inputs, [ZERO_VAL, RANDOM_VAL], 2, seed + b"u=0", 64)
tuple_expand(inputs, [RANDOM_VAL, ZERO_VAL], 2, seed + b"t=0", 64)
tuple_expand(inputs, [ZERO_VAL, FE(8).sqrts()], 3, seed + b"u=0;t=sqrt(8)")
tuple_expand(inputs, [FE(-8).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-8)")
tuple_expand(inputs, [FE(-6).cbrts(), ZERO_VAL], 3, seed + b"t=0;u=cbrt(-6)")
tuple_expand(inputs, [ZERO_VAL, ZERO_VAL], 3, seed + b"u=0;t=0")
# Unused.
tuple_expand(inputs, [ZERO_VAL, FE(-8).sqrts()], 4, seed + b"u=0;t=sqrt(-8)")
seen = set()
cases = []
for _prio, _cnt, _hash, vs, _seed in sorted(inputs):
inp = int(vs[0]).to_bytes(32, 'big') + int(vs[1]).to_bytes(32, 'big')
outp, flags = ellswift_decode_flagged(inp, simplified)
if flags not in seen:
cases.append((inp, outp, flags))
seen.add(flags)
return cases
def gen_all_ellswift_decode_vectors(fil):
"""Generate all xelligatorswift decoding test vectors."""
cases = gen_ellswift_decode_cases(b"")
writer = csv.DictWriter(fil, ["ellswift", "x", "comment"])
writer.writeheader()
for val, x, flags in sorted(cases):
writer.writerow({"ellswift": val.hex(), "x": x.hex(), "comment": flags})
def xswiftec_inv_flagged(x, u, case):
"""A variant of xswiftec_inv which also returns flags, describing conditions encountered."""
flags = []
if case & 2 == 0:
if GE.is_valid_x(-x - u):
flags.append("bad[valid_x(-x-u)]")
return None, flags
v = x if case & 1 == 0 else -x - u
if v == 0:
flags.append("info[v=0]")
s = -(u**3 + 7) / (u**2 + u*v + v**2)
assert s != 0 # would imply X=0 on curve
else:
s = x - u
if s == 0:
flags.append("bad[s=0]")
return None, flags
q = (-s * (4 * (u**3 + 7) + 3 * s * u**2))
if q == 0:
flags.append("info[q=0]")
r = q.sqrt()
if r is None:
flags.append("bad[non_square(q)]")
return None, flags
if case & 1:
if r == 0:
flags.append("bad[r=0]")
return None, flags
r = -r
v = (-u + r / s) / 2
if v == 0:
flags.append("info[v=0]")
w = s.sqrt()
assert w != 0
if w is None:
flags.append("bad[non_square(s)]")
return None, flags
if case & 4:
w = -w
Y = w / 2
assert Y != 0
X = 2 * Y * (v + u / 2)
if X == 0:
flags.append("info[X=0]")
flags.append("ok")
return w * (u * (MINUS_3_SQRT - 1) / 2 - v), flags
def xswiftec_inv_combo_flagged(x, u):
"""Compute the aggregate results and flags from xswiftec_inv_flagged for case=0..7."""
ts = []
allflags = []
for case in range(8):
t, flags = xswiftec_inv_flagged(x, u, case)
if t is not None:
assert x == xswiftec(u, t)
ts.append(t)
allflags.append(f"case{case}:{'&'.join(flags)}")
return ts, ";".join(allflags)
def gen_all_xswiftec_inv_vectors(fil):
"""Generate all xswiftec_inv test vectors."""
# Two constants used below. Compute them only once.
C1 = (FE(MINUS_3_SQRT) - 1) / 2
C2 = (-FE(MINUS_3_SQRT) - 1) / 2
# Helper functions that pick x and u with special properties.
TRIGGER_Q_ZERO = fn_of(1, lambda u: (FE(u)**3 + 28) / (FE(-3) * FE(u)**2))
TRIGGER_DIVZERO_A = fn_of(1, lambda u: FE(u) * C1)
TRIGGER_DIVZERO_B = fn_of(1, lambda u: FE(u) * C2)
TRIGGER_V_ZERO = fn_of(1, lambda u: FE(-7) / FE(u)**2)
TRIGGER_X_ZERO = fn_of(0, lambda x: FE(-2) * FE(x))
inputs = []
tuple_expand(inputs, [random_fe_int, random_fe_int], 0, b"uniform", 256)
tuple_expand(inputs, [TRIGGER_Q_ZERO, random_fe_int], 1, b"x=-(u^3+28)/(3*u^2)", 64)
tuple_expand(inputs, [TRIGGER_V_ZERO, random_fe_int], 1, b"x=-7/u^2", 512)
tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: x)], 2, b"u=x", 64)
tuple_expand(inputs, [random_fe_int, fn_of(0, lambda x: -FE(x))], 2, b"u=-x", 64)
# Unused.
tuple_expand(inputs, [TRIGGER_DIVZERO_A, random_fe_int], 3, b"x=u*(sqrt(-3)-1)/2", 64)
tuple_expand(inputs, [TRIGGER_DIVZERO_B, random_fe_int], 3, b"x=u*(-sqrt(-3)-1)/2", 64)
tuple_expand(inputs, [random_fe_int, TRIGGER_X_ZERO], 3, b"u=-2x", 64)
seen = set()
cases = []
for _prio, _cnt, _hash, vs, _seed in sorted(inputs):
x, u = FE(vs[0]), FE(vs[1])
if u == 0:
continue
if not GE.is_valid_x(x):
continue
ts, flags = xswiftec_inv_combo_flagged(x, u)
if flags not in seen:
cases.append((int(u), int(x), ts, flags))
seen.add(flags)
writer = csv.DictWriter(fil, ["u", "x"] + [f"case{c}_t" for c in range(8)] + ["comment"])
writer.writeheader()
for u, x, ts, flags in sorted(cases):
row = {"u": FE(u), "x": FE(x), "comment": flags}
for c in range(8):
if ts[c] is not None:
row[f"case{c}_t"] = FE(ts[c])
writer.writerow(row)
def gen_packet_encoding_vector(case):
"""Given a dict case with specs, construct a packet_encoding test vector as a CSV line."""
ikm = str(case).encode('utf-8')
in_initiating = case["init"]
in_ignore = int(case["ignore"])
in_priv_ours, in_ellswift_ours = ellswift_create_deterministic(ikm, case["features"])
mid_x_ours = (int.from_bytes(in_priv_ours, 'big') * SECP256K1_G).x.to_bytes()
assert mid_x_ours == ellswift_decode(in_ellswift_ours)
in_ellswift_theirs = case["theirs"]
in_contents = hkdf_sha256(case["contentlen"], ikm, b"contents", b"")
contents = in_contents * case["multiply"]
in_aad = hkdf_sha256(case["aadlen"], ikm, b"aad", b"")
mid_shared_secret = v2_ecdh(in_priv_ours, in_ellswift_theirs, in_ellswift_ours, in_initiating)
peer = initialize_v2_transport(mid_shared_secret, in_initiating)
for _ in range(case["idx"]):
v2_enc_packet(peer, b"")
ciphertext = v2_enc_packet(peer, contents, in_aad, case["ignore"])
long_msg = len(ciphertext) > 128
return {
"in_idx": case['idx'],
"in_priv_ours": in_priv_ours.hex(),
"in_ellswift_ours": in_ellswift_ours.hex(),
"in_ellswift_theirs": in_ellswift_theirs.hex(),
"in_initiating": int(in_initiating),
"in_contents": in_contents.hex(),
"in_multiply": case['multiply'],
"in_aad": in_aad.hex(),
"in_ignore": in_ignore,
"mid_x_ours": mid_x_ours.hex(),
"mid_x_theirs": ellswift_decode(in_ellswift_theirs).hex(),
"mid_x_shared": ellswift_ecdh_xonly(in_ellswift_theirs, in_priv_ours).hex(),
"mid_shared_secret": mid_shared_secret.hex(),
"mid_initiator_l": peer['initiator_L'].hex(),
"mid_initiator_p": peer['initiator_P'].hex(),
"mid_responder_l": peer['responder_L'].hex(),
"mid_responder_p": peer['responder_P'].hex(),
"mid_send_garbage_terminator": peer["send_garbage_terminator"].hex(),
"mid_recv_garbage_terminator": peer["recv_garbage_terminator"].hex(),
"out_session_id": peer["session_id"].hex(),
"out_ciphertext": "" if long_msg else ciphertext.hex(),
"out_ciphertext_endswith": ciphertext[-128:].hex() if long_msg else ""
}
def gen_all_packet_encoding_vectors(fil):
"""Return a list of CSV lines, one for each packet encoding vector."""
ellswift = gen_ellswift_decode_cases(b"simplified_", simplified=True)
ellswift.sort(key=lambda x: hashlib.sha256(b"simplified:" + x[0]).digest())
fields = [
"in_idx", "in_priv_ours", "in_ellswift_ours", "in_ellswift_theirs", "in_initiating",
"in_contents", "in_multiply", "in_aad", "in_ignore", "mid_x_ours", "mid_x_theirs",
"mid_x_shared", "mid_shared_secret", "mid_initiator_l", "mid_initiator_p",
"mid_responder_l", "mid_responder_p", "mid_send_garbage_terminator",
"mid_recv_garbage_terminator", "out_session_id", "out_ciphertext", "out_ciphertext_endswith"
]
writer = csv.DictWriter(fil, fields)
writer.writeheader()
for case in [
{"init": True, "contentlen": 1, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 1,
"theirs": ellswift[0][0], "features": 0},
{"init": False, "contentlen": 17, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 999,
"theirs": ellswift[1][0], "features": 1},
{"init": True, "contentlen": 63, "multiply": 1, "aadlen": 4095, "ignore": False, "idx": 0,
"theirs": ellswift[2][0], "features": 2},
{"init": False, "contentlen": 128, "multiply": 1, "aadlen": 0, "ignore": True, "idx": 223,
"theirs": ellswift[3][0], "features": 3},
{"init": True, "contentlen": 193, "multiply": 1, "aadlen": 0, "ignore": False, "idx": 448,
"theirs": ellswift[4][0], "features": 4},
{"init": False, "contentlen": 41, "multiply": 97561, "aadlen": 0, "ignore": False,
"idx": 673, "theirs": ellswift[5][0], "features": 5},
{"init": True, "contentlen": 241, "multiply": 69615, "aadlen": 0, "ignore": True,
"idx": 1024, "theirs": ellswift[6][0], "features": 6},
]:
writer.writerow(gen_packet_encoding_vector(case))
if __name__ == "__main__":
print(f"Generating {FILENAME_PACKET_TEST}...")
with open(FILENAME_PACKET_TEST, "w", encoding="utf-8") as fil_packet:
gen_all_packet_encoding_vectors(fil_packet)
print(f"Generating {FILENAME_XSWIFTEC_INV_TEST}...")
with open(FILENAME_XSWIFTEC_INV_TEST, "w", encoding="utf-8") as fil_xswiftec_inv:
gen_all_xswiftec_inv_vectors(fil_xswiftec_inv)
print(f"Generating {FILENAME_ELLSWIFT_DECODE_TEST}...")
with open(FILENAME_ELLSWIFT_DECODE_TEST, "w", encoding="utf-8") as fil_ellswift_decode:
gen_all_ellswift_decode_vectors(fil_ellswift_decode)