diff --git a/bip-DLEQ.mediawiki b/bip-DLEQ.mediawiki
index bb2735f0..726daea1 100644
--- a/bip-DLEQ.mediawiki
+++ b/bip-DLEQ.mediawiki
@@ -38,6 +38,7 @@ Input:
* The public key ''B'': a point on the curve
* The generator point ''G'': a point on the curve
* Auxiliary random data ''r'': a 32-byte array
+* An optional message ''m'': a 32-byte array
The algorithm ''GenerateProof(a, B, r)'' is defined as:
* Fail if ''a = 0'' or ''a ≥ n''.
@@ -50,7 +51,8 @@ The algorithm ''GenerateProof(a, B, r)'' is defined as:
* Fail if ''k = 0''.
* Let ''R1 = k⋅G''.
* Let ''R2 = k⋅B''.
-* Let ''e = int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2)))''.
+* Let ''m' = m if m is provided, otherwise an empty byte array''.
+* Let ''e = int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2) || cbytes(m')))''.
* Let ''s = (k + e⋅a) mod n''.
* Let ''proof = bytes(32, e) || bytes(32, s)''.
* If ''VerifyProof(A, B, C, proof)'' (see below) returns failure, abort.
@@ -64,6 +66,7 @@ Input:
* The result of multiplying the secret and public keys used in the proof generation ''C'': a point on the curve
* The generator point used in the proof generation ''G'': a point on the curve
* A proof ''proof'': a 64-byte array
+* An optional message ''m'': a 32-byte array
The algorithm ''VerifyProof(A, B, C, G, proof)'' is defined as:
* Let ''e = int(proof[0:32])''.
@@ -72,7 +75,8 @@ The algorithm ''VerifyProof(A, B, C, G, proof)'' is defined as:
* Fail if ''is_infinite(R1)''.
* Let ''R2 = s⋅B - e⋅C''.
* Fail if ''is_infinite(R2)''.
-* Fail if ''e ≠ int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2)))''.
+* Let ''m' = m if m is provided, otherwise an empty byte array''.
+* Fail if ''e ≠ int(hashBIP0???/challenge(cbytes(A) || cbytes(B) || cbytes(C) || cbytes(G) || cbytes(R1) || cbytes(R2) || cbytes(m')))''.
* Return success iff no failure occurred before reaching this point.
== Test Vectors and Reference Code ==
diff --git a/bip-DLEQ/reference.py b/bip-DLEQ/reference.py
index f508776c..231617ac 100644
--- a/bip-DLEQ/reference.py
+++ b/bip-DLEQ/reference.py
@@ -24,13 +24,30 @@ def xor_bytes(lhs: bytes, rhs: bytes) -> bytes:
return bytes([lhs[i] ^ rhs[i] for i in range(len(lhs))])
-def dleq_challenge(A: GE, B: GE, C: GE, R1: GE, R2: GE) -> int:
- return int.from_bytes(TaggedHash(DLEQ_TAG_CHALLENGE,
- A.to_bytes_compressed() + B.to_bytes_compressed() + C.to_bytes_compressed() +
- R1.to_bytes_compressed() + R2.to_bytes_compressed()), 'big')
+def dleq_challenge(
+ A: GE, B: GE, C: GE, R1: GE, R2: GE, G: GE = G, m: bytes | None = None
+) -> int:
+ if m is not None:
+ assert len(m) == 32
+ m = bytes([]) if m is None else m.to_bytes(32, "big")
+ return int.from_bytes(
+ TaggedHash(
+ DLEQ_TAG_CHALLENGE,
+ A.to_bytes_compressed()
+ + B.to_bytes_compressed()
+ + C.to_bytes_compressed()
+ + G.to_bytes_compressed()
+ + R1.to_bytes_compressed()
+ + R2.to_bytes_compressed()
+ + m,
+ ),
+ "big",
+ )
-def dleq_generate_proof(a: int, B: GE, r: bytes) -> bytes | None:
+def dleq_generate_proof(
+ a: int, B: GE, r: bytes, G: GE = G, m: bytes | None = None
+) -> bytes | None:
assert len(r) == 32
if not (0 < a < GE.ORDER):
return None
@@ -38,25 +55,29 @@ def dleq_generate_proof(a: int, B: GE, r: bytes) -> bytes | None:
return None
A = a * G
C = a * B
- t = xor_bytes(a.to_bytes(32, 'big'), TaggedHash(DLEQ_TAG_AUX, r))
- rand = TaggedHash(DLEQ_TAG_NONCE, t + A.to_bytes_compressed() + C.to_bytes_compressed())
- k = int.from_bytes(rand, 'big') % GE.ORDER
+ t = xor_bytes(a.to_bytes(32, "big"), TaggedHash(DLEQ_TAG_AUX, r))
+ rand = TaggedHash(
+ DLEQ_TAG_NONCE, t + A.to_bytes_compressed() + C.to_bytes_compressed()
+ )
+ k = int.from_bytes(rand, "big") % GE.ORDER
if k == 0:
return None
R1 = k * G
R2 = k * B
e = dleq_challenge(A, B, C, R1, R2)
s = (k + e * a) % GE.ORDER
- proof = e.to_bytes(32, 'big') + s.to_bytes(32, 'big')
+ proof = e.to_bytes(32, "big") + s.to_bytes(32, "big")
if not dleq_verify_proof(A, B, C, proof):
return None
return proof
-def dleq_verify_proof(A: GE, B: GE, C: GE, proof: bytes) -> bool:
+def dleq_verify_proof(
+ A: GE, B: GE, C: GE, proof: bytes, G: GE = G, m: bytes | None = None
+) -> bool:
assert len(proof) == 64
- e = int.from_bytes(proof[:32], 'big')
- s = int.from_bytes(proof[32:], 'big')
+ e = int.from_bytes(proof[:32], "big")
+ s = int.from_bytes(proof[32:], "big")
if s >= GE.ORDER:
return False
# TODO: implement subtraction operator (__sub__) for GE class to simplify these terms
@@ -97,6 +118,25 @@ class DLEQTests(unittest.TestCase):
# flip a random bit in the dleq proof and check that verification fails
for _ in range(5):
proof_damaged = list(proof)
- proof_damaged[random.randrange(len(proof))] ^= (1 << (random.randrange(8)))
+ proof_damaged[random.randrange(len(proof))] ^= 1 << (
+ random.randrange(8)
+ )
+ success = dleq_verify_proof(A, B, C, bytes(proof_damaged))
+ self.assertFalse(success)
+
+ # create the same dleq proof with a message
+ message = random.randbytes(32)
+ proof = dleq_generate_proof(a, B, rand_aux, m=message)
+ self.assertTrue(proof is not None)
+ # verify dleq proof with a message
+ success = dleq_verify_proof(A, B, C, proof, m=message)
+ self.assertTrue(success)
+
+ # flip a random bit in the dleq proof and check that verification fails
+ for _ in range(5):
+ proof_damaged = list(proof)
+ proof_damaged[random.randrange(len(proof))] ^= 1 << (
+ random.randrange(8)
+ )
success = dleq_verify_proof(A, B, C, bytes(proof_damaged))
self.assertFalse(success)