From b1403155fc8a4aa38b769d1e7b1233cb233bc1ad Mon Sep 17 00:00:00 2001 From: Chris Stewart Date: Mon, 30 Dec 2024 06:48:47 -0600 Subject: [PATCH] wallet: Refactor AddressHandling to be account specific (#5825) * wallet: Refactor AddressHandling to be account specific * fix compile * Add test * Fix AddressType.fromPurpose() for HDPurpose.Multisig * Fix nodeTest/test * Don't use HDPurpose.Multisig in test * Return None for HDPurpose.Multisig inside of AddressType.fromPurpose() --- .../org/bitcoins/core/hd/AddressType.scala | 10 +++++++ .../node/NeutrinoNodeWithWalletTest.scala | 8 +++--- .../bitcoins/wallet/AddressHandlingTest.scala | 28 ++++++++++++++++++- .../wallet/internal/AccountHandling.scala | 13 ++------- .../wallet/internal/AddressHandling.scala | 18 ++++++------ .../bitcoins/wallet/models/AddressDAO.scala | 7 +++-- 6 files changed, 57 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/bitcoins/core/hd/AddressType.scala b/core/src/main/scala/org/bitcoins/core/hd/AddressType.scala index e91266a46f..61178cb750 100644 --- a/core/src/main/scala/org/bitcoins/core/hd/AddressType.scala +++ b/core/src/main/scala/org/bitcoins/core/hd/AddressType.scala @@ -47,4 +47,14 @@ object AddressType extends StringFactory[AddressType] { sys.error(s"Could not find address type for string=$string") } } + + def fromPurpose(purpose: HDPurpose): Option[AddressType] = { + purpose match { + case HDPurpose.Legacy => Some(Legacy) + case HDPurpose.NestedSegWit => Some(NestedSegWit) + case HDPurpose.SegWit => Some(SegWit) + case HDPurpose.Taproot => Some(P2TR) + case _: HDPurpose | HDPurpose.Multisig => None + } + } } diff --git a/node-test/src/test/scala/org/bitcoins/node/NeutrinoNodeWithWalletTest.scala b/node-test/src/test/scala/org/bitcoins/node/NeutrinoNodeWithWalletTest.scala index 4cea51792b..fc7908ba3e 100644 --- a/node-test/src/test/scala/org/bitcoins/node/NeutrinoNodeWithWalletTest.scala +++ b/node-test/src/test/scala/org/bitcoins/node/NeutrinoNodeWithWalletTest.scala @@ -79,7 +79,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest { condition( expectedBalance = 6.bitcoin - TestAmount - fee, expectedUtxos = 3, - expectedAddresses = 7 + expectedAddresses = 4 ) } @@ -91,7 +91,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest { condition( expectedBalance = (6.bitcoin - TestAmount - firstTxFee) + TestAmount, expectedUtxos = 4, - expectedAddresses = 8 + expectedAddresses = 5 ) } @@ -189,7 +189,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest { for { addresses <- wallet.addressHandling.listAddresses() utxos <- wallet.utxoHandling.listUtxos() - _ = assert(addresses.size == 6) + _ = assert(addresses.size == 3) _ = assert(utxos.size == 3) address <- wallet.getNewAddress() @@ -199,7 +199,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest { addresses <- wallet.addressHandling.listAddresses() utxos <- wallet.utxoHandling.listUtxos() - _ = assert(addresses.size == 7) + _ = assert(addresses.size == 4) _ = assert(utxos.size == 3) _ <- bitcoind.getNewAddress diff --git a/wallet-test/src/test/scala/org/bitcoins/wallet/AddressHandlingTest.scala b/wallet-test/src/test/scala/org/bitcoins/wallet/AddressHandlingTest.scala index f6cb036641..c10f2fd086 100644 --- a/wallet-test/src/test/scala/org/bitcoins/wallet/AddressHandlingTest.scala +++ b/wallet-test/src/test/scala/org/bitcoins/wallet/AddressHandlingTest.scala @@ -1,6 +1,7 @@ package org.bitcoins.wallet import org.bitcoins.core.currency.{Bitcoins, Satoshis} +import org.bitcoins.core.hd.{AddressType, HDPurpose} import org.bitcoins.core.protocol.script.EmptyScriptPubKey import org.bitcoins.core.protocol.transaction.TransactionOutput import org.bitcoins.core.wallet.utxo.StorageLocationTag.HotStorage @@ -15,6 +16,7 @@ import org.bitcoins.wallet.models.{ScriptPubKeyDAO, SpendingInfoDAO} import org.scalatest.FutureOutcome import scala.concurrent.Future +import scala.util.Random class AddressHandlingTest extends BitcoinSWalletTest { type FixtureParam = FundedWallet @@ -58,7 +60,7 @@ class AddressHandlingTest extends BitcoinSWalletTest { s"Wallet must contain address in specific after generating it" ) assert( - doesNotExist, + !doesNotExist, s"Wallet must NOT contain address in default account when address is specified" ) } @@ -282,4 +284,28 @@ class AddressHandlingTest extends BitcoinSWalletTest { assert(spkOpt.isDefined) } } + + it must "listaddresses for current default purpose, not all purposes in the wallet" in { + (fundedWallet: FundedWallet) => + val wallet = fundedWallet.wallet + val randomPurpose = Random + .shuffle( + HDPurpose.all + // maybe should remove HDPurpose.MultiSig as it doesn't make sense + // to use a single multisig key for address generation? + .filterNot(_ == HDPurpose.Multisig) + .filterNot(_ == fundedWallet.walletConfig.defaultPurpose)) + .head + + val addrType = AddressType.fromPurpose(randomPurpose).get + val addrF = wallet.addressHandling.getNewAddress(addrType) + + for { + addr <- addrF + addresses <- wallet.addressHandling.listAddresses() + } yield { + assert(!addresses.exists(_.address == addr)) + } + } + } diff --git a/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala b/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala index 03d4684cf2..3f43504d93 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala @@ -359,22 +359,13 @@ case class AccountHandling( override def listSpentAddresses( account: HDAccount ): Future[Vector[AddressDb]] = { - val spentAddressesF = addressDAO.getSpentAddresses - - spentAddressesF.map { spentAddresses => - spentAddresses.filter(addr => HDAccount.isSameAccount(addr.path, account)) - } + addressDAO.getSpentAddresses(account) } override def listFundedAddresses( account: HDAccount ): Future[Vector[(AddressDb, CurrencyUnit)]] = { - val spentAddressesF = addressDAO.getFundedAddresses - - spentAddressesF.map { spentAddresses => - spentAddresses.filter(addr => - HDAccount.isSameAccount(addr._1.path, account)) - } + addressDAO.getFundedAddresses(account) } private def getNewAddressHelperAction( diff --git a/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala b/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala index f648980022..a3d6ead767 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala @@ -51,23 +51,24 @@ case class AddressHandling( private val networkParameters: NetworkParameters = walletConfig.network override def listAddresses(): Future[Vector[AddressDb]] = - addressDAO.findAllAddresses() + addressDAO.findAllAddressDbForAccount(walletConfig.defaultAccount) override def listSpentAddresses(): Future[Vector[AddressDb]] = { - addressDAO.getSpentAddresses + addressDAO.getSpentAddresses(walletConfig.defaultAccount) } override def listFundedAddresses() : Future[Vector[(AddressDb, CurrencyUnit)]] = { - addressDAO.getFundedAddresses + addressDAO.getFundedAddresses(walletConfig.defaultAccount) } override def listUnusedAddresses(): Future[Vector[AddressDb]] = { - addressDAO.getUnusedAddresses + addressDAO.getUnusedAddresses(walletConfig.defaultAccount) } - override def listScriptPubKeys(): Future[Vector[ScriptPubKeyDb]] = + override def listScriptPubKeys(): Future[Vector[ScriptPubKeyDb]] = { scriptPubKeyDAO.findAll() + } override def watchScriptPubKey( scriptPubKey: ScriptPubKey @@ -86,7 +87,7 @@ case class AddressHandling( case (out, index) if spks.map(_.scriptPubKey).contains(out.scriptPubKey) => (out, TransactionOutPoint(transaction.txId, UInt32(index))) - }.toVector + } /** Derives a new address in the wallet for the given account and chain type * (change/external). After deriving the address it inserts it into our table @@ -232,11 +233,10 @@ case class AddressHandling( /** @inheritdoc */ override def getUnusedAddress: Future[BitcoinAddress] = { for { - account <- accountHandling.getDefaultAccount() - addresses <- addressDAO.getUnusedAddresses(account.hdAccount) + addresses <- addressDAO.getUnusedAddresses(walletConfig.defaultAccount) address <- if (addresses.isEmpty) { - accountHandling.getNewAddress(account.hdAccount) + accountHandling.getNewAddress(walletConfig.defaultAccount) } else { Future.successful(addresses.head.address) } diff --git a/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala b/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala index 743bee3b64..621bf36b43 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala @@ -282,7 +282,7 @@ case class AddressDAO()(implicit getUnusedAddresses.map(_.filter(_.path.account == hdAccount)) } - def getSpentAddresses: Future[Vector[AddressDb]] = { + def getSpentAddresses(hdAccount: HDAccount): Future[Vector[AddressDb]] = { val query = table .join(spkTable) @@ -298,9 +298,11 @@ case class AddressDAO()(implicit res.map { case (addrRec, spkRec) => addrRec.toAddressDb(spkRec.scriptPubKey) }) + .map(_.filter(_.path.account == hdAccount)) } - def getFundedAddresses: Future[Vector[(AddressDb, CurrencyUnit)]] = { + def getFundedAddresses( + account: HDAccount): Future[Vector[(AddressDb, CurrencyUnit)]] = { val query = table .join(spkTable) .on(_.scriptPubKeyId === _.id) @@ -313,6 +315,7 @@ case class AddressDAO()(implicit .map(_.map { case ((addrRec, spkRec), utxoDb) => (addrRec.toAddressDb(spkRec.scriptPubKey), utxoDb.value) }) + .map(_.filter(_._1.path.account == account)) } def findByScriptPubKey(spk: ScriptPubKey): Future[Option[AddressDb]] = {