diff --git a/core/src/main/scala/org/bitcoins/core/hd/AddressType.scala b/core/src/main/scala/org/bitcoins/core/hd/AddressType.scala new file mode 100644 index 0000000000..e7b12bdc0b --- /dev/null +++ b/core/src/main/scala/org/bitcoins/core/hd/AddressType.scala @@ -0,0 +1,18 @@ +package org.bitcoins.core.hd + +/** The address types covered by BIP44, BIP49 and BIP84 */ +sealed abstract class AddressType + +object AddressType { + + /** Uses BIP84 address derivation, gives bech32 address (`bc1...`) */ + final case object SegWit extends AddressType + + /** Uses BIP49 address derivation, gives SegWit addresses wrapped + * in P2SH addresses (`3...`) + */ + final case object NestedSegWit extends AddressType + + /** Uses BIP44 address derivation (`1...`) */ + final case object Legacy extends AddressType +} diff --git a/wallet-test/src/test/scala/org/bitcoins/wallet/LegacyWalletTest.scala b/wallet-test/src/test/scala/org/bitcoins/wallet/LegacyWalletTest.scala index 1610563fa0..63f258d93d 100644 --- a/wallet-test/src/test/scala/org/bitcoins/wallet/LegacyWalletTest.scala +++ b/wallet-test/src/test/scala/org/bitcoins/wallet/LegacyWalletTest.scala @@ -11,6 +11,9 @@ import org.bitcoins.wallet.api.UnlockWalletError.MnemonicNotFound import com.typesafe.config.ConfigFactory import org.bitcoins.core.protocol.P2PKHAddress import org.bitcoins.core.hd.HDPurposes +import org.bitcoins.core.protocol.Bech32Address +import org.bitcoins.core.protocol.P2SHAddress +import org.bitcoins.core.hd.AddressType class LegacyWalletTest extends BitcoinSWalletTest { @@ -24,13 +27,41 @@ class LegacyWalletTest extends BitcoinSWalletTest { addr <- wallet.getNewAddress() account <- wallet.getDefaultAccount() otherAddr <- wallet.getNewAddress() + thirdAddr <- wallet.getNewAddress(AddressType.Legacy) allAddrs <- wallet.listAddresses() } yield { assert(account.hdAccount.purpose == HDPurposes.Legacy) assert(allAddrs.forall(_.address.isInstanceOf[P2PKHAddress])) - assert(allAddrs.length == 2) + assert(allAddrs.length == 3) assert(allAddrs.exists(_.address == addr)) assert(allAddrs.exists(_.address == otherAddr)) + assert(allAddrs.exists(_.address == thirdAddr)) + } + } + + it should "generate segwit addresses" in { wallet => + for { + account <- wallet.getDefaultAccountForType(AddressType.SegWit) + addr <- wallet.getNewAddress(AddressType.SegWit) + } yield { + assert(account.hdAccount.purpose == HDPurposes.SegWit) + assert(addr.isInstanceOf[Bech32Address]) + } + } + + it should "generate mixed addresses" in { wallet => + for { + segwit <- wallet.getNewAddress(AddressType.SegWit) + legacy <- wallet.getNewAddress(AddressType.Legacy) + // TODO: uncomment this once nested segwit is implemented + // https://github.com/bitcoin-s/bitcoin-s/issues/407 + // nested <- wallet.getNewAddress(AddressType.NestedSegWit) + } yield { + assert(segwit.isInstanceOf[Bech32Address]) + assert(legacy.isInstanceOf[P2PKHAddress]) + // TODO: uncomment this once nested segwit is implemented + // https://github.com/bitcoin-s/bitcoin-s/issues/407 + // assert(nested.isInstanceOf[P2SHAddress]) } } } diff --git a/wallet-test/src/test/scala/org/bitcoins/wallet/SegwitWalletTest.scala b/wallet-test/src/test/scala/org/bitcoins/wallet/SegwitWalletTest.scala index 5f406b7009..a09e270ccb 100644 --- a/wallet-test/src/test/scala/org/bitcoins/wallet/SegwitWalletTest.scala +++ b/wallet-test/src/test/scala/org/bitcoins/wallet/SegwitWalletTest.scala @@ -12,6 +12,8 @@ import com.typesafe.config.ConfigFactory import org.bitcoins.core.protocol.P2PKHAddress import org.bitcoins.core.protocol.Bech32Address import org.bitcoins.core.hd.HDPurposes +import org.bitcoins.core.hd.AddressType +import org.bitcoins.core.protocol.P2SHAddress class SegwitWalletTest extends BitcoinSWalletTest { @@ -26,13 +28,39 @@ class SegwitWalletTest extends BitcoinSWalletTest { addr <- wallet.getNewAddress() account <- wallet.getDefaultAccount() otherAddr <- wallet.getNewAddress() + thirdAddr <- wallet.getNewAddress(AddressType.SegWit) allAddrs <- wallet.listAddresses() } yield { assert(account.hdAccount.purpose == HDPurposes.SegWit) assert(allAddrs.forall(_.address.isInstanceOf[Bech32Address])) - assert(allAddrs.length == 2) + assert(allAddrs.length == 3) assert(allAddrs.exists(_.address == addr)) assert(allAddrs.exists(_.address == otherAddr)) + assert(allAddrs.exists(_.address == thirdAddr)) + } + } + + it should "generate legacy addresses" in { wallet => + for { + account <- wallet.getDefaultAccountForType(AddressType.Legacy) + addr <- wallet.getNewAddress(AddressType.Legacy) + } yield { + assert(account.hdAccount.purpose == HDPurposes.Legacy) + assert(addr.isInstanceOf[P2PKHAddress]) + } + } + + it should "generate mixed addresses" in { wallet => + for { + segwit <- wallet.getNewAddress(AddressType.SegWit) + legacy <- wallet.getNewAddress(AddressType.Legacy) + // TODO: uncomment this once nested segwit is implemented + // nested <- wallet.getNewAddress(AddressType.NestedSegWit) + } yield { + assert(segwit.isInstanceOf[Bech32Address]) + assert(legacy.isInstanceOf[P2PKHAddress]) + // TODO: uncomment this once nested segwit is implemented + // assert(nested.isInstanceOf[P2SHAddress]) } } } diff --git a/wallet-test/src/test/scala/org/bitcoins/wallet/WalletUnitTest.scala b/wallet-test/src/test/scala/org/bitcoins/wallet/WalletUnitTest.scala index 80dd6b9faf..29133e6752 100644 --- a/wallet-test/src/test/scala/org/bitcoins/wallet/WalletUnitTest.scala +++ b/wallet-test/src/test/scala/org/bitcoins/wallet/WalletUnitTest.scala @@ -17,6 +17,7 @@ import org.bitcoins.core.hd.HDChainType.External import org.bitcoins.core.protocol.BitcoinAddress import org.bitcoins.wallet.models.AddressDb import org.bitcoins.core.hd.HDChain +import org.bitcoins.core.hd.HDPurpose class WalletUnitTest extends BitcoinSWalletTest { @@ -53,12 +54,15 @@ class WalletUnitTest extends BitcoinSWalletTest { val wallet = walletApi.asInstanceOf[Wallet] def getMostRecent( + purpose: HDPurpose, chain: HDChainType, acctIndex: Int ): Future[AddressDb] = { val recentOptFut: Future[Option[AddressDb]] = chain match { - case Change => wallet.addressDAO.findMostRecentChange(acctIndex) - case External => wallet.addressDAO.findMostRecentExternal(acctIndex) + case Change => + wallet.addressDAO.findMostRecentChange(purpose, acctIndex) + case External => + wallet.addressDAO.findMostRecentExternal(purpose, acctIndex) } recentOptFut.map { @@ -68,10 +72,11 @@ class WalletUnitTest extends BitcoinSWalletTest { } def assertIndexIs( + purpose: HDPurpose, chain: HDChainType, addrIndex: Int, accountIndex: Int): Future[Assertion] = { - getMostRecent(chain, accountIndex) map { addr => + getMostRecent(purpose, chain, accountIndex) map { addr => assert(addr.path.address.index == addrIndex) } } @@ -80,10 +85,13 @@ class WalletUnitTest extends BitcoinSWalletTest { val addrRange = 0 to addressesToGenerate /** - * Generate some addresse, and verify that the correct address index is + * Generate some addresses, and verify that the correct address index is * being reported */ - def testChain(chain: HDChainType, accIdx: Int): Future[Assertion] = { + def testChain( + purpose: HDPurpose, + chain: HDChainType, + accIdx: Int): Future[Assertion] = { val getAddrFunc: () => Future[BitcoinAddress] = chain match { case Change => wallet.getNewChangeAddress _ case External => wallet.getNewAddress _ @@ -91,8 +99,10 @@ class WalletUnitTest extends BitcoinSWalletTest { for { _ <- { val addrF = chain match { - case Change => wallet.addressDAO.findMostRecentChange(accIdx) - case External => wallet.addressDAO.findMostRecentExternal(accIdx) + case Change => + wallet.addressDAO.findMostRecentChange(purpose, accIdx) + case External => + wallet.addressDAO.findMostRecentExternal(purpose, accIdx) } addrF.map { case Some(addr) => @@ -102,11 +112,12 @@ class WalletUnitTest extends BitcoinSWalletTest { } } _ <- FutureUtil.sequentially(addrRange)(_ => getAddrFunc()) - _ <- assertIndexIs(chain, + _ <- assertIndexIs(purpose, + chain, accountIndex = accIdx, addrIndex = addressesToGenerate) newest <- getAddrFunc() - res <- getMostRecent(chain, accIdx).map { found => + res <- getMostRecent(purpose, chain, accIdx).map { found => assert(found.address == newest) assert(found.path.address.index == addressesToGenerate + 1) } @@ -116,8 +127,8 @@ class WalletUnitTest extends BitcoinSWalletTest { for { account <- wallet.getDefaultAccount() accIdx = account.hdAccount.index - _ <- testChain(External, accIdx) - res <- testChain(Change, accIdx) + _ <- testChain(wallet.DEFAULT_HD_PURPOSE, External, accIdx) + res <- testChain(wallet.DEFAULT_HD_PURPOSE, Change, accIdx) } yield res } diff --git a/wallet/src/main/scala/org/bitcoins/wallet/api/WalletApi.scala b/wallet/src/main/scala/org/bitcoins/wallet/api/WalletApi.scala index b6acc38f2f..ef7372820e 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/api/WalletApi.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/api/WalletApi.scala @@ -15,6 +15,7 @@ import scala.concurrent.Future import scala.concurrent.ExecutionContext import org.bitcoins.wallet.config.WalletAppConfig import org.bitcoins.core.bloom.BloomFilter +import org.bitcoins.core.hd.AddressType /** * API for the wallet project. @@ -83,12 +84,15 @@ trait LockedWalletApi extends WalletApi { def listAddresses(): Future[Vector[AddressDb]] /** - * Gets a new external address from the specified - * account. Calling this method multiple + * Gets a new external address with the specified + * type. Calling this method multiple * times will return the same address, until it has * received funds. + * */ - def getNewAddress(account: AccountDb): Future[BitcoinAddress] + // TODO: Last sentence is not true, implement that + // https://github.com/bitcoin-s/bitcoin-s/issues/628 + def getNewAddress(addressType: AddressType): Future[BitcoinAddress] /** * Gets a new external address from the default account. @@ -98,8 +102,7 @@ trait LockedWalletApi extends WalletApi { */ def getNewAddress(): Future[BitcoinAddress] = { for { - account <- getDefaultAccount() - address <- getNewAddress(account) + address <- getNewAddress(walletConfig.defaultAddressType) } yield address } @@ -129,6 +132,10 @@ trait LockedWalletApi extends WalletApi { */ protected[wallet] def getDefaultAccount(): Future[AccountDb] + /** Fetches the default account for the given address/acount kind */ + protected[wallet] def getDefaultAccountForType( + addressType: AddressType): Future[AccountDb] + /** * Unlocks the wallet with the provided passphrase, * making it possible to send transactions. diff --git a/wallet/src/main/scala/org/bitcoins/wallet/config/WalletAppConfig.scala b/wallet/src/main/scala/org/bitcoins/wallet/config/WalletAppConfig.scala index 52e893f4ca..d1328fb959 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/config/WalletAppConfig.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/config/WalletAppConfig.scala @@ -9,6 +9,7 @@ import scala.util.Success import java.nio.file.Files import org.bitcoins.core.hd.HDPurpose import org.bitcoins.core.hd.HDPurposes +import org.bitcoins.core.hd.AddressType case class WalletAppConfig(private val conf: Config*) extends AppConfig { override val configOverrides: List[Config] = conf.toList @@ -27,6 +28,17 @@ case class WalletAppConfig(private val conf: Config*) extends AppConfig { throw new RuntimeException(s"$other is not a valid account type!") } + lazy val defaultAddressType: AddressType = { + defaultAccountKind match { + case HDPurposes.Legacy => AddressType.Legacy + case HDPurposes.NestedSegWit => AddressType.NestedSegWit + case HDPurposes.SegWit => AddressType.SegWit + // todo: validate this pre-app startup + case other => + throw new RuntimeException(s"$other is not a valid account type!") + } + } + lazy val bloomFalsePositiveRate: Double = config.getDouble("wallet.bloomFalsePositiveRate") diff --git a/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala b/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala index ef90e51402..593ddde007 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/internal/AccountHandling.scala @@ -8,6 +8,10 @@ import org.bitcoins.core.hd.HDCoin import org.bitcoins.core.protocol.blockchain.TestNetChainParams import org.bitcoins.core.protocol.blockchain.RegTestNetChainParams import org.bitcoins.core.protocol.blockchain.MainNetChainParams +import org.bitcoins.core.hd.HDPurpose +import org.bitcoins.core.hd.AddressType._ +import org.bitcoins.core.hd.AddressType +import org.bitcoins.core.hd.HDPurposes /** * Provides functionality related enumerating accounts. Account @@ -19,23 +23,51 @@ private[wallet] trait AccountHandling { self: LockedWallet => override def listAccounts(): Future[Vector[AccountDb]] = accountDAO.findAll() + private def getOrThrowAccount(account: Option[AccountDb]): AccountDb = + account.getOrElse( + throw new RuntimeException( + s"Could not find account with ${DEFAULT_HD_COIN.purpose.constant} " + + s"purpose field and ${DEFAULT_HD_COIN.coinType.toInt} coin field")) + /** @inheritdoc */ override protected[wallet] def getDefaultAccount(): Future[AccountDb] = { for { account <- accountDAO.read((DEFAULT_HD_COIN, 0)) - } yield - account.getOrElse( - throw new RuntimeException( - s"Could not find account with ${DEFAULT_HD_COIN.purpose.constant} " + - s"purpose field and ${DEFAULT_HD_COIN.coinType.toInt} coin field")) + } yield getOrThrowAccount(account) + } + + /** @inheritdoc */ + override protected[wallet] def getDefaultAccountForType( + addressType: AddressType): Future[AccountDb] = { + val hdCoin = addressType match { + case Legacy => HDCoin(HDPurposes.Legacy, DEFAULT_HD_COIN_TYPE) + case NestedSegWit => HDCoin(HDPurposes.NestedSegWit, DEFAULT_HD_COIN_TYPE) + case SegWit => HDCoin(HDPurposes.SegWit, DEFAULT_HD_COIN_TYPE) + } + for { + account <- accountDAO.read((hdCoin, 0)) + } yield getOrThrowAccount(account) } /** The default HD coin for this wallet, read from config */ protected[wallet] lazy val DEFAULT_HD_COIN: HDCoin = { - val coinType = chainParams match { - case MainNetChainParams => HDCoinType.Bitcoin - case RegTestNetChainParams | TestNetChainParams => HDCoinType.Testnet - } + val coinType = DEFAULT_HD_COIN_TYPE HDCoin(walletConfig.defaultAccountKind, coinType) } + + /** The default HD coin type for this wallet, derived from + * the network we're on + */ + protected[wallet] lazy val DEFAULT_HD_COIN_TYPE: HDCoinType = { + chainParams match { + case MainNetChainParams => HDCoinType.Bitcoin + case RegTestNetChainParams | TestNetChainParams => HDCoinType.Testnet + + } + + } + + /** The default HD purpose for this wallet, read from config */ + protected[wallet] lazy val DEFAULT_HD_PURPOSE: HDPurpose = + walletConfig.defaultAccountKind } diff --git a/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala b/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala index cbdbf900b1..ed3e3e3566 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/internal/AddressHandling.scala @@ -8,7 +8,6 @@ import org.bitcoins.wallet.models.AccountDb import org.bitcoins.core.hd.HDChainType import org.bitcoins.core.protocol.BitcoinAddress import org.bitcoins.core.hd.HDPath -import org.bitcoins.core.hd.HDAccount import org.bitcoins.core.hd.HDAddress import scala.util.Failure import scala.util.Success @@ -22,6 +21,7 @@ import org.bitcoins.core.protocol.transaction.TransactionOutput import org.bitcoins.core.protocol.script.ScriptPubKey import org.bitcoins.core.protocol.transaction.TransactionOutPoint import org.bitcoins.core.number.UInt32 +import org.bitcoins.core.hd.AddressType /** * Provides functionality related to addresses. This includes @@ -75,9 +75,10 @@ private[wallet] trait AddressHandling { self: LockedWallet => val lastAddrOptF = chainType match { case HDChainType.External => - addressDAO.findMostRecentExternal(accountIndex) + addressDAO.findMostRecentExternal(account.hdAccount.purpose, + accountIndex) case HDChainType.Change => - addressDAO.findMostRecentChange(accountIndex) + addressDAO.findMostRecentChange(account.hdAccount.purpose, accountIndex) } lastAddrOptF.flatMap { lastAddrOpt => @@ -88,8 +89,7 @@ private[wallet] trait AddressHandling { self: LockedWallet => s"Found previous address at path=${addr.path}, next=$next") next case None => - val account = HDAccount(DEFAULT_HD_COIN, accountIndex) - val chain = account.toChain(chainType) + val chain = account.hdAccount.toChain(chainType) val address = HDAddress(chain, 0) val path = address.toPath logger.debug(s"Did not find previous address, next=$path") @@ -135,12 +135,21 @@ private[wallet] trait AddressHandling { self: LockedWallet => } } - /** @inheritdoc */ - override def getNewAddress(account: AccountDb): Future[BitcoinAddress] = { - val addrF = getNewAddressHelper(account, HDChainType.External) + def getNewAddress(account: AccountDb): Future[BitcoinAddress] = { + val addrF = + getNewAddressHelper(account, HDChainType.External) addrF } + /** @inheritdoc */ + override def getNewAddress( + addressType: AddressType): Future[BitcoinAddress] = { + for { + account <- getDefaultAccountForType(addressType) + address <- getNewAddressHelper(account, HDChainType.External) + } yield address + } + /** Generates a new change address */ override protected[wallet] def getNewChangeAddress( account: AccountDb): Future[BitcoinAddress] = { diff --git a/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala b/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala index 9388e69abe..3a887820c6 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/models/AddressDAO.scala @@ -12,6 +12,7 @@ import org.bitcoins.core.hd.HDChainType import org.bitcoins.wallet.config.WalletAppConfig import org.bitcoins.core.crypto.ECPublicKey import org.bitcoins.core.protocol.script.ScriptPubKey +import org.bitcoins.core.hd.HDPurpose case class AddressDAO()( implicit val ec: ExecutionContext, @@ -44,8 +45,11 @@ case class AddressDAO()( /** * Finds the most recent change address in the wallet, if any */ - def findMostRecentChange(accountIndex: Int): Future[Option[AddressDb]] = { - val query = findMostRecentForChain(accountIndex, HDChainType.Change) + def findMostRecentChange( + purpose: HDPurpose, + accountIndex: Int): Future[Option[AddressDb]] = { + val query = + findMostRecentForChain(purpose, accountIndex, HDChainType.Change) database.run(query) } @@ -63,9 +67,11 @@ case class AddressDAO()( } private def findMostRecentForChain( + purpose: HDPurpose, accountIndex: Int, chain: HDChainType): SqlAction[Option[AddressDb], NoStream, Effect.Read] = { addressesForAccountQuery(accountIndex) + .filter(_.purpose === purpose) .filter(_.accountChainType === chain) .sortBy(_.addressIndex.desc) .take(1) @@ -76,8 +82,11 @@ case class AddressDAO()( /** * Finds the most recent external address in the wallet, if any */ - def findMostRecentExternal(accountIndex: Int): Future[Option[AddressDb]] = { - val query = findMostRecentForChain(accountIndex, HDChainType.External) + def findMostRecentExternal( + purpose: HDPurpose, + accountIndex: Int): Future[Option[AddressDb]] = { + val query = + findMostRecentForChain(purpose, accountIndex, HDChainType.External) database.run(query) } }