multi: refactor testListAddresses

This commit is contained in:
yyforyongyu 2022-08-18 17:30:34 +08:00
parent 4104a72b3a
commit 2c12c8a77c
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
6 changed files with 94 additions and 77 deletions

View File

@ -98,3 +98,29 @@ func (h *HarnessRPC) LabelTransaction(req *walletrpc.LabelTransactionRequest) {
_, err := h.WalletKit.LabelTransaction(ctxt, req)
h.NoError(err, "LabelTransaction")
}
// DeriveNextKey makes a RPC call to the DeriveNextKey and asserts.
func (h *HarnessRPC) DeriveNextKey(
req *walletrpc.KeyReq) *signrpc.KeyDescriptor {
ctxt, cancel := context.WithTimeout(h.runCtx, DefaultTimeout)
defer cancel()
key, err := h.WalletKit.DeriveNextKey(ctxt, req)
h.NoError(err, "DeriveNextKey")
return key
}
// ListAddresses makes a RPC call to the ListAddresses and asserts.
func (h *HarnessRPC) ListAddresses(
req *walletrpc.ListAddressesRequest) *walletrpc.ListAddressesResponse {
ctxt, cancel := context.WithTimeout(h.runCtx, DefaultTimeout)
defer cancel()
key, err := h.WalletKit.ListAddresses(ctxt, req)
h.NoError(err, "ListAddresses")
return key
}

View File

@ -5,6 +5,8 @@ import (
"io"
"math"
"os"
"strconv"
"strings"
"github.com/lightningnetwork/lnd/lntest"
)
@ -53,3 +55,40 @@ func errNumNotMatched(name string, subject string,
return fmt.Errorf("%s: assert %s failed: want %d, got: %d, total: "+
"%d, previously had: %d", name, subject, want, got, total, old)
}
// parseDerivationPath parses a path in the form of m/x'/y'/z'/a/b into a slice
// of [x, y, z, a, b], meaning that the apostrophe is ignored and 2^31 is _not_
// added to the numbers.
func ParseDerivationPath(path string) ([]uint32, error) {
path = strings.TrimSpace(path)
if len(path) == 0 {
return nil, fmt.Errorf("path cannot be empty")
}
if !strings.HasPrefix(path, "m/") {
return nil, fmt.Errorf("path must start with m/")
}
// Just the root key, no path was provided. This is valid but not useful
// in most cases.
rest := strings.ReplaceAll(path, "m/", "")
if rest == "" {
return []uint32{}, nil
}
parts := strings.Split(rest, "/")
indices := make([]uint32, len(parts))
for i := 0; i < len(parts); i++ {
part := parts[i]
if strings.Contains(parts[i], "'") {
part = strings.TrimRight(parts[i], "'")
}
parsed, err := strconv.ParseInt(part, 10, 32)
if err != nil {
return nil, fmt.Errorf("could not parse part \"%s\": "+
"%v", part, err)
}
indices[i] = uint32(parsed)
}
return indices, nil
}

View File

@ -103,4 +103,8 @@ var allTestCasesTemp = []*lntemp.TestCase{
Name: "node sign verify",
TestFunc: testNodeSignVerify,
},
{
Name: "list addresses",
TestFunc: testListAddresses,
},
}

View File

@ -989,16 +989,13 @@ func testSweepAllCoins(ht *lntemp.HarnessTest) {
// testListAddresses tests that we get all the addresses and their
// corresponding balance correctly.
func testListAddresses(net *lntest.NetworkHarness, t *harnessTest) {
ctxb := context.Background()
func testListAddresses(ht *lntemp.HarnessTest) {
// First, we'll make a new node - Alice, which will be generating
// new addresses.
alice := net.NewNode(t.t, "Alice", nil)
defer shutdownAndAssert(net, t, alice)
alice := ht.NewNode("Alice", nil)
// Next, we'll give Alice exactly 1 utxo of 1 BTC.
net.SendCoins(t.t, btcutil.SatoshiPerBitcoin, alice)
ht.FundCoins(btcutil.SatoshiPerBitcoin, alice)
type addressDetails struct {
Balance int64
@ -1010,81 +1007,75 @@ func testListAddresses(net *lntest.NetworkHarness, t *harnessTest) {
// Create an address generated from internal keys.
keyLoc := &walletrpc.KeyReq{KeyFamily: 123}
keyDesc, err := alice.WalletKitClient.DeriveNextKey(ctxb, keyLoc)
require.NoError(t.t, err)
keyDesc := alice.RPC.DeriveNextKey(keyLoc)
// Hex Encode the public key.
pubkeyString := hex.EncodeToString(keyDesc.RawKeyBytes)
// Create a p2tr address.
resp, err := alice.NewAddress(ctxb, &lnrpc.NewAddressRequest{
resp := alice.RPC.NewAddress(&lnrpc.NewAddressRequest{
Type: lnrpc.AddressType_TAPROOT_PUBKEY,
})
require.NoError(t.t, err)
generatedAddr[resp.Address] = addressDetails{
Balance: 200_000,
Type: walletrpc.AddressType_TAPROOT_PUBKEY,
}
// Create a p2wkh address.
resp, err = alice.NewAddress(ctxb, &lnrpc.NewAddressRequest{
resp = alice.RPC.NewAddress(&lnrpc.NewAddressRequest{
Type: lnrpc.AddressType_WITNESS_PUBKEY_HASH,
})
require.NoError(t.t, err)
generatedAddr[resp.Address] = addressDetails{
Balance: 300_000,
Type: walletrpc.AddressType_WITNESS_PUBKEY_HASH,
}
// Create a np2wkh address.
resp, err = alice.NewAddress(ctxb, &lnrpc.NewAddressRequest{
resp = alice.RPC.NewAddress(&lnrpc.NewAddressRequest{
Type: lnrpc.AddressType_NESTED_PUBKEY_HASH,
})
require.NoError(t.t, err)
generatedAddr[resp.Address] = addressDetails{
Balance: 400_000,
Type: walletrpc.AddressType_HYBRID_NESTED_WITNESS_PUBKEY_HASH,
}
for addr, addressDetail := range generatedAddr {
_, err := alice.SendCoins(ctxb, &lnrpc.SendCoinsRequest{
alice.RPC.SendCoins(&lnrpc.SendCoinsRequest{
Addr: addr,
Amount: addressDetail.Balance,
SpendUnconfirmed: true,
})
require.NoError(t.t, err)
}
mineBlocks(t, net, 1, 3)
ht.MineBlocksAndAssertNumTxes(1, 3)
// Get all the accounts except LND's custom accounts.
addressLists, err := alice.WalletKitClient.ListAddresses(
ctxb, &walletrpc.ListAddressesRequest{},
addressLists := alice.RPC.ListAddresses(
&walletrpc.ListAddressesRequest{},
)
require.NoError(t.t, err)
foundAddresses := 0
for _, addressList := range addressLists.AccountWithAddresses {
addresses := addressList.Addresses
derivationPath, err := parseDerivationPath(
derivationPath, err := lntemp.ParseDerivationPath(
addressList.DerivationPath,
)
require.NoError(t.t, err)
require.NoError(ht, err)
// Should not get an account with KeyFamily - 123.
require.NotEqual(
t.t, uint32(keyLoc.KeyFamily), derivationPath[2],
ht, uint32(keyLoc.KeyFamily), derivationPath[2],
)
for _, address := range addresses {
if _, ok := generatedAddr[address.Address]; ok {
addrDetails := generatedAddr[address.Address]
require.Equal(
t.t, addrDetails.Balance,
ht, addrDetails.Balance,
address.Balance,
)
require.Equal(
t.t, addrDetails.Type,
ht, addrDetails.Type,
addressList.AddressType,
)
foundAddresses++
@ -1092,23 +1083,22 @@ func testListAddresses(net *lntest.NetworkHarness, t *harnessTest) {
}
}
require.Equal(t.t, len(generatedAddr), foundAddresses)
require.Equal(ht, len(generatedAddr), foundAddresses)
foundAddresses = 0
// Get all the accounts (including LND's custom accounts).
addressLists, err = alice.WalletKitClient.ListAddresses(
ctxb, &walletrpc.ListAddressesRequest{
addressLists = alice.RPC.ListAddresses(
&walletrpc.ListAddressesRequest{
ShowCustomAccounts: true,
},
)
require.NoError(t.t, err)
for _, addressList := range addressLists.AccountWithAddresses {
addresses := addressList.Addresses
derivationPath, err := parseDerivationPath(
derivationPath, err := lntemp.ParseDerivationPath(
addressList.DerivationPath,
)
require.NoError(t.t, err)
require.NoError(ht, err)
for _, address := range addresses {
// Check if the KeyFamily in derivation path is 123.
@ -1116,15 +1106,15 @@ func testListAddresses(net *lntest.NetworkHarness, t *harnessTest) {
// For LND's custom accounts, the address
// represents the public key.
pubkey := address.Address
require.Equal(t.t, pubkeyString, pubkey)
require.Equal(ht, pubkeyString, pubkey)
} else if _, ok := generatedAddr[address.Address]; ok {
addrDetails := generatedAddr[address.Address]
require.Equal(
t.t, addrDetails.Balance,
ht, addrDetails.Balance,
address.Balance,
)
require.Equal(
t.t, addrDetails.Type,
ht, addrDetails.Type,
addressList.AddressType,
)
foundAddresses++
@ -1132,7 +1122,7 @@ func testListAddresses(net *lntest.NetworkHarness, t *harnessTest) {
}
}
require.Equal(t.t, len(generatedAddr), foundAddresses)
require.Equal(ht, len(generatedAddr), foundAddresses)
}
func assertChannelConstraintsEqual(ht *lntemp.HarnessTest,

View File

@ -4,10 +4,6 @@
package itest
var allTestCases = []*testCase{
{
name: "list addresses",
test: testListAddresses,
},
{
name: "recovery info",
test: testGetRecoveryInfo,

View File

@ -5,8 +5,6 @@ import (
"crypto/rand"
"fmt"
"io"
"strconv"
"strings"
"testing"
"time"
@ -501,42 +499,6 @@ func getOutputIndex(t *harnessTest, miner *lntest.HarnessMiner,
return p2trOutputIndex
}
// parseDerivationPath parses a path in the form of m/x'/y'/z'/a/b into a slice
// of [x, y, z, a, b], meaning that the apostrophe is ignored and 2^31 is _not_
// added to the numbers.
func parseDerivationPath(path string) ([]uint32, error) {
path = strings.TrimSpace(path)
if len(path) == 0 {
return nil, fmt.Errorf("path cannot be empty")
}
if !strings.HasPrefix(path, "m/") {
return nil, fmt.Errorf("path must start with m/")
}
// Just the root key, no path was provided. This is valid but not useful
// in most cases.
rest := strings.ReplaceAll(path, "m/", "")
if rest == "" {
return []uint32{}, nil
}
parts := strings.Split(rest, "/")
indices := make([]uint32, len(parts))
for i := 0; i < len(parts); i++ {
part := parts[i]
if strings.Contains(parts[i], "'") {
part = strings.TrimRight(parts[i], "'")
}
parsed, err := strconv.ParseInt(part, 10, 32)
if err != nil {
return nil, fmt.Errorf("could not parse part \"%s\": "+
"%v", part, err)
}
indices[i] = uint32(parsed)
}
return indices, nil
}
// acceptChannel is used to accept a single channel that comes across. This
// should be run in a goroutine and is used to test nodes with the zero-conf
// feature bit.