HD Wallets: fix a lookahead bug.

We were not previously triggering lookahead before calculating a Bloom filter, which means we might have missed transactions in some edge cases. Add a test to catch this and then fix up various unit tests to have fewer magic numbers and be more robust to changes.
This commit is contained in:
Mike Hearn 2014-08-11 16:37:31 +02:00
parent 54a543bf77
commit d824666c2f
4 changed files with 74 additions and 51 deletions

View File

@ -478,7 +478,14 @@ public class DeterministicKeyChain implements EncryptableKeyChain {
public int numKeys() {
// We need to return here the total number of keys including the lookahead zone, not the number of keys we
// have issued via getKey/freshReceiveKey.
return basicKeyChain.numKeys();
lock.lock();
try {
maybeLookAhead();
return basicKeyChain.numKeys();
} finally {
lock.unlock();
}
}
/**
@ -821,8 +828,15 @@ public class DeterministicKeyChain implements EncryptableKeyChain {
@Override
public BloomFilter getFilter(int size, double falsePositiveRate, long tweak) {
checkArgument(size >= numBloomFilterEntries());
return basicKeyChain.getFilter(size, falsePositiveRate, tweak);
lock.lock();
try {
checkArgument(size >= numBloomFilterEntries());
maybeLookAhead();
return basicKeyChain.getFilter(size, falsePositiveRate, tweak);
} finally {
lock.unlock();
}
}
/**
@ -919,15 +933,14 @@ public class DeterministicKeyChain implements EncryptableKeyChain {
final int lookaheadThreshold = getLookaheadThreshold();
final int needed = issued + lookaheadSize + lookaheadThreshold - numChildren;
log.info("{} keys needed = {} issued + {} lookahead size + {} lookahead threshold - {} num children",
needed, issued, lookaheadSize, lookaheadThreshold, numChildren);
if (needed <= lookaheadThreshold)
return new ArrayList<DeterministicKey>();
log.info("{} keys needed for {} = {} issued + {} lookahead size + {} lookahead threshold - {} num children",
needed, parent.getPathAsString(), issued, lookaheadSize, lookaheadThreshold, numChildren);
List<DeterministicKey> result = new ArrayList<DeterministicKey>(needed);
long now = System.currentTimeMillis();
log.info("Pre-generating {} keys for {}", needed, parent.getPathAsString());
int nextChild = numChildren;
for (int i = 0; i < needed; i++) {
DeterministicKey key = HDKeyDerivation.deriveThisOrNextChildKey(parent, nextChild);

View File

@ -36,7 +36,6 @@ import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.util.*;
@ -574,7 +573,8 @@ public class PeerGroupTest extends TestWithPeerGroup {
public void testBloomResendOnNewKey() throws Exception {
// Check that when we add a new key to the wallet, the Bloom filter is re-calculated and re-sent but only once
// we exceed the lookahead threshold.
wallet.setKeychainLookaheadSize(20);
wallet.setKeychainLookaheadSize(5);
wallet.setKeychainLookaheadThreshold(4);
peerGroup.startAsync();
peerGroup.awaitRunning();
// Create a couple of peers.
@ -582,28 +582,24 @@ public class PeerGroupTest extends TestWithPeerGroup {
InboundMessageQueuer p2 = connectPeer(2);
peerGroup.waitForJobQueue();
BloomFilter f1 = p1.lastReceivedFilter;
int threshold = wallet.getKeychainLookaheadThreshold();
wallet.freshReceiveKey(); // Force generation with the new lookahead size.
peerGroup.waitForJobQueue();
assertEquals(BloomFilter.class, outbound(p1).getClass());
assertEquals(MemoryPoolMessage.class, outbound(p1).getClass());
ECKey key = null;
// We have to run ahead of the lookahead zone for this test. There should only be one bloom filter recalc.
for (int i = 0; i < threshold + 2; i++) {
for (int i = 0; i < wallet.getKeychainLookaheadSize() + wallet.getKeychainLookaheadThreshold() + 1; i++) {
key = wallet.freshReceiveKey();
}
// Wait here. Bloom filters are recalculated asynchronously so if we didn't wait, we might not pass the
// test below where we expect each key to generate a new filter because this thread could generate all
// the keys before the peergroup thread does the recalculation, causing only one filter to be sent.
peerGroup.waitForJobQueue();
BloomFilter f3 = (BloomFilter) outbound(p1);
assertNotNull(f3);
assertEquals(MemoryPoolMessage.class, outbound(p1).getClass());
BloomFilter bf, f2 = null;
while ((bf = (BloomFilter) outbound(p1)) != null) {
assertEquals(MemoryPoolMessage.class, outbound(p1).getClass());
f2 = bf;
}
assertNotNull(key);
assertNotNull(f2);
assertNull(outbound(p1));
// Check the last filter received.
assertNotEquals(f1, f3);
assertTrue(f3.contains(key.getPubKey()));
assertTrue(f3.contains(key.getPubKeyHash()));
assertNotEquals(f1, f2);
assertTrue(f2.contains(key.getPubKey()));
assertTrue(f2.contains(key.getPubKeyHash()));
assertFalse(f1.contains(key.getPubKey()));
assertFalse(f1.contains(key.getPubKeyHash()));
}

View File

@ -252,17 +252,36 @@ public class DeterministicKeyChainTest {
}
@Test
public void bloom() {
public void bloom1() {
DeterministicKey key2 = chain.getKey(KeyChain.KeyPurpose.RECEIVE_FUNDS);
DeterministicKey key1 = chain.getKey(KeyChain.KeyPurpose.RECEIVE_FUNDS);
// The filter includes the internal keys as well (for now), although I'm not sure if we should allow funds to
// be received on them or not ....
assertEquals(36, chain.numBloomFilterEntries());
BloomFilter filter = chain.getFilter(36, 0.001, 1);
// ((13*2)+2+3)*2
int numEntries =
(((chain.getLookaheadSize() + chain.getLookaheadThreshold()) * 2) // * 2 because of internal/external
+ chain.numLeafKeysIssued()
+ 3 // one account key + two chain keys (internal/external)
) * 2; // because the filter contains keys and key hashes.
assertEquals(numEntries, chain.numBloomFilterEntries());
BloomFilter filter = chain.getFilter(numEntries, 0.001, 1);
assertTrue(filter.contains(key1.getPubKey()));
assertTrue(filter.contains(key1.getPubKeyHash()));
assertTrue(filter.contains(key2.getPubKey()));
assertTrue(filter.contains(key2.getPubKeyHash()));
// The lookahead zone is tested in bloom2 and via KeyChainGroupTest.bloom
}
@Test
public void bloom2() throws Exception {
// Verify that if when we watch a key, the filter contains at least 100 keys.
DeterministicKey[] keys = new DeterministicKey[100];
for (int i = 0; i < keys.length; i++)
keys[i] = chain.getKey(KeyChain.KeyPurpose.RECEIVE_FUNDS);
chain = DeterministicKeyChain.watch(chain.getWatchingKey());
int e = chain.numBloomFilterEntries();
BloomFilter filter = chain.getFilter(e, 0.001, 1);
for (DeterministicKey key : keys)
assertTrue("key " + key, filter.contains(key.getPubKeyHash()));
}
private String protoToString(List<Protos.Key> keys) {

View File

@ -66,17 +66,20 @@ public class KeyChainGroupTest {
@Test
public void freshCurrentKeys() throws Exception {
assertEquals(INITIAL_KEYS, group.numKeys());
assertEquals(2 * INITIAL_KEYS, group.getBloomFilterElementCount());
int numKeys = ((group.getLookaheadSize() + group.getLookaheadThreshold()) * 2) // * 2 because of internal/external
+ 1 // keys issued
+ 3 /* account key + int/ext parent keys */;
assertEquals(numKeys, group.numKeys());
assertEquals(2 * numKeys, group.getBloomFilterElementCount());
ECKey r1 = group.currentKey(KeyChain.KeyPurpose.RECEIVE_FUNDS);
final int keys = INITIAL_KEYS + LOOKAHEAD_SIZE + group.getLookaheadThreshold() + 1;
assertEquals(keys, group.numKeys());
assertEquals(2 * keys, group.getBloomFilterElementCount());
assertEquals(numKeys, group.numKeys());
assertEquals(2 * numKeys, group.getBloomFilterElementCount());
ECKey i1 = new ECKey();
group.importKeys(i1);
assertEquals(keys + 1, group.numKeys());
assertEquals(2 * (keys + 1), group.getBloomFilterElementCount());
numKeys++;
assertEquals(numKeys, group.numKeys());
assertEquals(2 * numKeys, group.getBloomFilterElementCount());
ECKey r2 = group.currentKey(KeyChain.KeyPurpose.RECEIVE_FUNDS);
assertEquals(r1, r2);
@ -117,11 +120,12 @@ public class KeyChainGroupTest {
@Test
public void imports() throws Exception {
ECKey key1 = new ECKey();
int numKeys = group.numKeys();
assertFalse(group.removeImportedKey(key1));
assertEquals(1, group.importKeys(ImmutableList.of(key1)));
assertEquals(INITIAL_KEYS + 1, group.numKeys()); // Lookahead is triggered by requesting a key, so none yet.
assertEquals(numKeys + 1, group.numKeys()); // Lookahead is triggered by requesting a key, so none yet.
group.removeImportedKey(key1);
assertEquals(INITIAL_KEYS, group.numKeys());
assertEquals(numKeys, group.numKeys());
}
@Test
@ -154,16 +158,10 @@ public class KeyChainGroupTest {
@Test
public void currentP2SHAddress() throws Exception {
group = createMarriedKeyChainGroup();
assertEquals(INITIAL_KEYS, group.numKeys());
Address a1 = group.currentAddress(KeyChain.KeyPurpose.RECEIVE_FUNDS);
assertEquals(INITIAL_KEYS + 1 + LOOKAHEAD_SIZE + group.getLookaheadThreshold(), group.numKeys());
assertTrue(a1.isP2SHAddress());
Address a2 = group.currentAddress(KeyChain.KeyPurpose.RECEIVE_FUNDS);
assertEquals(a1, a2);
assertEquals(INITIAL_KEYS + 1 + LOOKAHEAD_SIZE + group.getLookaheadThreshold(), group.numKeys());
Address a3 = group.currentAddress(KeyChain.KeyPurpose.CHANGE);
assertNotEquals(a2, a3);
}
@ -175,8 +173,9 @@ public class KeyChainGroupTest {
Address a2 = group.freshAddress(KeyChain.KeyPurpose.RECEIVE_FUNDS);
assertTrue(a1.isP2SHAddress());
assertNotEquals(a1, a2);
// numKeys does not include following chains. Possibly it should.
assertEquals(INITIAL_KEYS + 1 + group.getLookaheadSize() + group.getLookaheadThreshold(), group.numKeys());
assertEquals(((group.getLookaheadSize() + group.getLookaheadThreshold()) * 2) // * 2 because of internal/external
+ 2 // keys issued
+ 3, group.numKeys());
Address a3 = group.currentAddress(KeyChain.KeyPurpose.RECEIVE_FUNDS);
assertEquals(a2, a3);
@ -289,7 +288,6 @@ public class KeyChainGroupTest {
KeyCrypterScrypt scrypt = new KeyCrypterScrypt(2);
final KeyParameter aesKey = scrypt.deriveKey("password");
group.encrypt(scrypt, aesKey);
assertEquals(4, group.numKeys());
assertTrue(group.freshKey(KeyChain.KeyPurpose.RECEIVE_FUNDS).isEncrypted());
final ECKey key = group.currentKey(KeyChain.KeyPurpose.RECEIVE_FUNDS);
group.decrypt(aesKey);
@ -298,12 +296,9 @@ public class KeyChainGroupTest {
@Test
public void bloom() throws Exception {
assertEquals(INITIAL_KEYS * 2, group.getBloomFilterElementCount());
ECKey key1 = group.freshKey(KeyChain.KeyPurpose.RECEIVE_FUNDS);
ECKey key2 = new ECKey();
final int size = (INITIAL_KEYS + LOOKAHEAD_SIZE + group.getLookaheadThreshold() + 1 /* for the just created key */) * 2;
assertEquals(size, group.getBloomFilterElementCount());
BloomFilter filter = group.getBloomFilter(size, 0.001, (long)(Math.random() * Long.MAX_VALUE));
BloomFilter filter = group.getBloomFilter(group.getBloomFilterElementCount(), 0.001, (long)(Math.random() * Long.MAX_VALUE));
assertTrue(filter.contains(key1.getPubKeyHash()));
assertTrue(filter.contains(key1.getPubKey()));
assertFalse(filter.contains(key2.getPubKey()));