Add notion of address types in wallet

In this commit we
1) Add the notion of address types in the wallet, and remove
    the emphasis on account types. Users now just request
    a segwit/nested-segwit/legacy address, and we take care
    of querying for the correct account
2) Fix a bug where a wallet could only get addresses for
    its default address type. This was a pretty minor bug,
    where a few values that should have been dynamic was
    hard coded.
This commit is contained in:
Torkel Rogstad 2019-07-10 17:14:05 +02:00
parent efc1ce4405
commit 4bdc7966d0
9 changed files with 196 additions and 39 deletions

View file

@ -0,0 +1,18 @@
package org.bitcoins.core.hd
/** The address types covered by BIP44, BIP49 and BIP84 */
sealed abstract class AddressType
object AddressType {
/** Uses BIP84 address derivation, gives bech32 address (`bc1...`) */
final case object SegWit extends AddressType
/** Uses BIP49 address derivation, gives SegWit addresses wrapped
* in P2SH addresses (`3...`)
*/
final case object NestedSegWit extends AddressType
/** Uses BIP44 address derivation (`1...`) */
final case object Legacy extends AddressType
}

View file

@ -11,6 +11,9 @@ import org.bitcoins.wallet.api.UnlockWalletError.MnemonicNotFound
import com.typesafe.config.ConfigFactory
import org.bitcoins.core.protocol.P2PKHAddress
import org.bitcoins.core.hd.HDPurposes
import org.bitcoins.core.protocol.Bech32Address
import org.bitcoins.core.protocol.P2SHAddress
import org.bitcoins.core.hd.AddressType
class LegacyWalletTest extends BitcoinSWalletTest {
@ -24,13 +27,41 @@ class LegacyWalletTest extends BitcoinSWalletTest {
addr <- wallet.getNewAddress()
account <- wallet.getDefaultAccount()
otherAddr <- wallet.getNewAddress()
thirdAddr <- wallet.getNewAddress(AddressType.Legacy)
allAddrs <- wallet.listAddresses()
} yield {
assert(account.hdAccount.purpose == HDPurposes.Legacy)
assert(allAddrs.forall(_.address.isInstanceOf[P2PKHAddress]))
assert(allAddrs.length == 2)
assert(allAddrs.length == 3)
assert(allAddrs.exists(_.address == addr))
assert(allAddrs.exists(_.address == otherAddr))
assert(allAddrs.exists(_.address == thirdAddr))
}
}
it should "generate segwit addresses" in { wallet =>
for {
account <- wallet.getDefaultAccountForType(AddressType.SegWit)
addr <- wallet.getNewAddress(AddressType.SegWit)
} yield {
assert(account.hdAccount.purpose == HDPurposes.SegWit)
assert(addr.isInstanceOf[Bech32Address])
}
}
it should "generate mixed addresses" in { wallet =>
for {
segwit <- wallet.getNewAddress(AddressType.SegWit)
legacy <- wallet.getNewAddress(AddressType.Legacy)
// TODO: uncomment this once nested segwit is implemented
// https://github.com/bitcoin-s/bitcoin-s/issues/407
// nested <- wallet.getNewAddress(AddressType.NestedSegWit)
} yield {
assert(segwit.isInstanceOf[Bech32Address])
assert(legacy.isInstanceOf[P2PKHAddress])
// TODO: uncomment this once nested segwit is implemented
// https://github.com/bitcoin-s/bitcoin-s/issues/407
// assert(nested.isInstanceOf[P2SHAddress])
}
}
}

View file

@ -12,6 +12,8 @@ import com.typesafe.config.ConfigFactory
import org.bitcoins.core.protocol.P2PKHAddress
import org.bitcoins.core.protocol.Bech32Address
import org.bitcoins.core.hd.HDPurposes
import org.bitcoins.core.hd.AddressType
import org.bitcoins.core.protocol.P2SHAddress
class SegwitWalletTest extends BitcoinSWalletTest {
@ -26,13 +28,39 @@ class SegwitWalletTest extends BitcoinSWalletTest {
addr <- wallet.getNewAddress()
account <- wallet.getDefaultAccount()
otherAddr <- wallet.getNewAddress()
thirdAddr <- wallet.getNewAddress(AddressType.SegWit)
allAddrs <- wallet.listAddresses()
} yield {
assert(account.hdAccount.purpose == HDPurposes.SegWit)
assert(allAddrs.forall(_.address.isInstanceOf[Bech32Address]))
assert(allAddrs.length == 2)
assert(allAddrs.length == 3)
assert(allAddrs.exists(_.address == addr))
assert(allAddrs.exists(_.address == otherAddr))
assert(allAddrs.exists(_.address == thirdAddr))
}
}
it should "generate legacy addresses" in { wallet =>
for {
account <- wallet.getDefaultAccountForType(AddressType.Legacy)
addr <- wallet.getNewAddress(AddressType.Legacy)
} yield {
assert(account.hdAccount.purpose == HDPurposes.Legacy)
assert(addr.isInstanceOf[P2PKHAddress])
}
}
it should "generate mixed addresses" in { wallet =>
for {
segwit <- wallet.getNewAddress(AddressType.SegWit)
legacy <- wallet.getNewAddress(AddressType.Legacy)
// TODO: uncomment this once nested segwit is implemented
// nested <- wallet.getNewAddress(AddressType.NestedSegWit)
} yield {
assert(segwit.isInstanceOf[Bech32Address])
assert(legacy.isInstanceOf[P2PKHAddress])
// TODO: uncomment this once nested segwit is implemented
// assert(nested.isInstanceOf[P2SHAddress])
}
}
}

View file

@ -17,6 +17,7 @@ import org.bitcoins.core.hd.HDChainType.External
import org.bitcoins.core.protocol.BitcoinAddress
import org.bitcoins.wallet.models.AddressDb
import org.bitcoins.core.hd.HDChain
import org.bitcoins.core.hd.HDPurpose
class WalletUnitTest extends BitcoinSWalletTest {
@ -53,12 +54,15 @@ class WalletUnitTest extends BitcoinSWalletTest {
val wallet = walletApi.asInstanceOf[Wallet]
def getMostRecent(
purpose: HDPurpose,
chain: HDChainType,
acctIndex: Int
): Future[AddressDb] = {
val recentOptFut: Future[Option[AddressDb]] = chain match {
case Change => wallet.addressDAO.findMostRecentChange(acctIndex)
case External => wallet.addressDAO.findMostRecentExternal(acctIndex)
case Change =>
wallet.addressDAO.findMostRecentChange(purpose, acctIndex)
case External =>
wallet.addressDAO.findMostRecentExternal(purpose, acctIndex)
}
recentOptFut.map {
@ -68,10 +72,11 @@ class WalletUnitTest extends BitcoinSWalletTest {
}
def assertIndexIs(
purpose: HDPurpose,
chain: HDChainType,
addrIndex: Int,
accountIndex: Int): Future[Assertion] = {
getMostRecent(chain, accountIndex) map { addr =>
getMostRecent(purpose, chain, accountIndex) map { addr =>
assert(addr.path.address.index == addrIndex)
}
}
@ -80,10 +85,13 @@ class WalletUnitTest extends BitcoinSWalletTest {
val addrRange = 0 to addressesToGenerate
/**
* Generate some addresse, and verify that the correct address index is
* Generate some addresses, and verify that the correct address index is
* being reported
*/
def testChain(chain: HDChainType, accIdx: Int): Future[Assertion] = {
def testChain(
purpose: HDPurpose,
chain: HDChainType,
accIdx: Int): Future[Assertion] = {
val getAddrFunc: () => Future[BitcoinAddress] = chain match {
case Change => wallet.getNewChangeAddress _
case External => wallet.getNewAddress _
@ -91,8 +99,10 @@ class WalletUnitTest extends BitcoinSWalletTest {
for {
_ <- {
val addrF = chain match {
case Change => wallet.addressDAO.findMostRecentChange(accIdx)
case External => wallet.addressDAO.findMostRecentExternal(accIdx)
case Change =>
wallet.addressDAO.findMostRecentChange(purpose, accIdx)
case External =>
wallet.addressDAO.findMostRecentExternal(purpose, accIdx)
}
addrF.map {
case Some(addr) =>
@ -102,11 +112,12 @@ class WalletUnitTest extends BitcoinSWalletTest {
}
}
_ <- FutureUtil.sequentially(addrRange)(_ => getAddrFunc())
_ <- assertIndexIs(chain,
_ <- assertIndexIs(purpose,
chain,
accountIndex = accIdx,
addrIndex = addressesToGenerate)
newest <- getAddrFunc()
res <- getMostRecent(chain, accIdx).map { found =>
res <- getMostRecent(purpose, chain, accIdx).map { found =>
assert(found.address == newest)
assert(found.path.address.index == addressesToGenerate + 1)
}
@ -116,8 +127,8 @@ class WalletUnitTest extends BitcoinSWalletTest {
for {
account <- wallet.getDefaultAccount()
accIdx = account.hdAccount.index
_ <- testChain(External, accIdx)
res <- testChain(Change, accIdx)
_ <- testChain(wallet.DEFAULT_HD_PURPOSE, External, accIdx)
res <- testChain(wallet.DEFAULT_HD_PURPOSE, Change, accIdx)
} yield res
}

View file

@ -15,6 +15,7 @@ import scala.concurrent.Future
import scala.concurrent.ExecutionContext
import org.bitcoins.wallet.config.WalletAppConfig
import org.bitcoins.core.bloom.BloomFilter
import org.bitcoins.core.hd.AddressType
/**
* API for the wallet project.
@ -83,12 +84,15 @@ trait LockedWalletApi extends WalletApi {
def listAddresses(): Future[Vector[AddressDb]]
/**
* Gets a new external address from the specified
* account. Calling this method multiple
* Gets a new external address with the specified
* type. Calling this method multiple
* times will return the same address, until it has
* received funds.
*
*/
def getNewAddress(account: AccountDb): Future[BitcoinAddress]
// TODO: Last sentence is not true, implement that
// https://github.com/bitcoin-s/bitcoin-s/issues/628
def getNewAddress(addressType: AddressType): Future[BitcoinAddress]
/**
* Gets a new external address from the default account.
@ -98,8 +102,7 @@ trait LockedWalletApi extends WalletApi {
*/
def getNewAddress(): Future[BitcoinAddress] = {
for {
account <- getDefaultAccount()
address <- getNewAddress(account)
address <- getNewAddress(walletConfig.defaultAddressType)
} yield address
}
@ -129,6 +132,10 @@ trait LockedWalletApi extends WalletApi {
*/
protected[wallet] def getDefaultAccount(): Future[AccountDb]
/** Fetches the default account for the given address/acount kind */
protected[wallet] def getDefaultAccountForType(
addressType: AddressType): Future[AccountDb]
/**
* Unlocks the wallet with the provided passphrase,
* making it possible to send transactions.

View file

@ -9,6 +9,7 @@ import scala.util.Success
import java.nio.file.Files
import org.bitcoins.core.hd.HDPurpose
import org.bitcoins.core.hd.HDPurposes
import org.bitcoins.core.hd.AddressType
case class WalletAppConfig(private val conf: Config*) extends AppConfig {
override val configOverrides: List[Config] = conf.toList
@ -27,6 +28,17 @@ case class WalletAppConfig(private val conf: Config*) extends AppConfig {
throw new RuntimeException(s"$other is not a valid account type!")
}
lazy val defaultAddressType: AddressType = {
defaultAccountKind match {
case HDPurposes.Legacy => AddressType.Legacy
case HDPurposes.NestedSegWit => AddressType.NestedSegWit
case HDPurposes.SegWit => AddressType.SegWit
// todo: validate this pre-app startup
case other =>
throw new RuntimeException(s"$other is not a valid account type!")
}
}
lazy val bloomFalsePositiveRate: Double =
config.getDouble("wallet.bloomFalsePositiveRate")

View file

@ -8,6 +8,10 @@ import org.bitcoins.core.hd.HDCoin
import org.bitcoins.core.protocol.blockchain.TestNetChainParams
import org.bitcoins.core.protocol.blockchain.RegTestNetChainParams
import org.bitcoins.core.protocol.blockchain.MainNetChainParams
import org.bitcoins.core.hd.HDPurpose
import org.bitcoins.core.hd.AddressType._
import org.bitcoins.core.hd.AddressType
import org.bitcoins.core.hd.HDPurposes
/**
* Provides functionality related enumerating accounts. Account
@ -19,23 +23,51 @@ private[wallet] trait AccountHandling { self: LockedWallet =>
override def listAccounts(): Future[Vector[AccountDb]] =
accountDAO.findAll()
private def getOrThrowAccount(account: Option[AccountDb]): AccountDb =
account.getOrElse(
throw new RuntimeException(
s"Could not find account with ${DEFAULT_HD_COIN.purpose.constant} " +
s"purpose field and ${DEFAULT_HD_COIN.coinType.toInt} coin field"))
/** @inheritdoc */
override protected[wallet] def getDefaultAccount(): Future[AccountDb] = {
for {
account <- accountDAO.read((DEFAULT_HD_COIN, 0))
} yield
account.getOrElse(
throw new RuntimeException(
s"Could not find account with ${DEFAULT_HD_COIN.purpose.constant} " +
s"purpose field and ${DEFAULT_HD_COIN.coinType.toInt} coin field"))
} yield getOrThrowAccount(account)
}
/** @inheritdoc */
override protected[wallet] def getDefaultAccountForType(
addressType: AddressType): Future[AccountDb] = {
val hdCoin = addressType match {
case Legacy => HDCoin(HDPurposes.Legacy, DEFAULT_HD_COIN_TYPE)
case NestedSegWit => HDCoin(HDPurposes.NestedSegWit, DEFAULT_HD_COIN_TYPE)
case SegWit => HDCoin(HDPurposes.SegWit, DEFAULT_HD_COIN_TYPE)
}
for {
account <- accountDAO.read((hdCoin, 0))
} yield getOrThrowAccount(account)
}
/** The default HD coin for this wallet, read from config */
protected[wallet] lazy val DEFAULT_HD_COIN: HDCoin = {
val coinType = chainParams match {
case MainNetChainParams => HDCoinType.Bitcoin
case RegTestNetChainParams | TestNetChainParams => HDCoinType.Testnet
}
val coinType = DEFAULT_HD_COIN_TYPE
HDCoin(walletConfig.defaultAccountKind, coinType)
}
/** The default HD coin type for this wallet, derived from
* the network we're on
*/
protected[wallet] lazy val DEFAULT_HD_COIN_TYPE: HDCoinType = {
chainParams match {
case MainNetChainParams => HDCoinType.Bitcoin
case RegTestNetChainParams | TestNetChainParams => HDCoinType.Testnet
}
}
/** The default HD purpose for this wallet, read from config */
protected[wallet] lazy val DEFAULT_HD_PURPOSE: HDPurpose =
walletConfig.defaultAccountKind
}

View file

@ -8,7 +8,6 @@ import org.bitcoins.wallet.models.AccountDb
import org.bitcoins.core.hd.HDChainType
import org.bitcoins.core.protocol.BitcoinAddress
import org.bitcoins.core.hd.HDPath
import org.bitcoins.core.hd.HDAccount
import org.bitcoins.core.hd.HDAddress
import scala.util.Failure
import scala.util.Success
@ -22,6 +21,7 @@ import org.bitcoins.core.protocol.transaction.TransactionOutput
import org.bitcoins.core.protocol.script.ScriptPubKey
import org.bitcoins.core.protocol.transaction.TransactionOutPoint
import org.bitcoins.core.number.UInt32
import org.bitcoins.core.hd.AddressType
/**
* Provides functionality related to addresses. This includes
@ -75,9 +75,10 @@ private[wallet] trait AddressHandling { self: LockedWallet =>
val lastAddrOptF = chainType match {
case HDChainType.External =>
addressDAO.findMostRecentExternal(accountIndex)
addressDAO.findMostRecentExternal(account.hdAccount.purpose,
accountIndex)
case HDChainType.Change =>
addressDAO.findMostRecentChange(accountIndex)
addressDAO.findMostRecentChange(account.hdAccount.purpose, accountIndex)
}
lastAddrOptF.flatMap { lastAddrOpt =>
@ -88,8 +89,7 @@ private[wallet] trait AddressHandling { self: LockedWallet =>
s"Found previous address at path=${addr.path}, next=$next")
next
case None =>
val account = HDAccount(DEFAULT_HD_COIN, accountIndex)
val chain = account.toChain(chainType)
val chain = account.hdAccount.toChain(chainType)
val address = HDAddress(chain, 0)
val path = address.toPath
logger.debug(s"Did not find previous address, next=$path")
@ -135,12 +135,21 @@ private[wallet] trait AddressHandling { self: LockedWallet =>
}
}
/** @inheritdoc */
override def getNewAddress(account: AccountDb): Future[BitcoinAddress] = {
val addrF = getNewAddressHelper(account, HDChainType.External)
def getNewAddress(account: AccountDb): Future[BitcoinAddress] = {
val addrF =
getNewAddressHelper(account, HDChainType.External)
addrF
}
/** @inheritdoc */
override def getNewAddress(
addressType: AddressType): Future[BitcoinAddress] = {
for {
account <- getDefaultAccountForType(addressType)
address <- getNewAddressHelper(account, HDChainType.External)
} yield address
}
/** Generates a new change address */
override protected[wallet] def getNewChangeAddress(
account: AccountDb): Future[BitcoinAddress] = {

View file

@ -12,6 +12,7 @@ import org.bitcoins.core.hd.HDChainType
import org.bitcoins.wallet.config.WalletAppConfig
import org.bitcoins.core.crypto.ECPublicKey
import org.bitcoins.core.protocol.script.ScriptPubKey
import org.bitcoins.core.hd.HDPurpose
case class AddressDAO()(
implicit val ec: ExecutionContext,
@ -44,8 +45,11 @@ case class AddressDAO()(
/**
* Finds the most recent change address in the wallet, if any
*/
def findMostRecentChange(accountIndex: Int): Future[Option[AddressDb]] = {
val query = findMostRecentForChain(accountIndex, HDChainType.Change)
def findMostRecentChange(
purpose: HDPurpose,
accountIndex: Int): Future[Option[AddressDb]] = {
val query =
findMostRecentForChain(purpose, accountIndex, HDChainType.Change)
database.run(query)
}
@ -63,9 +67,11 @@ case class AddressDAO()(
}
private def findMostRecentForChain(
purpose: HDPurpose,
accountIndex: Int,
chain: HDChainType): SqlAction[Option[AddressDb], NoStream, Effect.Read] = {
addressesForAccountQuery(accountIndex)
.filter(_.purpose === purpose)
.filter(_.accountChainType === chain)
.sortBy(_.addressIndex.desc)
.take(1)
@ -76,8 +82,11 @@ case class AddressDAO()(
/**
* Finds the most recent external address in the wallet, if any
*/
def findMostRecentExternal(accountIndex: Int): Future[Option[AddressDb]] = {
val query = findMostRecentForChain(accountIndex, HDChainType.External)
def findMostRecentExternal(
purpose: HDPurpose,
accountIndex: Int): Future[Option[AddressDb]] = {
val query =
findMostRecentForChain(purpose, accountIndex, HDChainType.External)
database.run(query)
}
}