diff --git a/bip-DLEQ.mediawiki b/bip-DLEQ.mediawiki index bb2735f0..726daea1 100644 --- a/bip-DLEQ.mediawiki +++ b/bip-DLEQ.mediawiki @@ -38,6 +38,7 @@ Input: * The public key ''B'': a point on the curve * The generator point ''G'': a point on the curve * Auxiliary random data ''r'': a 32-byte array +* An optional message ''m'': a 32-byte array The algorithm ''GenerateProof(a, B, r)'' is defined as: * Fail if ''a = 0'' or ''a ≥ n''. @@ -50,7 +51,8 @@ The algorithm ''GenerateProof(a, B, r)'' is defined as: * Fail if ''k = 0''. * Let ''R1 = k⋅G''. * Let ''R2 = k⋅B''. -* Let ''e = int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2)))''. +* Let ''m' = m if m is provided, otherwise an empty byte array''. +* Let ''e = int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2) || cbytes(m')))''. * Let ''s = (k + e⋅a) mod n''. * Let ''proof = bytes(32, e) || bytes(32, s)''. * If ''VerifyProof(A, B, C, proof)'' (see below) returns failure, abort. @@ -64,6 +66,7 @@ Input: * The result of multiplying the secret and public keys used in the proof generation ''C'': a point on the curve * The generator point used in the proof generation ''G'': a point on the curve * A proof ''proof'': a 64-byte array +* An optional message ''m'': a 32-byte array The algorithm ''VerifyProof(A, B, C, G, proof)'' is defined as: * Let ''e = int(proof[0:32])''. @@ -72,7 +75,8 @@ The algorithm ''VerifyProof(A, B, C, G, proof)'' is defined as: * Fail if ''is_infinite(R1)''. * Let ''R2 = s⋅B - e⋅C''. * Fail if ''is_infinite(R2)''. -* Fail if ''e ≠ int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2)))''. +* Let ''m' = m if m is provided, otherwise an empty byte array''. +* Fail if ''e ≠ int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2) || cbytes(m')))''. * Return success iff no failure occurred before reaching this point. == Test Vectors and Reference Code == diff --git a/bip-DLEQ/reference.py b/bip-DLEQ/reference.py index f508776c..231617ac 100644 --- a/bip-DLEQ/reference.py +++ b/bip-DLEQ/reference.py @@ -24,13 +24,30 @@ def xor_bytes(lhs: bytes, rhs: bytes) -> bytes: return bytes([lhs[i] ^ rhs[i] for i in range(len(lhs))]) -def dleq_challenge(A: GE, B: GE, C: GE, R1: GE, R2: GE) -> int: - return int.from_bytes(TaggedHash(DLEQ_TAG_CHALLENGE, - A.to_bytes_compressed() + B.to_bytes_compressed() + C.to_bytes_compressed() + - R1.to_bytes_compressed() + R2.to_bytes_compressed()), 'big') +def dleq_challenge( + A: GE, B: GE, C: GE, R1: GE, R2: GE, G: GE = G, m: bytes | None = None +) -> int: + if m is not None: + assert len(m) == 32 + m = bytes([]) if m is None else m.to_bytes(32, "big") + return int.from_bytes( + TaggedHash( + DLEQ_TAG_CHALLENGE, + A.to_bytes_compressed() + + B.to_bytes_compressed() + + C.to_bytes_compressed() + + G.to_bytes_compressed() + + R1.to_bytes_compressed() + + R2.to_bytes_compressed() + + m, + ), + "big", + ) -def dleq_generate_proof(a: int, B: GE, r: bytes) -> bytes | None: +def dleq_generate_proof( + a: int, B: GE, r: bytes, G: GE = G, m: bytes | None = None +) -> bytes | None: assert len(r) == 32 if not (0 < a < GE.ORDER): return None @@ -38,25 +55,29 @@ def dleq_generate_proof(a: int, B: GE, r: bytes) -> bytes | None: return None A = a * G C = a * B - t = xor_bytes(a.to_bytes(32, 'big'), TaggedHash(DLEQ_TAG_AUX, r)) - rand = TaggedHash(DLEQ_TAG_NONCE, t + A.to_bytes_compressed() + C.to_bytes_compressed()) - k = int.from_bytes(rand, 'big') % GE.ORDER + t = xor_bytes(a.to_bytes(32, "big"), TaggedHash(DLEQ_TAG_AUX, r)) + rand = TaggedHash( + DLEQ_TAG_NONCE, t + A.to_bytes_compressed() + C.to_bytes_compressed() + ) + k = int.from_bytes(rand, "big") % GE.ORDER if k == 0: return None R1 = k * G R2 = k * B e = dleq_challenge(A, B, C, R1, R2) s = (k + e * a) % GE.ORDER - proof = e.to_bytes(32, 'big') + s.to_bytes(32, 'big') + proof = e.to_bytes(32, "big") + s.to_bytes(32, "big") if not dleq_verify_proof(A, B, C, proof): return None return proof -def dleq_verify_proof(A: GE, B: GE, C: GE, proof: bytes) -> bool: +def dleq_verify_proof( + A: GE, B: GE, C: GE, proof: bytes, G: GE = G, m: bytes | None = None +) -> bool: assert len(proof) == 64 - e = int.from_bytes(proof[:32], 'big') - s = int.from_bytes(proof[32:], 'big') + e = int.from_bytes(proof[:32], "big") + s = int.from_bytes(proof[32:], "big") if s >= GE.ORDER: return False # TODO: implement subtraction operator (__sub__) for GE class to simplify these terms @@ -97,6 +118,25 @@ class DLEQTests(unittest.TestCase): # flip a random bit in the dleq proof and check that verification fails for _ in range(5): proof_damaged = list(proof) - proof_damaged[random.randrange(len(proof))] ^= (1 << (random.randrange(8))) + proof_damaged[random.randrange(len(proof))] ^= 1 << ( + random.randrange(8) + ) + success = dleq_verify_proof(A, B, C, bytes(proof_damaged)) + self.assertFalse(success) + + # create the same dleq proof with a message + message = random.randbytes(32) + proof = dleq_generate_proof(a, B, rand_aux, m=message) + self.assertTrue(proof is not None) + # verify dleq proof with a message + success = dleq_verify_proof(A, B, C, proof, m=message) + self.assertTrue(success) + + # flip a random bit in the dleq proof and check that verification fails + for _ in range(5): + proof_damaged = list(proof) + proof_damaged[random.randrange(len(proof))] ^= 1 << ( + random.randrange(8) + ) success = dleq_verify_proof(A, B, C, bytes(proof_damaged)) self.assertFalse(success)