mirror of
https://github.com/bitcoin/bips.git
synced 2024-11-19 01:40:05 +01:00
26bb1d8ea3
(cherry picked from commit 4f2e6e7ffbd2fdc095ab8d59827be9da18b790be)
882 lines
34 KiB
Python
882 lines
34 KiB
Python
# BIP327 reference implementation
|
|
#
|
|
# WARNING: This implementation is for demonstration purposes only and _not_ to
|
|
# be used in production environments. The code is vulnerable to timing attacks,
|
|
# for example.
|
|
|
|
from typing import Any, List, Optional, Tuple, NewType, NamedTuple
|
|
import hashlib
|
|
import secrets
|
|
import time
|
|
|
|
#
|
|
# The following helper functions were copied from the BIP-340 reference implementation:
|
|
# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py
|
|
#
|
|
|
|
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
|
|
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
|
|
|
|
# Points are tuples of X and Y coordinates and the point at infinity is
|
|
# represented by the None keyword.
|
|
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)
|
|
|
|
Point = Tuple[int, int]
|
|
|
|
# This implementation can be sped up by storing the midstate after hashing
|
|
# tag_hash instead of rehashing it all the time.
|
|
def tagged_hash(tag: str, msg: bytes) -> bytes:
|
|
tag_hash = hashlib.sha256(tag.encode()).digest()
|
|
return hashlib.sha256(tag_hash + tag_hash + msg).digest()
|
|
|
|
def is_infinite(P: Optional[Point]) -> bool:
|
|
return P is None
|
|
|
|
def x(P: Point) -> int:
|
|
assert not is_infinite(P)
|
|
return P[0]
|
|
|
|
def y(P: Point) -> int:
|
|
assert not is_infinite(P)
|
|
return P[1]
|
|
|
|
def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]:
|
|
if P1 is None:
|
|
return P2
|
|
if P2 is None:
|
|
return P1
|
|
if (x(P1) == x(P2)) and (y(P1) != y(P2)):
|
|
return None
|
|
if P1 == P2:
|
|
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p
|
|
else:
|
|
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p
|
|
x3 = (lam * lam - x(P1) - x(P2)) % p
|
|
return (x3, (lam * (x(P1) - x3) - y(P1)) % p)
|
|
|
|
def point_mul(P: Optional[Point], n: int) -> Optional[Point]:
|
|
R = None
|
|
for i in range(256):
|
|
if (n >> i) & 1:
|
|
R = point_add(R, P)
|
|
P = point_add(P, P)
|
|
return R
|
|
|
|
def bytes_from_int(x: int) -> bytes:
|
|
return x.to_bytes(32, byteorder="big")
|
|
|
|
def lift_x(b: bytes) -> Optional[Point]:
|
|
x = int_from_bytes(b)
|
|
if x >= p:
|
|
return None
|
|
y_sq = (pow(x, 3, p) + 7) % p
|
|
y = pow(y_sq, (p + 1) // 4, p)
|
|
if pow(y, 2, p) != y_sq:
|
|
return None
|
|
return (x, y if y & 1 == 0 else p-y)
|
|
|
|
def int_from_bytes(b: bytes) -> int:
|
|
return int.from_bytes(b, byteorder="big")
|
|
|
|
def has_even_y(P: Point) -> bool:
|
|
assert not is_infinite(P)
|
|
return y(P) % 2 == 0
|
|
|
|
def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool:
|
|
if len(msg) != 32:
|
|
raise ValueError('The message must be a 32-byte array.')
|
|
if len(pubkey) != 32:
|
|
raise ValueError('The public key must be a 32-byte array.')
|
|
if len(sig) != 64:
|
|
raise ValueError('The signature must be a 64-byte array.')
|
|
P = lift_x(pubkey)
|
|
r = int_from_bytes(sig[0:32])
|
|
s = int_from_bytes(sig[32:64])
|
|
if (P is None) or (r >= p) or (s >= n):
|
|
return False
|
|
e = int_from_bytes(tagged_hash("BIP0340/challenge", sig[0:32] + pubkey + msg)) % n
|
|
R = point_add(point_mul(G, s), point_mul(P, n - e))
|
|
if (R is None) or (not has_even_y(R)) or (x(R) != r):
|
|
return False
|
|
return True
|
|
|
|
#
|
|
# End of helper functions copied from BIP-340 reference implementation.
|
|
#
|
|
|
|
PlainPk = NewType('PlainPk', bytes)
|
|
XonlyPk = NewType('XonlyPk', bytes)
|
|
|
|
# There are two types of exceptions that can be raised by this implementation:
|
|
# - ValueError for indicating that an input doesn't conform to some function
|
|
# precondition (e.g. an input array is the wrong length, a serialized
|
|
# representation doesn't have the correct format).
|
|
# - InvalidContributionError for indicating that a signer (or the
|
|
# aggregator) is misbehaving in the protocol.
|
|
#
|
|
# Assertions are used to (1) satisfy the type-checking system, and (2) check for
|
|
# inconvenient events that can't happen except with negligible probability (e.g.
|
|
# output of a hash function is 0) and can't be manually triggered by any
|
|
# signer.
|
|
|
|
# This exception is raised if a party (signer or nonce aggregator) sends invalid
|
|
# values. Actual implementations should not crash when receiving invalid
|
|
# contributions. Instead, they should hold the offending party accountable.
|
|
class InvalidContributionError(Exception):
|
|
def __init__(self, signer, contrib):
|
|
self.signer = signer
|
|
# contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig".
|
|
self.contrib = contrib
|
|
|
|
infinity = None
|
|
|
|
def xbytes(P: Point) -> bytes:
|
|
return bytes_from_int(x(P))
|
|
|
|
def cbytes(P: Point) -> bytes:
|
|
a = b'\x02' if has_even_y(P) else b'\x03'
|
|
return a + xbytes(P)
|
|
|
|
def cbytes_ext(P: Optional[Point]) -> bytes:
|
|
if is_infinite(P):
|
|
return (0).to_bytes(33, byteorder='big')
|
|
assert P is not None
|
|
return cbytes(P)
|
|
|
|
def point_negate(P: Optional[Point]) -> Optional[Point]:
|
|
if P is None:
|
|
return P
|
|
return (x(P), p - y(P))
|
|
|
|
def cpoint(x: bytes) -> Point:
|
|
if len(x) != 33:
|
|
raise ValueError('x is not a valid compressed point.')
|
|
P = lift_x(x[1:33])
|
|
if P is None:
|
|
raise ValueError('x is not a valid compressed point.')
|
|
if x[0] == 2:
|
|
return P
|
|
elif x[0] == 3:
|
|
P = point_negate(P)
|
|
assert P is not None
|
|
return P
|
|
else:
|
|
raise ValueError('x is not a valid compressed point.')
|
|
|
|
def cpoint_ext(x: bytes) -> Optional[Point]:
|
|
if x == (0).to_bytes(33, 'big'):
|
|
return None
|
|
else:
|
|
return cpoint(x)
|
|
|
|
# Return the plain public key corresponding to a given secret key
|
|
def individual_pk(seckey: bytes) -> PlainPk:
|
|
d0 = int_from_bytes(seckey)
|
|
if not (1 <= d0 <= n - 1):
|
|
raise ValueError('The secret key must be an integer in the range 1..n-1.')
|
|
P = point_mul(G, d0)
|
|
assert P is not None
|
|
return PlainPk(cbytes(P))
|
|
|
|
def key_sort(pubkeys: List[PlainPk]) -> List[PlainPk]:
|
|
pubkeys.sort()
|
|
return pubkeys
|
|
|
|
KeyAggContext = NamedTuple('KeyAggContext', [('Q', Point),
|
|
('gacc', int),
|
|
('tacc', int)])
|
|
|
|
def get_xonly_pk(keyagg_ctx: KeyAggContext) -> XonlyPk:
|
|
Q, _, _ = keyagg_ctx
|
|
return XonlyPk(xbytes(Q))
|
|
|
|
def key_agg(pubkeys: List[PlainPk]) -> KeyAggContext:
|
|
pk2 = get_second_key(pubkeys)
|
|
u = len(pubkeys)
|
|
Q = infinity
|
|
for i in range(u):
|
|
try:
|
|
P_i = cpoint(pubkeys[i])
|
|
except ValueError:
|
|
raise InvalidContributionError(i, "pubkey")
|
|
a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2)
|
|
Q = point_add(Q, point_mul(P_i, a_i))
|
|
# Q is not the point at infinity except with negligible probability.
|
|
assert(Q is not None)
|
|
gacc = 1
|
|
tacc = 0
|
|
return KeyAggContext(Q, gacc, tacc)
|
|
|
|
def hash_keys(pubkeys: List[PlainPk]) -> bytes:
|
|
return tagged_hash('KeyAgg list', b''.join(pubkeys))
|
|
|
|
def get_second_key(pubkeys: List[PlainPk]) -> PlainPk:
|
|
u = len(pubkeys)
|
|
for j in range(1, u):
|
|
if pubkeys[j] != pubkeys[0]:
|
|
return pubkeys[j]
|
|
return PlainPk(b'\x00'*33)
|
|
|
|
def key_agg_coeff(pubkeys: List[PlainPk], pk_: PlainPk) -> int:
|
|
pk2 = get_second_key(pubkeys)
|
|
return key_agg_coeff_internal(pubkeys, pk_, pk2)
|
|
|
|
def key_agg_coeff_internal(pubkeys: List[PlainPk], pk_: PlainPk, pk2: PlainPk) -> int:
|
|
L = hash_keys(pubkeys)
|
|
if pk_ == pk2:
|
|
return 1
|
|
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n
|
|
|
|
def apply_tweak(keyagg_ctx: KeyAggContext, tweak: bytes, is_xonly: bool) -> KeyAggContext:
|
|
if len(tweak) != 32:
|
|
raise ValueError('The tweak must be a 32-byte array.')
|
|
Q, gacc, tacc = keyagg_ctx
|
|
if is_xonly and not has_even_y(Q):
|
|
g = n - 1
|
|
else:
|
|
g = 1
|
|
t = int_from_bytes(tweak)
|
|
if t >= n:
|
|
raise ValueError('The tweak must be less than n.')
|
|
Q_ = point_add(point_mul(Q, g), point_mul(G, t))
|
|
if Q_ is None:
|
|
raise ValueError('The result of tweaking cannot be infinity.')
|
|
gacc_ = g * gacc % n
|
|
tacc_ = (t + g * tacc) % n
|
|
return KeyAggContext(Q_, gacc_, tacc_)
|
|
|
|
def bytes_xor(a: bytes, b: bytes) -> bytes:
|
|
return bytes(x ^ y for x, y in zip(a, b))
|
|
|
|
def nonce_hash(rand: bytes, pk: PlainPk, aggpk: XonlyPk, i: int, msg_prefixed: bytes, extra_in: bytes) -> int:
|
|
buf = b''
|
|
buf += rand
|
|
buf += len(pk).to_bytes(1, 'big')
|
|
buf += pk
|
|
buf += len(aggpk).to_bytes(1, 'big')
|
|
buf += aggpk
|
|
buf += msg_prefixed
|
|
buf += len(extra_in).to_bytes(4, 'big')
|
|
buf += extra_in
|
|
buf += i.to_bytes(1, 'big')
|
|
return int_from_bytes(tagged_hash('MuSig/nonce', buf))
|
|
|
|
def nonce_gen_internal(rand_: bytes, sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]:
|
|
if sk is not None:
|
|
rand = bytes_xor(sk, tagged_hash('MuSig/aux', rand_))
|
|
else:
|
|
rand = rand_
|
|
if aggpk is None:
|
|
aggpk = XonlyPk(b'')
|
|
if msg is None:
|
|
msg_prefixed = b'\x00'
|
|
else:
|
|
msg_prefixed = b'\x01'
|
|
msg_prefixed += len(msg).to_bytes(8, 'big')
|
|
msg_prefixed += msg
|
|
if extra_in is None:
|
|
extra_in = b''
|
|
k_1 = nonce_hash(rand, pk, aggpk, 0, msg_prefixed, extra_in) % n
|
|
k_2 = nonce_hash(rand, pk, aggpk, 1, msg_prefixed, extra_in) % n
|
|
# k_1 == 0 or k_2 == 0 cannot occur except with negligible probability.
|
|
assert k_1 != 0
|
|
assert k_2 != 0
|
|
R_s1 = point_mul(G, k_1)
|
|
R_s2 = point_mul(G, k_2)
|
|
assert R_s1 is not None
|
|
assert R_s2 is not None
|
|
pubnonce = cbytes(R_s1) + cbytes(R_s2)
|
|
secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + pk)
|
|
return secnonce, pubnonce
|
|
|
|
def nonce_gen(sk: Optional[bytes], pk: PlainPk, aggpk: Optional[XonlyPk], msg: Optional[bytes], extra_in: Optional[bytes]) -> Tuple[bytearray, bytes]:
|
|
if sk is not None and len(sk) != 32:
|
|
raise ValueError('The optional byte array sk must have length 32.')
|
|
if aggpk is not None and len(aggpk) != 32:
|
|
raise ValueError('The optional byte array aggpk must have length 32.')
|
|
rand_ = secrets.token_bytes(32)
|
|
return nonce_gen_internal(rand_, sk, pk, aggpk, msg, extra_in)
|
|
|
|
def nonce_agg(pubnonces: List[bytes]) -> bytes:
|
|
u = len(pubnonces)
|
|
aggnonce = b''
|
|
for j in (1, 2):
|
|
R_j = infinity
|
|
for i in range(u):
|
|
try:
|
|
R_ij = cpoint(pubnonces[i][(j-1)*33:j*33])
|
|
except ValueError:
|
|
raise InvalidContributionError(i, "pubnonce")
|
|
R_j = point_add(R_j, R_ij)
|
|
aggnonce += cbytes_ext(R_j)
|
|
return aggnonce
|
|
|
|
SessionContext = NamedTuple('SessionContext', [('aggnonce', bytes),
|
|
('pubkeys', List[PlainPk]),
|
|
('tweaks', List[bytes]),
|
|
('is_xonly', List[bool]),
|
|
('msg', bytes)])
|
|
|
|
def key_agg_and_tweak(pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool]) -> KeyAggContext:
|
|
if len(tweaks) != len(is_xonly):
|
|
raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.')
|
|
keyagg_ctx = key_agg(pubkeys)
|
|
v = len(tweaks)
|
|
for i in range(v):
|
|
keyagg_ctx = apply_tweak(keyagg_ctx, tweaks[i], is_xonly[i])
|
|
return keyagg_ctx
|
|
|
|
def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, int, Point, int]:
|
|
(aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx
|
|
Q, gacc, tacc = key_agg_and_tweak(pubkeys, tweaks, is_xonly)
|
|
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + xbytes(Q) + msg)) % n
|
|
try:
|
|
R_1 = cpoint_ext(aggnonce[0:33])
|
|
R_2 = cpoint_ext(aggnonce[33:66])
|
|
except ValueError:
|
|
# Nonce aggregator sent invalid nonces
|
|
raise InvalidContributionError(None, "aggnonce")
|
|
R_ = point_add(R_1, point_mul(R_2, b))
|
|
R = R_ if not is_infinite(R_) else G
|
|
assert R is not None
|
|
e = int_from_bytes(tagged_hash('BIP0340/challenge', xbytes(R) + xbytes(Q) + msg)) % n
|
|
return (Q, gacc, tacc, b, R, e)
|
|
|
|
def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int:
|
|
(_, pubkeys, _, _, _) = session_ctx
|
|
pk = PlainPk(cbytes(P))
|
|
if pk not in pubkeys:
|
|
raise ValueError('The signer\'s pubkey must be included in the list of pubkeys.')
|
|
return key_agg_coeff(pubkeys, pk)
|
|
|
|
def sign(secnonce: bytearray, sk: bytes, session_ctx: SessionContext) -> bytes:
|
|
(Q, gacc, _, b, R, e) = get_session_values(session_ctx)
|
|
k_1_ = int_from_bytes(secnonce[0:32])
|
|
k_2_ = int_from_bytes(secnonce[32:64])
|
|
# Overwrite the secnonce argument with zeros such that subsequent calls of
|
|
# sign with the same secnonce raise a ValueError.
|
|
secnonce[:64] = bytearray(b'\x00'*64)
|
|
if not 0 < k_1_ < n:
|
|
raise ValueError('first secnonce value is out of range.')
|
|
if not 0 < k_2_ < n:
|
|
raise ValueError('second secnonce value is out of range.')
|
|
k_1 = k_1_ if has_even_y(R) else n - k_1_
|
|
k_2 = k_2_ if has_even_y(R) else n - k_2_
|
|
d_ = int_from_bytes(sk)
|
|
if not 0 < d_ < n:
|
|
raise ValueError('secret key value is out of range.')
|
|
P = point_mul(G, d_)
|
|
assert P is not None
|
|
pk = cbytes(P)
|
|
if not pk == secnonce[64:97]:
|
|
raise ValueError('Public key does not match nonce_gen argument')
|
|
a = get_session_key_agg_coeff(session_ctx, P)
|
|
g = 1 if has_even_y(Q) else n - 1
|
|
d = g * gacc * d_ % n
|
|
s = (k_1 + b * k_2 + e * a * d) % n
|
|
psig = bytes_from_int(s)
|
|
R_s1 = point_mul(G, k_1_)
|
|
R_s2 = point_mul(G, k_2_)
|
|
assert R_s1 is not None
|
|
assert R_s2 is not None
|
|
pubnonce = cbytes(R_s1) + cbytes(R_s2)
|
|
# Optional correctness check. The result of signing should pass signature verification.
|
|
assert partial_sig_verify_internal(psig, pubnonce, pk, session_ctx)
|
|
return psig
|
|
|
|
def det_nonce_hash(sk_: bytes, aggothernonce: bytes, aggpk: bytes, msg: bytes, i: int) -> int:
|
|
buf = b''
|
|
buf += sk_
|
|
buf += aggothernonce
|
|
buf += aggpk
|
|
buf += len(msg).to_bytes(8, 'big')
|
|
buf += msg
|
|
buf += i.to_bytes(1, 'big')
|
|
return int_from_bytes(tagged_hash('MuSig/deterministic/nonce', buf))
|
|
|
|
def deterministic_sign(sk: bytes, aggothernonce: bytes, pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, rand: Optional[bytes]) -> Tuple[bytes, bytes]:
|
|
if rand is not None:
|
|
sk_ = bytes_xor(sk, tagged_hash('MuSig/aux', rand))
|
|
else:
|
|
sk_ = sk
|
|
aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly))
|
|
|
|
k_1 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 0) % n
|
|
k_2 = det_nonce_hash(sk_, aggothernonce, aggpk, msg, 1) % n
|
|
# k_1 == 0 or k_2 == 0 cannot occur except with negligible probability.
|
|
assert k_1 != 0
|
|
assert k_2 != 0
|
|
|
|
R_s1 = point_mul(G, k_1)
|
|
R_s2 = point_mul(G, k_2)
|
|
assert R_s1 is not None
|
|
assert R_s2 is not None
|
|
pubnonce = cbytes(R_s1) + cbytes(R_s2)
|
|
secnonce = bytearray(bytes_from_int(k_1) + bytes_from_int(k_2) + individual_pk(sk))
|
|
try:
|
|
aggnonce = nonce_agg([pubnonce, aggothernonce])
|
|
except Exception:
|
|
raise InvalidContributionError(None, "aggothernonce")
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
psig = sign(secnonce, sk, session_ctx)
|
|
return (pubnonce, psig)
|
|
|
|
def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[PlainPk], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, i: int) -> bool:
|
|
if len(pubnonces) != len(pubkeys):
|
|
raise ValueError('The `pubnonces` and `pubkeys` arrays must have the same length.')
|
|
if len(tweaks) != len(is_xonly):
|
|
raise ValueError('The `tweaks` and `is_xonly` arrays must have the same length.')
|
|
aggnonce = nonce_agg(pubnonces)
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx)
|
|
|
|
def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session_ctx: SessionContext) -> bool:
|
|
(Q, gacc, _, b, R, e) = get_session_values(session_ctx)
|
|
s = int_from_bytes(psig)
|
|
if s >= n:
|
|
return False
|
|
R_s1 = cpoint(pubnonce[0:33])
|
|
R_s2 = cpoint(pubnonce[33:66])
|
|
Re_s_ = point_add(R_s1, point_mul(R_s2, b))
|
|
Re_s = Re_s_ if has_even_y(R) else point_negate(Re_s_)
|
|
P = cpoint(pk)
|
|
a = get_session_key_agg_coeff(session_ctx, P)
|
|
g = 1 if has_even_y(Q) else n - 1
|
|
g_ = g * gacc % n
|
|
return point_mul(G, s) == point_add(Re_s, point_mul(P, e * a * g_ % n))
|
|
|
|
def partial_sig_agg(psigs: List[bytes], session_ctx: SessionContext) -> bytes:
|
|
(Q, _, tacc, _, R, e) = get_session_values(session_ctx)
|
|
s = 0
|
|
u = len(psigs)
|
|
for i in range(u):
|
|
s_i = int_from_bytes(psigs[i])
|
|
if s_i >= n:
|
|
raise InvalidContributionError(i, "psig")
|
|
s = (s + s_i) % n
|
|
g = 1 if has_even_y(Q) else n - 1
|
|
s = (s + e * g * tacc) % n
|
|
return xbytes(R) + bytes_from_int(s)
|
|
#
|
|
# The following code is only used for testing.
|
|
#
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
def fromhex_all(l):
|
|
return [bytes.fromhex(l_i) for l_i in l]
|
|
|
|
# Check that calling `try_fn` raises a `exception`. If `exception` is raised,
|
|
# examine it with `except_fn`.
|
|
def assert_raises(exception, try_fn, except_fn):
|
|
raised = False
|
|
try:
|
|
try_fn()
|
|
except exception as e:
|
|
raised = True
|
|
assert(except_fn(e))
|
|
except BaseException:
|
|
raise AssertionError("Wrong exception raised in a test.")
|
|
if not raised:
|
|
raise AssertionError("Exception was _not_ raised in a test where it was required.")
|
|
|
|
def get_error_details(test_case):
|
|
error = test_case["error"]
|
|
if error["type"] == "invalid_contribution":
|
|
exception = InvalidContributionError
|
|
if "contrib" in error:
|
|
except_fn = lambda e: e.signer == error["signer"] and e.contrib == error["contrib"]
|
|
else:
|
|
except_fn = lambda e: e.signer == error["signer"]
|
|
elif error["type"] == "value":
|
|
exception = ValueError
|
|
except_fn = lambda e: str(e) == error["message"]
|
|
else:
|
|
raise RuntimeError(f"Invalid error type: {error['type']}")
|
|
return exception, except_fn
|
|
|
|
def test_key_sort_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'key_sort_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
X = fromhex_all(test_data["pubkeys"])
|
|
X_sorted = fromhex_all(test_data["sorted_pubkeys"])
|
|
|
|
assert key_sort(X) == X_sorted
|
|
|
|
def test_key_agg_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'key_agg_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
X = fromhex_all(test_data["pubkeys"])
|
|
T = fromhex_all(test_data["tweaks"])
|
|
valid_test_cases = test_data["valid_test_cases"]
|
|
error_test_cases = test_data["error_test_cases"]
|
|
|
|
for test_case in valid_test_cases:
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
expected = bytes.fromhex(test_case["expected"])
|
|
|
|
assert get_xonly_pk(key_agg(pubkeys)) == expected
|
|
|
|
for test_case in error_test_cases:
|
|
exception, except_fn = get_error_details(test_case)
|
|
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
tweaks = [T[i] for i in test_case["tweak_indices"]]
|
|
is_xonly = test_case["is_xonly"]
|
|
|
|
assert_raises(exception, lambda: key_agg_and_tweak(pubkeys, tweaks, is_xonly), except_fn)
|
|
|
|
def test_nonce_gen_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'nonce_gen_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
for test_case in test_data["test_cases"]:
|
|
def get_value(key) -> bytes:
|
|
return bytes.fromhex(test_case[key])
|
|
|
|
def get_value_maybe(key) -> Optional[bytes]:
|
|
if test_case[key] is not None:
|
|
return get_value(key)
|
|
else:
|
|
return None
|
|
|
|
rand_ = get_value("rand_")
|
|
sk = get_value_maybe("sk")
|
|
pk = PlainPk(get_value("pk"))
|
|
aggpk = get_value_maybe("aggpk")
|
|
if aggpk is not None:
|
|
aggpk = XonlyPk(aggpk)
|
|
msg = get_value_maybe("msg")
|
|
extra_in = get_value_maybe("extra_in")
|
|
expected_secnonce = get_value("expected_secnonce")
|
|
expected_pubnonce = get_value("expected_pubnonce")
|
|
|
|
assert nonce_gen_internal(rand_, sk, pk, aggpk, msg, extra_in) == (expected_secnonce, expected_pubnonce)
|
|
|
|
def test_nonce_agg_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'nonce_agg_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
pnonce = fromhex_all(test_data["pnonces"])
|
|
valid_test_cases = test_data["valid_test_cases"]
|
|
error_test_cases = test_data["error_test_cases"]
|
|
|
|
for test_case in valid_test_cases:
|
|
pubnonces = [pnonce[i] for i in test_case["pnonce_indices"]]
|
|
expected = bytes.fromhex(test_case["expected"])
|
|
assert nonce_agg(pubnonces) == expected
|
|
|
|
for test_case in error_test_cases:
|
|
exception, except_fn = get_error_details(test_case)
|
|
pubnonces = [pnonce[i] for i in test_case["pnonce_indices"]]
|
|
assert_raises(exception, lambda: nonce_agg(pubnonces), except_fn)
|
|
|
|
def test_sign_verify_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'sign_verify_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
sk = bytes.fromhex(test_data["sk"])
|
|
X = fromhex_all(test_data["pubkeys"])
|
|
# The public key corresponding to sk is at index 0
|
|
assert X[0] == individual_pk(sk)
|
|
|
|
secnonces = fromhex_all(test_data["secnonces"])
|
|
pnonce = fromhex_all(test_data["pnonces"])
|
|
# The public nonce corresponding to secnonces[0] is at index 0
|
|
k_1 = int_from_bytes(secnonces[0][0:32])
|
|
k_2 = int_from_bytes(secnonces[0][32:64])
|
|
R_s1 = point_mul(G, k_1)
|
|
R_s2 = point_mul(G, k_2)
|
|
assert R_s1 is not None and R_s2 is not None
|
|
assert pnonce[0] == cbytes(R_s1) + cbytes(R_s2)
|
|
|
|
aggnonces = fromhex_all(test_data["aggnonces"])
|
|
# The aggregate of the first three elements of pnonce is at index 0
|
|
assert (aggnonces[0] == nonce_agg([pnonce[0], pnonce[1], pnonce[2]]))
|
|
# The aggregate of the first and fourth elements of pnonce is at index 1,
|
|
# which is the infinity point encoded as a zeroed 33-byte array
|
|
assert (aggnonces[1] == nonce_agg([pnonce[0], pnonce[3]]))
|
|
|
|
msgs = fromhex_all(test_data["msgs"])
|
|
|
|
valid_test_cases = test_data["valid_test_cases"]
|
|
sign_error_test_cases = test_data["sign_error_test_cases"]
|
|
verify_fail_test_cases = test_data["verify_fail_test_cases"]
|
|
verify_error_test_cases = test_data["verify_error_test_cases"]
|
|
|
|
for test_case in valid_test_cases:
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
pubnonces = [pnonce[i] for i in test_case["nonce_indices"]]
|
|
aggnonce = aggnonces[test_case["aggnonce_index"]]
|
|
# Make sure that pubnonces and aggnonce in the test vector are
|
|
# consistent
|
|
assert nonce_agg(pubnonces) == aggnonce
|
|
msg = msgs[test_case["msg_index"]]
|
|
signer_index = test_case["signer_index"]
|
|
expected = bytes.fromhex(test_case["expected"])
|
|
|
|
session_ctx = SessionContext(aggnonce, pubkeys, [], [], msg)
|
|
# WARNING: An actual implementation should _not_ copy the secnonce.
|
|
# Reusing the secnonce, as we do here for testing purposes, can leak the
|
|
# secret key.
|
|
secnonce_tmp = bytearray(secnonces[0])
|
|
assert sign(secnonce_tmp, sk, session_ctx) == expected
|
|
assert partial_sig_verify(expected, pubnonces, pubkeys, [], [], msg, signer_index)
|
|
|
|
for test_case in sign_error_test_cases:
|
|
exception, except_fn = get_error_details(test_case)
|
|
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
aggnonce = aggnonces[test_case["aggnonce_index"]]
|
|
msg = msgs[test_case["msg_index"]]
|
|
secnonce = bytearray(secnonces[test_case["secnonce_index"]])
|
|
|
|
session_ctx = SessionContext(aggnonce, pubkeys, [], [], msg)
|
|
assert_raises(exception, lambda: sign(secnonce, sk, session_ctx), except_fn)
|
|
|
|
for test_case in verify_fail_test_cases:
|
|
sig = bytes.fromhex(test_case["sig"])
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
pubnonces = [pnonce[i] for i in test_case["nonce_indices"]]
|
|
msg = msgs[test_case["msg_index"]]
|
|
signer_index = test_case["signer_index"]
|
|
|
|
assert not partial_sig_verify(sig, pubnonces, pubkeys, [], [], msg, signer_index)
|
|
|
|
for test_case in verify_error_test_cases:
|
|
exception, except_fn = get_error_details(test_case)
|
|
|
|
sig = bytes.fromhex(test_case["sig"])
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
pubnonces = [pnonce[i] for i in test_case["nonce_indices"]]
|
|
msg = msgs[test_case["msg_index"]]
|
|
signer_index = test_case["signer_index"]
|
|
|
|
assert_raises(exception, lambda: partial_sig_verify(sig, pubnonces, pubkeys, [], [], msg, signer_index), except_fn)
|
|
|
|
def test_tweak_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'tweak_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
sk = bytes.fromhex(test_data["sk"])
|
|
X = fromhex_all(test_data["pubkeys"])
|
|
# The public key corresponding to sk is at index 0
|
|
assert X[0] == individual_pk(sk)
|
|
|
|
secnonce = bytearray(bytes.fromhex(test_data["secnonce"]))
|
|
pnonce = fromhex_all(test_data["pnonces"])
|
|
# The public nonce corresponding to secnonce is at index 0
|
|
k_1 = int_from_bytes(secnonce[0:32])
|
|
k_2 = int_from_bytes(secnonce[32:64])
|
|
R_s1 = point_mul(G, k_1)
|
|
R_s2 = point_mul(G, k_2)
|
|
assert R_s1 is not None and R_s2 is not None
|
|
assert pnonce[0] == cbytes(R_s1) + cbytes(R_s2)
|
|
|
|
aggnonce = bytes.fromhex(test_data["aggnonce"])
|
|
# The aggnonce is the aggregate of the first three elements of pnonce
|
|
assert(aggnonce == nonce_agg([pnonce[0], pnonce[1], pnonce[2]]))
|
|
|
|
tweak = fromhex_all(test_data["tweaks"])
|
|
msg = bytes.fromhex(test_data["msg"])
|
|
|
|
valid_test_cases = test_data["valid_test_cases"]
|
|
error_test_cases = test_data["error_test_cases"]
|
|
|
|
for test_case in valid_test_cases:
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
pubnonces = [pnonce[i] for i in test_case["nonce_indices"]]
|
|
tweaks = [tweak[i] for i in test_case["tweak_indices"]]
|
|
is_xonly = test_case["is_xonly"]
|
|
signer_index = test_case["signer_index"]
|
|
expected = bytes.fromhex(test_case["expected"])
|
|
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
secnonce_tmp = bytearray(secnonce)
|
|
# WARNING: An actual implementation should _not_ copy the secnonce.
|
|
# Reusing the secnonce, as we do here for testing purposes, can leak the
|
|
# secret key.
|
|
assert sign(secnonce_tmp, sk, session_ctx) == expected
|
|
assert partial_sig_verify(expected, pubnonces, pubkeys, tweaks, is_xonly, msg, signer_index)
|
|
|
|
for test_case in error_test_cases:
|
|
exception, except_fn = get_error_details(test_case)
|
|
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
pubnonces = [pnonce[i] for i in test_case["nonce_indices"]]
|
|
tweaks = [tweak[i] for i in test_case["tweak_indices"]]
|
|
is_xonly = test_case["is_xonly"]
|
|
signer_index = test_case["signer_index"]
|
|
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
assert_raises(exception, lambda: sign(secnonce, sk, session_ctx), except_fn)
|
|
|
|
def test_det_sign_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'det_sign_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
sk = bytes.fromhex(test_data["sk"])
|
|
X = fromhex_all(test_data["pubkeys"])
|
|
# The public key corresponding to sk is at index 0
|
|
assert X[0] == individual_pk(sk)
|
|
|
|
msgs = fromhex_all(test_data["msgs"])
|
|
|
|
valid_test_cases = test_data["valid_test_cases"]
|
|
error_test_cases = test_data["error_test_cases"]
|
|
|
|
for test_case in valid_test_cases:
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
aggothernonce = bytes.fromhex(test_case["aggothernonce"])
|
|
tweaks = fromhex_all(test_case["tweaks"])
|
|
is_xonly = test_case["is_xonly"]
|
|
msg = msgs[test_case["msg_index"]]
|
|
signer_index = test_case["signer_index"]
|
|
rand = bytes.fromhex(test_case["rand"]) if test_case["rand"] is not None else None
|
|
expected = fromhex_all(test_case["expected"])
|
|
|
|
pubnonce, psig = deterministic_sign(sk, aggothernonce, pubkeys, tweaks, is_xonly, msg, rand)
|
|
assert pubnonce == expected[0]
|
|
assert psig == expected[1]
|
|
|
|
pubnonces = [aggothernonce, pubnonce]
|
|
aggnonce = nonce_agg(pubnonces)
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
assert partial_sig_verify_internal(psig, pubnonce, pubkeys[signer_index], session_ctx)
|
|
|
|
for test_case in error_test_cases:
|
|
exception, except_fn = get_error_details(test_case)
|
|
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
aggothernonce = bytes.fromhex(test_case["aggothernonce"])
|
|
tweaks = fromhex_all(test_case["tweaks"])
|
|
is_xonly = test_case["is_xonly"]
|
|
msg = msgs[test_case["msg_index"]]
|
|
signer_index = test_case["signer_index"]
|
|
rand = bytes.fromhex(test_case["rand"]) if test_case["rand"] is not None else None
|
|
|
|
try_fn = lambda: deterministic_sign(sk, aggothernonce, pubkeys, tweaks, is_xonly, msg, rand)
|
|
assert_raises(exception, try_fn, except_fn)
|
|
|
|
def test_sig_agg_vectors() -> None:
|
|
with open(os.path.join(sys.path[0], 'vectors', 'sig_agg_vectors.json')) as f:
|
|
test_data = json.load(f)
|
|
|
|
X = fromhex_all(test_data["pubkeys"])
|
|
|
|
# These nonces are only required if the tested API takes the individual
|
|
# nonces and not the aggregate nonce.
|
|
pnonce = fromhex_all(test_data["pnonces"])
|
|
|
|
tweak = fromhex_all(test_data["tweaks"])
|
|
psig = fromhex_all(test_data["psigs"])
|
|
|
|
msg = bytes.fromhex(test_data["msg"])
|
|
|
|
valid_test_cases = test_data["valid_test_cases"]
|
|
error_test_cases = test_data["error_test_cases"]
|
|
|
|
for test_case in valid_test_cases:
|
|
pubnonces = [pnonce[i] for i in test_case["nonce_indices"]]
|
|
aggnonce = bytes.fromhex(test_case["aggnonce"])
|
|
assert aggnonce == nonce_agg(pubnonces)
|
|
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
tweaks = [tweak[i] for i in test_case["tweak_indices"]]
|
|
is_xonly = test_case["is_xonly"]
|
|
psigs = [psig[i] for i in test_case["psig_indices"]]
|
|
expected = bytes.fromhex(test_case["expected"])
|
|
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
sig = partial_sig_agg(psigs, session_ctx)
|
|
assert sig == expected
|
|
aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly))
|
|
assert schnorr_verify(msg, aggpk, sig)
|
|
|
|
for test_case in error_test_cases:
|
|
exception, except_fn = get_error_details(test_case)
|
|
|
|
pubnonces = [pnonce[i] for i in test_case["nonce_indices"]]
|
|
aggnonce = nonce_agg(pubnonces)
|
|
|
|
pubkeys = [X[i] for i in test_case["key_indices"]]
|
|
tweaks = [tweak[i] for i in test_case["tweak_indices"]]
|
|
is_xonly = test_case["is_xonly"]
|
|
psigs = [psig[i] for i in test_case["psig_indices"]]
|
|
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
assert_raises(exception, lambda: partial_sig_agg(psigs, session_ctx), except_fn)
|
|
|
|
def test_sign_and_verify_random(iters: int) -> None:
|
|
for i in range(iters):
|
|
sk_1 = secrets.token_bytes(32)
|
|
sk_2 = secrets.token_bytes(32)
|
|
pk_1 = individual_pk(sk_1)
|
|
pk_2 = individual_pk(sk_2)
|
|
pubkeys = [pk_1, pk_2]
|
|
|
|
# In this example, the message and aggregate pubkey are known
|
|
# before nonce generation, so they can be passed into the nonce
|
|
# generation function as a defense-in-depth measure to protect
|
|
# against nonce reuse.
|
|
#
|
|
# If these values are not known when nonce_gen is called, empty
|
|
# byte arrays can be passed in for the corresponding arguments
|
|
# instead.
|
|
msg = secrets.token_bytes(32)
|
|
v = secrets.randbelow(4)
|
|
tweaks = [secrets.token_bytes(32) for _ in range(v)]
|
|
is_xonly = [secrets.choice([False, True]) for _ in range(v)]
|
|
aggpk = get_xonly_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly))
|
|
|
|
# Use a non-repeating counter for extra_in
|
|
secnonce_1, pubnonce_1 = nonce_gen(sk_1, pk_1, aggpk, msg, i.to_bytes(4, 'big'))
|
|
|
|
# On even iterations use regular signing algorithm for signer 2,
|
|
# otherwise use deterministic signing algorithm
|
|
if i % 2 == 0:
|
|
# Use a clock for extra_in
|
|
t = time.clock_gettime_ns(time.CLOCK_MONOTONIC)
|
|
secnonce_2, pubnonce_2 = nonce_gen(sk_2, pk_2, aggpk, msg, t.to_bytes(8, 'big'))
|
|
else:
|
|
aggothernonce = nonce_agg([pubnonce_1])
|
|
rand = secrets.token_bytes(32)
|
|
pubnonce_2, psig_2 = deterministic_sign(sk_2, aggothernonce, pubkeys, tweaks, is_xonly, msg, rand)
|
|
|
|
pubnonces = [pubnonce_1, pubnonce_2]
|
|
aggnonce = nonce_agg(pubnonces)
|
|
|
|
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
|
|
psig_1 = sign(secnonce_1, sk_1, session_ctx)
|
|
assert partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 0)
|
|
# An exception is thrown if secnonce_1 is accidentally reused
|
|
assert_raises(ValueError, lambda: sign(secnonce_1, sk_1, session_ctx), lambda e: True)
|
|
|
|
# Wrong signer index
|
|
assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 1)
|
|
|
|
# Wrong message
|
|
assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, secrets.token_bytes(32), 0)
|
|
|
|
if i % 2 == 0:
|
|
psig_2 = sign(secnonce_2, sk_2, session_ctx)
|
|
assert partial_sig_verify(psig_2, pubnonces, pubkeys, tweaks, is_xonly, msg, 1)
|
|
|
|
sig = partial_sig_agg([psig_1, psig_2], session_ctx)
|
|
assert schnorr_verify(msg, aggpk, sig)
|
|
|
|
if __name__ == '__main__':
|
|
test_key_sort_vectors()
|
|
test_key_agg_vectors()
|
|
test_nonce_gen_vectors()
|
|
test_nonce_agg_vectors()
|
|
test_sign_verify_vectors()
|
|
test_tweak_vectors()
|
|
test_det_sign_vectors()
|
|
test_sig_agg_vectors()
|
|
test_sign_and_verify_random(6)
|