keychain: extend TestKeyRingDerivation to check KeyLocators of derived keys

This commit is contained in:
Olaoluwa Osuntokun 2018-08-13 19:20:57 -07:00
parent ad25ae1a07
commit cf06b041a4
No known key found for this signature in database
GPG key ID: 964EA263DD637C21

View file

@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
@ -13,6 +14,7 @@ import (
"github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/walletdb"
"github.com/davecgh/go-spew/spew"
_ "github.com/btcsuite/btcwallet/walletdb/bdb" // Required in order to create the default database. _ "github.com/btcsuite/btcwallet/walletdb/bdb" // Required in order to create the default database.
) )
@ -91,6 +93,14 @@ func createTestBtcWallet(coinType uint32) (func(), *wallet.Wallet, error) {
return cleanUp, baseWallet, nil return cleanUp, baseWallet, nil
} }
func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) {
_, _, line, _ := runtime.Caller(1)
if a != b {
t.Fatalf("line #%v: mismatched key locators: expected %v, "+
"got %v", line, spew.Sdump(a), spew.Sdump(b))
}
}
// secretKeyRingConstructor is a function signature that's used as a generic // secretKeyRingConstructor is a function signature that's used as a generic
// constructor for various implementations of the KeyRing interface. A string // constructor for various implementations of the KeyRing interface. A string
// naming the returned interface, a function closure that cleans up any // naming the returned interface, a function closure that cleans up any
@ -141,6 +151,8 @@ func TestKeyRingDerivation(t *testing.T) {
}, },
} }
const numKeysToDerive = 10
// For each implementation constructor registered above, we'll execute // For each implementation constructor registered above, we'll execute
// an identical set of tests in order to ensure that the interface // an identical set of tests in order to ensure that the interface
// adheres to our nominal specification. // adheres to our nominal specification.
@ -163,10 +175,16 @@ func TestKeyRingDerivation(t *testing.T) {
t.Fatalf("unable to derive next for "+ t.Fatalf("unable to derive next for "+
"keyFam=%v: %v", keyFam, err) "keyFam=%v: %v", keyFam, err)
} }
assertEqualKeyLocator(t,
KeyLocator{
Family: keyFam,
Index: 0,
}, keyDesc.KeyLocator,
)
// If we now try to manually derive the *first* // We'll now re-derive that key to ensure that
// key, then we should get an identical public // we're able to properly access the key via
// key back. // the random access derivation methods.
keyLoc := KeyLocator{ keyLoc := KeyLocator{
Family: keyFam, Family: keyFam,
Index: 0, Index: 0,
@ -176,13 +194,41 @@ func TestKeyRingDerivation(t *testing.T) {
t.Fatalf("unable to derive first key for "+ t.Fatalf("unable to derive first key for "+
"keyFam=%v: %v", keyFam, err) "keyFam=%v: %v", keyFam, err)
} }
if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) { if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) {
t.Fatalf("mismatched keys: expected %v, "+ t.Fatalf("mismatched keys: expected %x, "+
"got %x", "got %x",
keyDesc.PubKey.SerializeCompressed(), keyDesc.PubKey.SerializeCompressed(),
firstKeyDesc.PubKey.SerializeCompressed()) firstKeyDesc.PubKey.SerializeCompressed())
} }
assertEqualKeyLocator(t,
KeyLocator{
Family: keyFam,
Index: 0,
}, firstKeyDesc.KeyLocator,
)
// If we now try to manually derive the next 10
// keys (including the original key), then we
// should get an identical public key back and
// their KeyLocator information
// should be set properly.
for i := 0; i < numKeysToDerive+1; i++ {
keyLoc := KeyLocator{
Family: keyFam,
Index: uint32(i),
}
keyDesc, err := keyRing.DeriveKey(keyLoc)
if err != nil {
t.Fatalf("unable to derive first key for "+
"keyFam=%v: %v", keyFam, err)
}
// Ensure that the key locator matches
// up as well.
assertEqualKeyLocator(
t, keyLoc, keyDesc.KeyLocator,
)
}
// If this succeeds, then we'll also try to // If this succeeds, then we'll also try to
// derive a random index within the range. // derive a random index within the range.
@ -191,12 +237,15 @@ func TestKeyRingDerivation(t *testing.T) {
Family: keyFam, Family: keyFam,
Index: randKeyIndex, Index: randKeyIndex,
} }
_, err = keyRing.DeriveKey(keyLoc) keyDesc, err = keyRing.DeriveKey(keyLoc)
if err != nil { if err != nil {
t.Fatalf("unable to derive key_index=%v "+ t.Fatalf("unable to derive key_index=%v "+
"for keyFam=%v: %v", "for keyFam=%v: %v",
randKeyIndex, keyFam, err) randKeyIndex, keyFam, err)
} }
assertEqualKeyLocator(
t, keyLoc, keyDesc.KeyLocator,
)
} }
}) })
if !success { if !success {