diff --git a/wallet-test/src/test/scala/org/bitcoins/wallet/RescanHandlingTest.scala b/wallet-test/src/test/scala/org/bitcoins/wallet/RescanHandlingTest.scala index 71772ed5b1..41eb832ff2 100644 --- a/wallet-test/src/test/scala/org/bitcoins/wallet/RescanHandlingTest.scala +++ b/wallet-test/src/test/scala/org/bitcoins/wallet/RescanHandlingTest.scala @@ -3,7 +3,6 @@ package org.bitcoins.wallet import org.bitcoins.asyncutil.AsyncUtil import org.bitcoins.core.currency.{Bitcoins, CurrencyUnits, Satoshis} import org.bitcoins.core.protocol.BlockStamp -import org.bitcoins.core.protocol.script.ScriptPubKey import org.bitcoins.core.protocol.transaction.TransactionOutput import org.bitcoins.core.util.FutureUtil import org.bitcoins.core.wallet.rescan.RescanState @@ -161,13 +160,15 @@ class RescanHandlingTest extends BitcoinSWalletTestCachedBitcoindNewest { val amt = Bitcoins.one val numBlocks = 1 + val initBalanceF = wallet.getBalance() val defaultAccountF = wallet.getDefaultAccount() //send funds to a fresh wallet address val addrF = wallet.getNewAddress() val bitcoindAddrF = bitcoind.getNewAddress - val newTxWalletF = for { + val balanceAfterPayment1F = for { addr <- addrF + _ <- initBalanceF txid <- bitcoind.sendToAddress(addr, amt) tx <- bitcoind.getRawTransactionRaw(txid) bitcoindAddr <- bitcoindAddrF @@ -183,40 +184,31 @@ class RescanHandlingTest extends BitcoinSWalletTestCachedBitcoindNewest { //wallet before hand. assert(balance >= amt) assert(amt == unconfirmedBalance) - newTxWallet + balance } for { - newTxWallet <- newTxWalletF + _ <- initBalanceF + balanceAfterPayment1 <- balanceAfterPayment1F account <- defaultAccountF txIds <- wallet .listUtxos(account.hdAccount) .map(_.map(_.txid)) - _ <- newTxWallet + _ <- wallet .findByTxIds(txIds) .map(_.flatMap(_.blockHashOpt)) - _ <- newTxWallet.clearAllUtxos() - _ <- newTxWallet.clearAllAddresses() - _ <- - 1.to(10).foldLeft(Future.successful(Vector.empty[ScriptPubKey])) { - (prevFuture, _) => - for { - prev <- prevFuture - address <- wallet.getNewAddress(account) - changeAddress <- wallet.getNewChangeAddress(account) - } yield prev :+ address.scriptPubKey :+ changeAddress.scriptPubKey - } - _ <- wallet.rescanNeutrinoWallet(startOpt = None, - endOpt = None, - addressBatchSize = 1, - useCreationTime = false, - force = true) + _ <- wallet.clearAllUtxos() + _ <- wallet.clearAllAddresses() + balanceAfterClear <- wallet.getBalance() + rescanState <- wallet.fullRescanNeutrinoWallet(1, true) + _ <- RescanState.awaitRescanDone(rescanState) + balanceAfterRescan <- wallet.getBalance() } yield { - - succeed + assert(balanceAfterClear == CurrencyUnits.zero) + assert(balanceAfterPayment1 == balanceAfterRescan) } } diff --git a/wallet/src/main/scala/org/bitcoins/wallet/internal/RescanHandling.scala b/wallet/src/main/scala/org/bitcoins/wallet/internal/RescanHandling.scala index 258e530313..422e208111 100644 --- a/wallet/src/main/scala/org/bitcoins/wallet/internal/RescanHandling.scala +++ b/wallet/src/main/scala/org/bitcoins/wallet/internal/RescanHandling.scala @@ -129,15 +129,26 @@ private[wallet] trait RescanHandling extends WalletLogger { .epochSecondToBlockHeight(creationTime.getEpochSecond) .map(BlockHeight) - private def buildFilterMatchFlow( + private def buildRescanFlow( + account: HDAccount, + addressBatchSize: Int, range: Range, - scripts: Vector[ScriptPubKey], parallelism: Int, - batchSize: Int): RescanState.RescanStarted = { + filterBatchSize: Int): RescanState.RescanStarted = { + val scriptsF = generateScriptPubKeys(account, addressBatchSize) + + //by completing the promise returned by this sink + //we will be able to arbitrarily terminate the stream + //see: https://doc.akka.io/docs/akka/current/stream/operators/Source/maybe.html val maybe = Source.maybe[Int] + + //combine the Source.maybe with the Source providing filter heights + //this is needed so we can arbitrarily kill the stream with + //the promise returned by Source.maybe val combine: Source[Int, Promise[Option[Int]]] = { Source.combineMat(maybe, Source(range))(Merge(_))(Keep.left) } + val seed: Int => Vector[Int] = { case int => Vector(int) } @@ -150,14 +161,18 @@ private[wallet] trait RescanHandling extends WalletLogger { val rescanCompletePromise: Promise[Unit] = Promise() //fetches filters, matches filters against our wallet, and then request blocks - //for the wallet to process + //for the wallet to process. This sink takes as input filter heights + //to fetch for rescanning. val rescanSink: Sink[Int, Future[Seq[Vector[BlockMatchingResponse]]]] = { Flow[Int] - .batch[Vector[Int]](batchSize, seed)(aggregate) + .batch[Vector[Int]](filterBatchSize, seed)(aggregate) .via(fetchFiltersFlow) .mapAsync(1) { case filterResponse => - val f = searchFiltersForMatches(scripts, filterResponse, parallelism)( - ExecutionContext.fromExecutor(walletConfig.rescanThreadPool)) + val f = + scriptsF.flatMap { scripts => + searchFiltersForMatches(scripts, filterResponse, parallelism)( + ExecutionContext.fromExecutor(walletConfig.rescanThreadPool)) + } val heightRange = filterResponse.map(_.blockHeight) @@ -207,38 +222,35 @@ private[wallet] trait RescanHandling extends WalletLogger { * @return a list of matching block hashes */ def getMatchingBlocks( - scripts: Vector[ScriptPubKey], startOpt: Option[BlockStamp] = None, endOpt: Option[BlockStamp] = None, - batchSize: Int = 100, - parallelismLevel: Int = Runtime.getRuntime.availableProcessors())(implicit + addressBatchSize: Int = 100, + parallelismLevel: Int = Runtime.getRuntime.availableProcessors(), + account: HDAccount)(implicit ec: ExecutionContext): Future[RescanState] = { - require(batchSize > 0, "batch size must be greater than zero") + require(addressBatchSize > 0, "batch size must be greater than zero") require(parallelismLevel > 0, "parallelism level must be greater than zero") - if (scripts.isEmpty) { - Future.successful(RescanState.RescanDone) - } else { - for { - startHeight <- startOpt.fold(Future.successful(0))( - chainQueryApi.getHeightByBlockStamp) - _ = if (startHeight < 0) - throw InvalidBlockRange(s"Start position cannot negative") - endHeight <- endOpt.fold(chainQueryApi.getFilterCount())( - chainQueryApi.getHeightByBlockStamp) - _ = if (startHeight > endHeight) - throw InvalidBlockRange( - s"End position cannot precede start: $startHeight:$endHeight") - _ = logger.info( - s"Beginning to search for matches between ${startHeight}:${endHeight} against ${scripts.length} spks") - range = startHeight.to(endHeight) + for { + startHeight <- startOpt.fold(Future.successful(0))( + chainQueryApi.getHeightByBlockStamp) + _ = if (startHeight < 0) + throw InvalidBlockRange(s"Start position cannot negative") + endHeight <- endOpt.fold(chainQueryApi.getFilterCount())( + chainQueryApi.getHeightByBlockStamp) + _ = if (startHeight > endHeight) + throw InvalidBlockRange( + s"End position cannot precede start: $startHeight:$endHeight") + _ = logger.info( + s"Beginning to search for matches between ${startHeight}:${endHeight}") + range = startHeight.to(endHeight) - rescanStarted = buildFilterMatchFlow(range, - scripts, - parallelismLevel, - batchSize) - } yield { - rescanStarted - } + rescanStarted = buildRescanFlow(account = account, + addressBatchSize = addressBatchSize, + range = range, + parallelism = parallelismLevel, + filterBatchSize = addressBatchSize) + } yield { + rescanStarted } } @@ -251,11 +263,10 @@ private[wallet] trait RescanHandling extends WalletLogger { endOpt: Option[BlockStamp], addressBatchSize: Int): Future[RescanState] = { for { - scriptPubKeys <- generateScriptPubKeys(account, addressBatchSize) addressCount <- addressDAO.count() - inProgress <- matchBlocks(scriptPubKeys = scriptPubKeys, - endOpt = endOpt, - startOpt = startOpt) + inProgress <- matchBlocks(endOpt = endOpt, + startOpt = startOpt, + account = account) externalGap <- calcAddressGap(HDChainType.External, account) changeGap <- calcAddressGap(HDChainType.Change, account) _ <- { @@ -321,14 +332,13 @@ private[wallet] trait RescanHandling extends WalletLogger { } private def matchBlocks( - scriptPubKeys: Vector[ScriptPubKey], endOpt: Option[BlockStamp], - startOpt: Option[BlockStamp]): Future[RescanState] = { - + startOpt: Option[BlockStamp], + account: HDAccount): Future[RescanState] = { val rescanStateF = for { - rescanState <- getMatchingBlocks(scripts = scriptPubKeys, - startOpt = startOpt, - endOpt = endOpt)( + rescanState <- getMatchingBlocks(startOpt = startOpt, + endOpt = endOpt, + account = account)( ExecutionContext.fromExecutor(walletConfig.rescanThreadPool)) } yield { rescanState