mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-19 05:45:21 +01:00
e86babe133
In this commit, we extend the DerivePrivKey method to allow callers that don't know the full KeyLocator information to attempt to derive a private key via a brute force mechanism. If we don't now the full KeyLoactor, then given the KeyFamily, we can walk down the derivation path and compare keys one by one. In order to ensure we don' t enter an infinite loop when given an unknown public key, we cap the number of keys derived at 100k. An upcoming feature to lnd that adds static channel backups will utilize this feature, as we need to derive the shachain root given only the public key and key family, as we don't currently store this KeyLocator on disk.
425 lines
12 KiB
Go
425 lines
12 KiB
Go
package keychain
|
|
|
|
import (
|
|
"fmt"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/btcsuite/btcd/btcec"
|
|
"github.com/btcsuite/btcd/chaincfg"
|
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
|
"github.com/btcsuite/btcwallet/waddrmgr"
|
|
"github.com/btcsuite/btcwallet/wallet"
|
|
"github.com/btcsuite/btcwallet/walletdb"
|
|
"github.com/davecgh/go-spew/spew"
|
|
|
|
_ "github.com/btcsuite/btcwallet/walletdb/bdb" // Required in order to create the default database.
|
|
)
|
|
|
|
// versionZeroKeyFamilies is a slice of all the known key families for first
|
|
// version of the key derivation schema defined in this package.
|
|
var versionZeroKeyFamilies = []KeyFamily{
|
|
KeyFamilyMultiSig,
|
|
KeyFamilyRevocationBase,
|
|
KeyFamilyHtlcBase,
|
|
KeyFamilyPaymentBase,
|
|
KeyFamilyDelayBase,
|
|
KeyFamilyRevocationRoot,
|
|
KeyFamilyNodeKey,
|
|
}
|
|
|
|
var (
|
|
testHDSeed = chainhash.Hash{
|
|
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
|
|
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
|
|
0x4f, 0x2f, 0x6f, 0x25, 0x98, 0xa3, 0xef, 0xb9,
|
|
0x69, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
|
|
}
|
|
)
|
|
|
|
func createTestBtcWallet(coinType uint32) (func(), *wallet.Wallet, error) {
|
|
tempDir, err := ioutil.TempDir("", "keyring-lnwallet")
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
loader := wallet.NewLoader(&chaincfg.SimNetParams, tempDir, 0)
|
|
|
|
pass := []byte("test")
|
|
|
|
baseWallet, err := loader.CreateNewWallet(
|
|
pass, pass, testHDSeed[:], time.Time{},
|
|
)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if err := baseWallet.Unlock(pass, nil); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Construct the key scope required to derive keys for the chose
|
|
// coinType.
|
|
chainKeyScope := waddrmgr.KeyScope{
|
|
Purpose: BIP0043Purpose,
|
|
Coin: coinType,
|
|
}
|
|
|
|
// We'll now ensure that the KeyScope: (1017, coinType) exists within
|
|
// the internal waddrmgr. We'll need this in order to properly generate
|
|
// the keys required for signing various contracts.
|
|
_, err = baseWallet.Manager.FetchScopedKeyManager(chainKeyScope)
|
|
if err != nil {
|
|
err := walletdb.Update(baseWallet.Database(), func(tx walletdb.ReadWriteTx) error {
|
|
addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
|
|
|
|
_, err := baseWallet.Manager.NewScopedKeyManager(
|
|
addrmgrNs, chainKeyScope, lightningAddrSchema,
|
|
)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
cleanUp := func() {
|
|
baseWallet.Lock()
|
|
os.RemoveAll(tempDir)
|
|
}
|
|
|
|
return cleanUp, baseWallet, nil
|
|
}
|
|
|
|
func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) {
|
|
t.Helper()
|
|
if a != b {
|
|
t.Fatalf("mismatched key locators: expected %v, "+
|
|
"got %v", spew.Sdump(a), spew.Sdump(b))
|
|
}
|
|
}
|
|
|
|
// secretKeyRingConstructor is a function signature that's used as a generic
|
|
// constructor for various implementations of the KeyRing interface. A string
|
|
// naming the returned interface, a function closure that cleans up any
|
|
// resources, and the clean up interface itself are to be returned.
|
|
type keyRingConstructor func() (string, func(), KeyRing, error)
|
|
|
|
// TestKeyRingDerivation tests that each known KeyRing implementation properly
|
|
// adheres to the expected behavior of the set of interfaces.
|
|
func TestKeyRingDerivation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
keyRingImplementations := []keyRingConstructor{
|
|
func() (string, func(), KeyRing, error) {
|
|
cleanUp, wallet, err := createTestBtcWallet(
|
|
CoinTypeBitcoin,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unable to create wallet: %v", err)
|
|
}
|
|
|
|
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeBitcoin)
|
|
|
|
return "btcwallet", cleanUp, keyRing, nil
|
|
},
|
|
func() (string, func(), KeyRing, error) {
|
|
cleanUp, wallet, err := createTestBtcWallet(
|
|
CoinTypeLitecoin,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unable to create wallet: %v", err)
|
|
}
|
|
|
|
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeLitecoin)
|
|
|
|
return "ltcwallet", cleanUp, keyRing, nil
|
|
},
|
|
func() (string, func(), KeyRing, error) {
|
|
cleanUp, wallet, err := createTestBtcWallet(
|
|
CoinTypeTestnet,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unable to create wallet: %v", err)
|
|
}
|
|
|
|
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeTestnet)
|
|
|
|
return "testwallet", cleanUp, keyRing, nil
|
|
},
|
|
}
|
|
|
|
const numKeysToDerive = 10
|
|
|
|
// For each implementation constructor registered above, we'll execute
|
|
// an identical set of tests in order to ensure that the interface
|
|
// adheres to our nominal specification.
|
|
for _, keyRingConstructor := range keyRingImplementations {
|
|
keyRingName, cleanUp, keyRing, err := keyRingConstructor()
|
|
if err != nil {
|
|
t.Fatalf("unable to create key ring %v: %v", keyRingName,
|
|
err)
|
|
}
|
|
defer cleanUp()
|
|
|
|
success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) {
|
|
// First, we'll ensure that we're able to derive keys
|
|
// from each of the known key families.
|
|
for _, keyFam := range versionZeroKeyFamilies {
|
|
// First, we'll ensure that we can derive the
|
|
// *next* key in the keychain.
|
|
keyDesc, err := keyRing.DeriveNextKey(keyFam)
|
|
if err != nil {
|
|
t.Fatalf("unable to derive next for "+
|
|
"keyFam=%v: %v", keyFam, err)
|
|
}
|
|
assertEqualKeyLocator(t,
|
|
KeyLocator{
|
|
Family: keyFam,
|
|
Index: 0,
|
|
}, keyDesc.KeyLocator,
|
|
)
|
|
|
|
// We'll now re-derive that key to ensure that
|
|
// we're able to properly access the key via
|
|
// the random access derivation methods.
|
|
keyLoc := KeyLocator{
|
|
Family: keyFam,
|
|
Index: 0,
|
|
}
|
|
firstKeyDesc, err := keyRing.DeriveKey(keyLoc)
|
|
if err != nil {
|
|
t.Fatalf("unable to derive first key for "+
|
|
"keyFam=%v: %v", keyFam, err)
|
|
}
|
|
if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) {
|
|
t.Fatalf("mismatched keys: expected %x, "+
|
|
"got %x",
|
|
keyDesc.PubKey.SerializeCompressed(),
|
|
firstKeyDesc.PubKey.SerializeCompressed())
|
|
}
|
|
assertEqualKeyLocator(t,
|
|
KeyLocator{
|
|
Family: keyFam,
|
|
Index: 0,
|
|
}, firstKeyDesc.KeyLocator,
|
|
)
|
|
|
|
// If we now try to manually derive the next 10
|
|
// keys (including the original key), then we
|
|
// should get an identical public key back and
|
|
// their KeyLocator information
|
|
// should be set properly.
|
|
for i := 0; i < numKeysToDerive+1; i++ {
|
|
keyLoc := KeyLocator{
|
|
Family: keyFam,
|
|
Index: uint32(i),
|
|
}
|
|
keyDesc, err := keyRing.DeriveKey(keyLoc)
|
|
if err != nil {
|
|
t.Fatalf("unable to derive first key for "+
|
|
"keyFam=%v: %v", keyFam, err)
|
|
}
|
|
|
|
// Ensure that the key locator matches
|
|
// up as well.
|
|
assertEqualKeyLocator(
|
|
t, keyLoc, keyDesc.KeyLocator,
|
|
)
|
|
}
|
|
|
|
// If this succeeds, then we'll also try to
|
|
// derive a random index within the range.
|
|
randKeyIndex := uint32(rand.Int31())
|
|
keyLoc = KeyLocator{
|
|
Family: keyFam,
|
|
Index: randKeyIndex,
|
|
}
|
|
keyDesc, err = keyRing.DeriveKey(keyLoc)
|
|
if err != nil {
|
|
t.Fatalf("unable to derive key_index=%v "+
|
|
"for keyFam=%v: %v",
|
|
randKeyIndex, keyFam, err)
|
|
}
|
|
assertEqualKeyLocator(
|
|
t, keyLoc, keyDesc.KeyLocator,
|
|
)
|
|
}
|
|
})
|
|
if !success {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// secretKeyRingConstructor is a function signature that's used as a generic
|
|
// constructor for various implementations of the SecretKeyRing interface. A
|
|
// string naming the returned interface, a function closure that cleans up any
|
|
// resources, and the clean up interface itself are to be returned.
|
|
type secretKeyRingConstructor func() (string, func(), SecretKeyRing, error)
|
|
|
|
// TestSecretKeyRingDerivation tests that each known SecretKeyRing
|
|
// implementation properly adheres to the expected behavior of the set of
|
|
// interface.
|
|
func TestSecretKeyRingDerivation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
secretKeyRingImplementations := []secretKeyRingConstructor{
|
|
func() (string, func(), SecretKeyRing, error) {
|
|
cleanUp, wallet, err := createTestBtcWallet(
|
|
CoinTypeBitcoin,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unable to create wallet: %v", err)
|
|
}
|
|
|
|
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeBitcoin)
|
|
|
|
return "btcwallet", cleanUp, keyRing, nil
|
|
},
|
|
func() (string, func(), SecretKeyRing, error) {
|
|
cleanUp, wallet, err := createTestBtcWallet(
|
|
CoinTypeLitecoin,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unable to create wallet: %v", err)
|
|
}
|
|
|
|
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeLitecoin)
|
|
|
|
return "ltcwallet", cleanUp, keyRing, nil
|
|
},
|
|
func() (string, func(), SecretKeyRing, error) {
|
|
cleanUp, wallet, err := createTestBtcWallet(
|
|
CoinTypeTestnet,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unable to create wallet: %v", err)
|
|
}
|
|
|
|
keyRing := NewBtcWalletKeyRing(wallet, CoinTypeTestnet)
|
|
|
|
return "testwallet", cleanUp, keyRing, nil
|
|
},
|
|
}
|
|
|
|
// For each implementation constructor registered above, we'll execute
|
|
// an identical set of tests in order to ensure that the interface
|
|
// adheres to our nominal specification.
|
|
for _, secretKeyRingConstructor := range secretKeyRingImplementations {
|
|
keyRingName, cleanUp, secretKeyRing, err := secretKeyRingConstructor()
|
|
if err != nil {
|
|
t.Fatalf("unable to create secret key ring %v: %v",
|
|
keyRingName, err)
|
|
}
|
|
defer cleanUp()
|
|
|
|
success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) {
|
|
// For, each key family, we'll ensure that we're able
|
|
// to obtain the private key of a randomly select child
|
|
// index within the key family.
|
|
for _, keyFam := range versionZeroKeyFamilies {
|
|
randKeyIndex := uint32(rand.Int31())
|
|
keyLoc := KeyLocator{
|
|
Family: keyFam,
|
|
Index: randKeyIndex,
|
|
}
|
|
|
|
// First, we'll query for the public key for
|
|
// this target key locator.
|
|
pubKeyDesc, err := secretKeyRing.DeriveKey(keyLoc)
|
|
if err != nil {
|
|
t.Fatalf("unable to derive pubkey "+
|
|
"(fam=%v, index=%v): %v",
|
|
keyLoc.Family,
|
|
keyLoc.Index, err)
|
|
}
|
|
|
|
// With the public key derive, ensure that
|
|
// we're able to obtain the corresponding
|
|
// private key correctly.
|
|
privKey, err := secretKeyRing.DerivePrivKey(KeyDescriptor{
|
|
KeyLocator: keyLoc,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unable to derive priv "+
|
|
"(fam=%v, index=%v): %v", keyLoc.Family,
|
|
keyLoc.Index, err)
|
|
}
|
|
|
|
// Finally, ensure that the keys match up
|
|
// properly.
|
|
if !pubKeyDesc.PubKey.IsEqual(privKey.PubKey()) {
|
|
t.Fatalf("pubkeys mismatched: expected %x, got %x",
|
|
pubKeyDesc.PubKey.SerializeCompressed(),
|
|
privKey.PubKey().SerializeCompressed())
|
|
}
|
|
|
|
// Next, we'll test that we're able to derive a
|
|
// key given only the public key and key
|
|
// family.
|
|
//
|
|
// Derive a new key from the key ring.
|
|
keyDesc, err := secretKeyRing.DeriveNextKey(keyFam)
|
|
if err != nil {
|
|
t.Fatalf("unable to derive key: %v", err)
|
|
}
|
|
|
|
// We'll now construct a key descriptor that
|
|
// requires us to scan the key range, and query
|
|
// for the key, we should be able to find it as
|
|
// it's valid.
|
|
keyDesc = KeyDescriptor{
|
|
PubKey: keyDesc.PubKey,
|
|
KeyLocator: KeyLocator{
|
|
Family: keyFam,
|
|
},
|
|
}
|
|
privKey, err = secretKeyRing.DerivePrivKey(keyDesc)
|
|
if err != nil {
|
|
t.Fatalf("unable to derive priv key "+
|
|
"via scanning: %v", err)
|
|
}
|
|
|
|
// Having to resort to scanning, we should be
|
|
// able to find the target public key.
|
|
if !keyDesc.PubKey.IsEqual(privKey.PubKey()) {
|
|
t.Fatalf("pubkeys mismatched: expected %x, got %x",
|
|
pubKeyDesc.PubKey.SerializeCompressed(),
|
|
privKey.PubKey().SerializeCompressed())
|
|
}
|
|
|
|
// We'll try again, but this time with an
|
|
// unknown public key.
|
|
_, pub := btcec.PrivKeyFromBytes(
|
|
btcec.S256(), testHDSeed[:],
|
|
)
|
|
keyDesc.PubKey = pub
|
|
|
|
// If we attempt to query for this key, then we
|
|
// should get ErrCannotDerivePrivKey.
|
|
privKey, err = secretKeyRing.DerivePrivKey(
|
|
keyDesc,
|
|
)
|
|
if err != ErrCannotDerivePrivKey {
|
|
t.Fatalf("expected %T, instead got %v",
|
|
ErrCannotDerivePrivKey, err)
|
|
}
|
|
|
|
// TODO(roasbeef): scalar mult once integrated
|
|
}
|
|
})
|
|
if !success {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func init() {
|
|
// We'll clamp the max range scan to constrain the run time of the
|
|
// private key scan test.
|
|
MaxKeyRangeScan = 3
|
|
}
|