wallet: Refactor AddressHandling to be account specific (#5825)

* wallet: Refactor AddressHandling to be account specific

* fix compile

* Add test

* Fix AddressType.fromPurpose() for HDPurpose.Multisig

* Fix nodeTest/test

* Don't use HDPurpose.Multisig in test

* Return None for HDPurpose.Multisig inside of AddressType.fromPurpose()
This commit is contained in:
Chris Stewart 2024-12-30 06:48:47 -06:00 committed by GitHub
parent 051f6ed5cb
commit b1403155fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 57 additions and 27 deletions

View file

@ -47,4 +47,14 @@ object AddressType extends StringFactory[AddressType] {
sys.error(s"Could not find address type for string=$string")
}
}
def fromPurpose(purpose: HDPurpose): Option[AddressType] = {
purpose match {
case HDPurpose.Legacy => Some(Legacy)
case HDPurpose.NestedSegWit => Some(NestedSegWit)
case HDPurpose.SegWit => Some(SegWit)
case HDPurpose.Taproot => Some(P2TR)
case _: HDPurpose | HDPurpose.Multisig => None
}
}
}

View file

@ -79,7 +79,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest {
condition(
expectedBalance = 6.bitcoin - TestAmount - fee,
expectedUtxos = 3,
expectedAddresses = 7
expectedAddresses = 4
)
}
@ -91,7 +91,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest {
condition(
expectedBalance = (6.bitcoin - TestAmount - firstTxFee) + TestAmount,
expectedUtxos = 4,
expectedAddresses = 8
expectedAddresses = 5
)
}
@ -189,7 +189,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest {
for {
addresses <- wallet.addressHandling.listAddresses()
utxos <- wallet.utxoHandling.listUtxos()
_ = assert(addresses.size == 6)
_ = assert(addresses.size == 3)
_ = assert(utxos.size == 3)
address <- wallet.getNewAddress()
@ -199,7 +199,7 @@ class NeutrinoNodeWithWalletTest extends NodeTestWithCachedBitcoindNewest {
addresses <- wallet.addressHandling.listAddresses()
utxos <- wallet.utxoHandling.listUtxos()
_ = assert(addresses.size == 7)
_ = assert(addresses.size == 4)
_ = assert(utxos.size == 3)
_ <-
bitcoind.getNewAddress

View file

@ -1,6 +1,7 @@
package org.bitcoins.wallet
import org.bitcoins.core.currency.{Bitcoins, Satoshis}
import org.bitcoins.core.hd.{AddressType, HDPurpose}
import org.bitcoins.core.protocol.script.EmptyScriptPubKey
import org.bitcoins.core.protocol.transaction.TransactionOutput
import org.bitcoins.core.wallet.utxo.StorageLocationTag.HotStorage
@ -15,6 +16,7 @@ import org.bitcoins.wallet.models.{ScriptPubKeyDAO, SpendingInfoDAO}
import org.scalatest.FutureOutcome
import scala.concurrent.Future
import scala.util.Random
class AddressHandlingTest extends BitcoinSWalletTest {
type FixtureParam = FundedWallet
@ -58,7 +60,7 @@ class AddressHandlingTest extends BitcoinSWalletTest {
s"Wallet must contain address in specific after generating it"
)
assert(
doesNotExist,
!doesNotExist,
s"Wallet must NOT contain address in default account when address is specified"
)
}
@ -282,4 +284,28 @@ class AddressHandlingTest extends BitcoinSWalletTest {
assert(spkOpt.isDefined)
}
}
it must "listaddresses for current default purpose, not all purposes in the wallet" in {
(fundedWallet: FundedWallet) =>
val wallet = fundedWallet.wallet
val randomPurpose = Random
.shuffle(
HDPurpose.all
// maybe should remove HDPurpose.MultiSig as it doesn't make sense
// to use a single multisig key for address generation?
.filterNot(_ == HDPurpose.Multisig)
.filterNot(_ == fundedWallet.walletConfig.defaultPurpose))
.head
val addrType = AddressType.fromPurpose(randomPurpose).get
val addrF = wallet.addressHandling.getNewAddress(addrType)
for {
addr <- addrF
addresses <- wallet.addressHandling.listAddresses()
} yield {
assert(!addresses.exists(_.address == addr))
}
}
}

View file

@ -359,22 +359,13 @@ case class AccountHandling(
override def listSpentAddresses(
account: HDAccount
): Future[Vector[AddressDb]] = {
val spentAddressesF = addressDAO.getSpentAddresses
spentAddressesF.map { spentAddresses =>
spentAddresses.filter(addr => HDAccount.isSameAccount(addr.path, account))
}
addressDAO.getSpentAddresses(account)
}
override def listFundedAddresses(
account: HDAccount
): Future[Vector[(AddressDb, CurrencyUnit)]] = {
val spentAddressesF = addressDAO.getFundedAddresses
spentAddressesF.map { spentAddresses =>
spentAddresses.filter(addr =>
HDAccount.isSameAccount(addr._1.path, account))
}
addressDAO.getFundedAddresses(account)
}
private def getNewAddressHelperAction(

View file

@ -51,23 +51,24 @@ case class AddressHandling(
private val networkParameters: NetworkParameters = walletConfig.network
override def listAddresses(): Future[Vector[AddressDb]] =
addressDAO.findAllAddresses()
addressDAO.findAllAddressDbForAccount(walletConfig.defaultAccount)
override def listSpentAddresses(): Future[Vector[AddressDb]] = {
addressDAO.getSpentAddresses
addressDAO.getSpentAddresses(walletConfig.defaultAccount)
}
override def listFundedAddresses()
: Future[Vector[(AddressDb, CurrencyUnit)]] = {
addressDAO.getFundedAddresses
addressDAO.getFundedAddresses(walletConfig.defaultAccount)
}
override def listUnusedAddresses(): Future[Vector[AddressDb]] = {
addressDAO.getUnusedAddresses
addressDAO.getUnusedAddresses(walletConfig.defaultAccount)
}
override def listScriptPubKeys(): Future[Vector[ScriptPubKeyDb]] =
override def listScriptPubKeys(): Future[Vector[ScriptPubKeyDb]] = {
scriptPubKeyDAO.findAll()
}
override def watchScriptPubKey(
scriptPubKey: ScriptPubKey
@ -86,7 +87,7 @@ case class AddressHandling(
case (out, index)
if spks.map(_.scriptPubKey).contains(out.scriptPubKey) =>
(out, TransactionOutPoint(transaction.txId, UInt32(index)))
}.toVector
}
/** Derives a new address in the wallet for the given account and chain type
* (change/external). After deriving the address it inserts it into our table
@ -232,11 +233,10 @@ case class AddressHandling(
/** @inheritdoc */
override def getUnusedAddress: Future[BitcoinAddress] = {
for {
account <- accountHandling.getDefaultAccount()
addresses <- addressDAO.getUnusedAddresses(account.hdAccount)
addresses <- addressDAO.getUnusedAddresses(walletConfig.defaultAccount)
address <-
if (addresses.isEmpty) {
accountHandling.getNewAddress(account.hdAccount)
accountHandling.getNewAddress(walletConfig.defaultAccount)
} else {
Future.successful(addresses.head.address)
}

View file

@ -282,7 +282,7 @@ case class AddressDAO()(implicit
getUnusedAddresses.map(_.filter(_.path.account == hdAccount))
}
def getSpentAddresses: Future[Vector[AddressDb]] = {
def getSpentAddresses(hdAccount: HDAccount): Future[Vector[AddressDb]] = {
val query =
table
.join(spkTable)
@ -298,9 +298,11 @@ case class AddressDAO()(implicit
res.map { case (addrRec, spkRec) =>
addrRec.toAddressDb(spkRec.scriptPubKey)
})
.map(_.filter(_.path.account == hdAccount))
}
def getFundedAddresses: Future[Vector[(AddressDb, CurrencyUnit)]] = {
def getFundedAddresses(
account: HDAccount): Future[Vector[(AddressDb, CurrencyUnit)]] = {
val query = table
.join(spkTable)
.on(_.scriptPubKeyId === _.id)
@ -313,6 +315,7 @@ case class AddressDAO()(implicit
.map(_.map { case ((addrRec, spkRec), utxoDb) =>
(addrRec.toAddressDb(spkRec.scriptPubKey), utxoDb.value)
})
.map(_.filter(_._1.path.account == account))
}
def findByScriptPubKey(spk: ScriptPubKey): Future[Option[AddressDb]] = {