Add Equihash implementation for use in ASIC-resistant PoW

Implement the Equihash (https://eprint.iacr.org/2015/946.pdf) algorithm
for solving/verifying memory-hard client-puzzles/proof-of-work problems
for ASIC-resistant DoS attack protection. The scheme is asymmetric, so
that even though solving a puzzle is slow and memory-intensive, needing
100's of kB to MB's of memory, the solution verification is instant.

Instead of a single 64-bit counter/nonce, as in the case of Hashcash,
Equihash solutions are larger objects ranging from 10's of bytes to a
few kB, depending on the puzzle parameters used. These need to be
stored in entirety, in the proof-of-work field of each offer payload.

Include logic for fine-grained difficulty control in Equihash with a
double-precision floating point number. This is based on lexicographic
comparison with a target hash, like in Bitcoin, instead of just
counting the number of leading zeros of a hash.

The code is unused at present. Also add some simple unit tests.
This commit is contained in:
Steven Barclay 2021-11-23 06:27:33 +00:00
parent 5da8df266a
commit c2b3a078ff
No known key found for this signature in database
GPG key ID: 9FED6BF1176D500B
3 changed files with 441 additions and 0 deletions

View file

@ -0,0 +1,349 @@
/*
* This file is part of Bisq.
*
* Bisq is free software: you can redistribute it and/or modify it
* under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at
* your option) any later version.
*
* Bisq is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
* License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Bisq. If not, see <http://www.gnu.org/licenses/>.
*/
package bisq.common.crypto;
import bisq.common.util.Utilities;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.primitives.Bytes;
import com.google.common.primitives.ImmutableIntArray;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import com.google.common.primitives.UnsignedInts;
import org.bouncycastle.crypto.digests.Blake2bDigest;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.math.BigInteger;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.Optional;
import lombok.ToString;
import static com.google.common.base.Preconditions.checkArgument;
import static java.math.BigInteger.ONE;
/**
* An ASIC-resistant Proof-of-Work scheme based on the Generalized Birthday Problem (GBP),
* as described in <a href="https://eprint.iacr.org/2015/946.pdf">this paper</a> and used
* in ZCash and some other cryptocurrencies. Like
* <a href="https://en.wikipedia.org/wiki/Hashcash">Hashcash</a> but unlike many other
* memory-hard and ASIC resistant PoW schemes, it is <i>asymmetric</i>, meaning that it
* supports fast verification. This makes it suitable for DoS attack protection.<br><p>
* <br>
* The Generalized Birthday Problem is an attempt to find <i>2<sup>k</sup></i>
* <i>n</i>-bit hashes (out of a list of length <i>N</i>) which XOR to zero. When <i>N</i>
* equals <i>2<sup>1+n/(k+1)</sup></i>, this has at least a handful of solutions on
* average, which can be found using Wagner's Algorithm, as described in the paper and
* implemented here. The rough idea is to split each hash into <i>k+1</i>
* <i>n/(k+1)</i>-bit blocks and place them into a table to look for collisions on a given
* block. All (partially) colliding pairs are XORed together and used to build a new table
* of roughly the same size as the last, before moving on to the next block and repeating
* the process until a full collision can be found. Keeping track of the tuple of hashes
* used to form each intermediate XOR result (which doubles in length each iteration)
* gives the final solution. The table-based approach needed to find solutions to the GBP
* makes it a memory-hard problem.<br><p>
* <br>
* In this implementation and the reference
* <a href="https://github.com/khovratovich/equihash">C++ implementation</a> included with
* the paper, the hash function BLAKE2b is used to supply 256 bits, which is shortened and
* split into <i>k+1</i> 32-bit blocks. The blocks are masked to provide <i>n/(k+1)</i>
* bits each and <i>n</i> bits in total. This allows working with 32-bit integers
* throughout, for efficiency.
*/
@SuppressWarnings("UnstableApiUsage")
public class Equihash {
private static final int HASH_BIT_LENGTH = 256;
private final int k, N;
private final int tableCapacity;
private final int inputNum, inputBits;
private final int[] hashUpperBound;
public Equihash(int n, int k, double difficulty) {
checkArgument(k > 0 && k < HASH_BIT_LENGTH / 32,
"Tree depth k must be a positive integer less than %s.",
HASH_BIT_LENGTH / 32);
checkArgument(n > 0 && n < HASH_BIT_LENGTH && n % (k + 1) == 0,
"Collision bit count n must be a positive multiple of k + 1 and less than %s.",
HASH_BIT_LENGTH);
checkArgument(n / (k + 1) < 30,
"Sub-collision bit count n / (k + 1) must be less than 30, got %s.",
n / (k + 1));
this.k = k;
inputNum = 1 << k;
inputBits = n / (k + 1) + 1;
N = 1 << inputBits;
tableCapacity = (int) (N * 1.1);
hashUpperBound = hashUpperBound(difficulty);
}
@VisibleForTesting
static int[] hashUpperBound(double difficulty) {
return Utilities.bytesToIntsBE(Utilities.copyRightAligned(
inverseDifficultyMinusOne(difficulty).toByteArray(), HASH_BIT_LENGTH / 8
));
}
private static BigInteger inverseDifficultyMinusOne(double difficulty) {
checkArgument(difficulty >= 1.0, "Difficulty must be at least 1.");
int exponent = Math.getExponent(difficulty) - 52;
var mantissa = BigInteger.valueOf((long) Math.scalb(difficulty, -exponent));
var inverse = ONE.shiftLeft(HASH_BIT_LENGTH - exponent).add(mantissa).subtract(ONE).divide(mantissa);
return inverse.subtract(ONE).max(BigInteger.ZERO);
}
public Puzzle puzzle(byte[] seed) {
return new Puzzle(seed);
}
public class Puzzle {
private final byte[] seed;
private Puzzle(byte[] seed) {
this.seed = seed;
}
@ToString
public class Solution {
private final long nonce;
private final int[] inputs;
private Solution(long nonce, int... inputs) {
this.nonce = nonce;
this.inputs = inputs;
}
public boolean verify() {
return withHashPrefix(seed, nonce).verify(inputs);
}
public byte[] serialize() {
int bitLen = 64 + inputNum * inputBits;
int byteLen = (bitLen + 7) / 8;
byte[] paddedBytes = new byte[byteLen + 3 & -4];
IntBuffer intBuffer = ByteBuffer.wrap(paddedBytes).asIntBuffer();
intBuffer.put((int) (nonce >> 32)).put((int) nonce);
int off = 64;
long buf = 0;
for (int v : inputs) {
off -= inputBits;
buf |= UnsignedInts.toLong(v) << off;
if (off <= 32) {
intBuffer.put((int) (buf >> 32));
buf <<= 32;
off += 32;
}
}
if (off < 64) {
intBuffer.put((int) (buf >> 32));
}
return (byteLen & 3) == 0 ? paddedBytes : Arrays.copyOf(paddedBytes, byteLen);
}
}
public Solution deserializeSolution(byte[] bytes) {
int bitLen = 64 + inputNum * inputBits;
int byteLen = (bitLen + 7) / 8;
checkArgument(bytes.length == byteLen,
"Incorrect solution byte length. Expected %s but got %s.",
byteLen, bytes.length);
checkArgument(byteLen == 0 || (byte) (bytes[byteLen - 1] << ((bitLen + 7 & 7) + 1)) == 0,
"Nonzero padding bits found at end of solution byte array.");
byte[] paddedBytes = (byteLen & 3) == 0 ? bytes : Arrays.copyOf(bytes, byteLen + 3 & -4);
IntBuffer intBuffer = ByteBuffer.wrap(paddedBytes).asIntBuffer();
long nonce = ((long) intBuffer.get() << 32) | UnsignedInts.toLong(intBuffer.get());
int[] inputs = new int[inputNum];
int off = 0;
long buf = 0;
for (int i = 0; i < inputs.length; i++) {
if (off < inputBits) {
buf = buf << 32 | UnsignedInts.toLong(intBuffer.get());
off += 32;
}
off -= inputBits;
inputs[i] = (int) (buf >>> off) & (N - 1);
}
return new Solution(nonce, inputs);
}
public Solution findSolution() {
Optional<int[]> inputs;
for (int nonce = 0; ; nonce++) {
if ((inputs = withHashPrefix(seed, nonce).findInputs()).isPresent()) {
return new Solution(nonce, inputs.get());
}
}
}
}
private WithHashPrefix withHashPrefix(byte[] seed, long nonce) {
return new WithHashPrefix(Bytes.concat(seed, Longs.toByteArray(nonce)));
}
private class WithHashPrefix {
private final byte[] prefixBytes;
private WithHashPrefix(byte[] prefixBytes) {
this.prefixBytes = prefixBytes;
}
private int[] hashInputs(int... inputs) {
var digest = new Blake2bDigest(HASH_BIT_LENGTH);
digest.update(prefixBytes, 0, prefixBytes.length);
byte[] inputBytes = Utilities.intsToBytesBE(inputs);
digest.update(inputBytes, 0, inputBytes.length);
byte[] outputBytes = new byte[HASH_BIT_LENGTH / 8];
digest.doFinal(outputBytes, 0);
return Utilities.bytesToIntsBE(outputBytes);
}
Optional<int[]> findInputs() {
var table = computeAllHashes();
for (int i = 0; i < k; i++) {
table = findCollisions(table, i + 1 < k);
}
for (int i = 0; i < table.numRows; i++) {
if (table.getRow(i).stream().distinct().count() == inputNum) {
int[] inputs = sortInputs(table.getRow(i).toArray());
if (testDifficultyCondition(inputs)) {
return Optional.of(inputs);
}
}
}
return Optional.empty();
}
private XorTable computeAllHashes() {
var tableValues = ImmutableIntArray.builder((k + 2) * N);
for (int i = 0; i < N; i++) {
int[] hash = hashInputs(i);
for (int j = 0; j <= k; j++) {
tableValues.add(hash[j] & (N / 2 - 1));
}
tableValues.add(i);
}
return new XorTable(k + 1, 1, tableValues.build());
}
private boolean testDifficultyCondition(int[] inputs) {
int[] difficultyHash = hashInputs(inputs);
return UnsignedInts.lexicographicalComparator().compare(difficultyHash, hashUpperBound) <= 0;
}
boolean verify(int[] inputs) {
if (inputs.length != inputNum || Arrays.stream(inputs).distinct().count() < inputNum) {
return false;
}
if (Arrays.stream(inputs).anyMatch(i -> i < 0 || i >= N)) {
return false;
}
if (!Arrays.equals(inputs, sortInputs(inputs))) {
return false;
}
if (!testDifficultyCondition(inputs)) {
return false;
}
int[] hashBlockSums = new int[k + 1];
for (int i = 0; i < inputs.length; i++) {
int[] hash = hashInputs(inputs[i]);
for (int j = 0; j <= k; j++) {
hashBlockSums[j] ^= hash[j] & (N / 2 - 1);
}
for (int ii = i + 1 + inputNum, j = 0; (ii & 1) == 0; ii /= 2, j++) {
if (hashBlockSums[j] != 0) {
return false;
}
}
}
return true;
}
}
private static class XorTable {
private final int hashWidth, indexTupleWidth, rowWidth, numRows;
private final ImmutableIntArray values;
XorTable(int hashWidth, int indexTupleWidth, ImmutableIntArray values) {
this.hashWidth = hashWidth;
this.indexTupleWidth = indexTupleWidth;
this.values = values;
rowWidth = hashWidth + indexTupleWidth;
numRows = (values.length() + rowWidth - 1) / rowWidth;
}
ImmutableIntArray getRow(int index) {
return values.subArray(index * rowWidth, index * rowWidth + hashWidth + indexTupleWidth);
}
}
// Apply a single iteration of Wagner's Algorithm.
private XorTable findCollisions(XorTable table, boolean isPartial) {
int newHashWidth = isPartial ? table.hashWidth - 1 : 0;
int newIndexTupleWidth = table.indexTupleWidth * 2;
int newRowWidth = newHashWidth + newIndexTupleWidth;
var newTableValues = ImmutableIntArray.builder(
newRowWidth * (isPartial ? tableCapacity : 10));
ListMultimap<Integer, Integer> indexMultimap = MultimapBuilder.hashKeys().arrayListValues().build();
for (int i = 0; i < table.numRows; i++) {
var row = table.getRow(i);
var collisionIndices = indexMultimap.get(row.get(0));
collisionIndices.forEach(ii -> {
var collidingRow = table.getRow(ii);
if (isPartial) {
for (int j = 1; j < table.hashWidth; j++) {
newTableValues.add(collidingRow.get(j) ^ row.get(j));
}
} else if (!collidingRow.subArray(1, table.hashWidth).equals(row.subArray(1, table.hashWidth))) {
return;
}
newTableValues.addAll(collidingRow.subArray(table.hashWidth, collidingRow.length()));
newTableValues.addAll(row.subArray(table.hashWidth, row.length()));
});
indexMultimap.put(row.get(0), i);
}
return new XorTable(newHashWidth, newIndexTupleWidth, newTableValues.build());
}
private static int[] sortInputs(int[] inputs) {
Deque<int[]> sublistStack = new ArrayDeque<>();
int[] topSublist;
for (int input : inputs) {
topSublist = new int[]{input};
while (!sublistStack.isEmpty() && sublistStack.peek().length == topSublist.length) {
topSublist = UnsignedInts.lexicographicalComparator().compare(sublistStack.peek(), topSublist) < 0
? Ints.concat(sublistStack.pop(), topSublist)
: Ints.concat(topSublist, sublistStack.pop());
}
sublistStack.push(topSublist);
}
return sublistStack.pop();
}
}

View file

@ -42,6 +42,7 @@ import java.text.DecimalFormat;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.file.Paths;
import java.io.File;
@ -525,6 +526,26 @@ public class Utilities {
return result;
}
public static byte[] copyRightAligned(byte[] src, int newLength) {
byte[] dest = new byte[newLength];
int srcPos = Math.max(src.length - newLength, 0);
int destPos = Math.max(newLength - src.length, 0);
System.arraycopy(src, srcPos, dest, destPos, newLength - destPos);
return dest;
}
public static byte[] intsToBytesBE(int[] ints) {
byte[] bytes = new byte[ints.length * 4];
ByteBuffer.wrap(bytes).asIntBuffer().put(ints);
return bytes;
}
public static int[] bytesToIntsBE(byte[] bytes) {
int[] ints = new int[bytes.length / 4];
ByteBuffer.wrap(bytes).asIntBuffer().get(ints);
return ints;
}
// Helper to filter unique elements by key
public static <T> Predicate<T> distinctByKey(Function<? super T, Object> keyExtractor) {
Map<Object, Boolean> map = new ConcurrentHashMap<>();

View file

@ -0,0 +1,71 @@
/*
* This file is part of Bisq.
*
* Bisq is free software: you can redistribute it and/or modify it
* under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at
* your option) any later version.
*
* Bisq is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
* License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Bisq. If not, see <http://www.gnu.org/licenses/>.
*/
package bisq.common.crypto;
import bisq.common.crypto.Equihash.Puzzle.Solution;
import com.google.common.base.Strings;
import java.util.Arrays;
import java.util.stream.Collectors;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class EquihashTest {
@Test
public void testHashUpperBound() {
assertEquals("ffffffff ffffffff ffffffff ffffffff ffffffff ffffffff ffffffff ffffffff", hub(1));
assertEquals("aaaaaaaa aaaaaaaa aaaaaaaa aaaaaaaa aaaaaaaa aaaaaaaa aaaaaaaa aaaaaaaa", hub(1.5));
assertEquals("7fffffff ffffffff ffffffff ffffffff ffffffff ffffffff ffffffff ffffffff", hub(2));
assertEquals("55555555 55555555 55555555 55555555 55555555 55555555 55555555 55555555", hub(3));
assertEquals("3fffffff ffffffff ffffffff ffffffff ffffffff ffffffff ffffffff ffffffff", hub(4));
assertEquals("33333333 33333333 33333333 33333333 33333333 33333333 33333333 33333333", hub(5));
assertEquals("051eb851 eb851eb8 51eb851e b851eb85 1eb851eb 851eb851 eb851eb8 51eb851e", hub(50.0));
assertEquals("0083126e 978d4fdf 3b645a1c ac083126 e978d4fd f3b645a1 cac08312 6e978d4f", hub(500.0));
assertEquals("00000000 00000000 2f394219 248446ba a23d2ec7 29af3d61 0607aa01 67dd94ca", hub(1.0e20));
assertEquals("00000000 00000000 00000000 00000000 ffffffff ffffffff ffffffff ffffffff", hub(0x1.0p128));
assertEquals("00000000 00000000 00000000 00000000 00000000 00000000 00000000 00000000", hub(Double.POSITIVE_INFINITY));
}
@Test
public void testFindSolution() {
Equihash equihash = new Equihash(90, 5, 5.0);
byte[] seed = new byte[64];
Solution solution = equihash.puzzle(seed).findSolution();
byte[] solutionBytes = solution.serialize();
Solution roundTrippedSolution = equihash.puzzle(seed).deserializeSolution(solutionBytes);
assertTrue(solution.verify());
assertEquals(72, solutionBytes.length);
assertEquals(solution.toString(), roundTrippedSolution.toString());
}
private static String hub(double difficulty) {
return hexString(Equihash.hashUpperBound(difficulty));
}
private static String hexString(int[] ints) {
return Arrays.stream(ints)
.mapToObj(n -> Strings.padStart(Integer.toHexString(n), 8, '0'))
.collect(Collectors.joining(" "));
}
}