diff --git a/chain-test/src/test/scala/org/bitcoins/chain/models/BlockHeaderDAOTest.scala b/chain-test/src/test/scala/org/bitcoins/chain/models/BlockHeaderDAOTest.scala index 6c3a0185d6..0a34ceffb5 100644 --- a/chain-test/src/test/scala/org/bitcoins/chain/models/BlockHeaderDAOTest.scala +++ b/chain-test/src/test/scala/org/bitcoins/chain/models/BlockHeaderDAOTest.scala @@ -350,12 +350,30 @@ class BlockHeaderDAOTest extends ChainDbUnitTest { it must "successfully getBlockchainsBetweenHeights" in { blockerHeaderDAO: BlockHeaderDAO => - val duplicate = BlockHeader( + val duplicate3 = BlockHeader( + version = Int32.one, + previousBlockHash = ChainTestUtil.blockHeader562463.hash, + merkleRootHash = DoubleSha256Digest.empty, + time = UInt32.zero, + nBits = ChainTestUtil.blockHeader562464.nBits, + nonce = UInt32.zero + ) + + val duplicate2 = BlockHeader( version = Int32.one, previousBlockHash = ChainTestUtil.blockHeader562462.hash, merkleRootHash = DoubleSha256Digest.empty, time = UInt32.zero, - nBits = UInt32.zero, + nBits = ChainTestUtil.blockHeader562463.nBits, + nonce = UInt32.zero + ) + + val duplicate1 = BlockHeader( + version = Int32.one, + previousBlockHash = genesisHeaderDb.hashBE.flip, + merkleRootHash = DoubleSha256Digest.empty, + time = UInt32.zero, + nBits = ChainTestUtil.blockHeader562462.nBits, nonce = UInt32.zero ) @@ -372,21 +390,40 @@ class BlockHeaderDAOTest extends ChainDbUnitTest { ) val chain2 = Vector( - BlockHeaderDbHelper.fromBlockHeader(2, BigInt(1), duplicate), + BlockHeaderDbHelper.fromBlockHeader(2, BigInt(1), duplicate2), BlockHeaderDbHelper.fromBlockHeader(1, BigInt(0), ChainTestUtil.blockHeader562462) ) - val headers = (chain1 ++ chain2).distinct + val chain3 = Vector( + BlockHeaderDbHelper.fromBlockHeader(3, BigInt(2), duplicate3), + BlockHeaderDbHelper.fromBlockHeader(2, + BigInt(1), + ChainTestUtil.blockHeader562463), + BlockHeaderDbHelper.fromBlockHeader(1, + BigInt(0), + ChainTestUtil.blockHeader562462) + ) - val expectedChains = Vector(Blockchain(chain1), Blockchain(chain2)) + val chain4 = Vector( + BlockHeaderDbHelper.fromBlockHeader(1, BigInt(0), duplicate1) + ) + + val expectedChains = + Vector(Blockchain(chain1), + Blockchain(chain2), + Blockchain(chain3), + Blockchain(chain4)) + + val headers = expectedChains.flatMap(_.headers).distinct for { _ <- blockerHeaderDAO.createAll(headers) chains <- blockerHeaderDAO.getBlockchainsBetweenHeights(1, 3) } yield { - assert(chains.forall(expectedChains.contains)) + assert(chains.nonEmpty) + assert(expectedChains.forall(chains.contains)) } } } diff --git a/chain/src/main/scala/org/bitcoins/chain/models/BlockHeaderDAO.scala b/chain/src/main/scala/org/bitcoins/chain/models/BlockHeaderDAO.scala index ff03956e66..dcf30ebc96 100644 --- a/chain/src/main/scala/org/bitcoins/chain/models/BlockHeaderDAO.scala +++ b/chain/src/main/scala/org/bitcoins/chain/models/BlockHeaderDAO.scala @@ -334,19 +334,47 @@ case class BlockHeaderDAO()(implicit headersF.map(headers => Blockchain.fromHeaders(headers.reverse)) } + @tailrec + private def loop( + chains: Vector[Blockchain], + allHeaders: Vector[BlockHeaderDb]): Vector[Blockchain] = { + val usedHeaders = chains.flatMap(_.headers).distinct + val diff = allHeaders.filter(header => + !usedHeaders.exists(_.hashBE == header.hashBE)) + if (diff.isEmpty) { + chains + } else { + val sortedDiff = diff.sortBy(_.height)(Ordering.Int.reverse) + + val newChainHeaders = + Blockchain.connectWalkBackwards(sortedDiff.head, allHeaders) + val newChain = Blockchain( + newChainHeaders.sortBy(_.height)(Ordering.Int.reverse)) + loop(chains :+ newChain, allHeaders) + } + } + /** Retrieves a blockchain with the best tip being the given header */ def getBlockchainsBetweenHeights(from: Int, to: Int)(implicit ec: ExecutionContext): Future[Vector[Blockchain]] = { getBetweenHeights(from = from, to = to).map { headers => if (headers.map(_.height).distinct.size == headers.size) { - Vector(Blockchain.fromHeaders(headers.reverse)) + Vector( + Blockchain.fromHeaders( + headers.sortBy(_.height)(Ordering.Int.reverse))) } else { - val headersByHeight = headers.groupBy(_.height).toVector - val sortedHeaders = headersByHeight.sortBy(_._1).reverse.map(_._2) - val chains = sortedHeaders.map { headers => - Blockchain.reconstructFromHeaders(headers.head, headers.tail) + val headersByHeight: Vector[(Int, Vector[BlockHeaderDb])] = + headers.groupBy(_.height).toVector + val tips: Vector[BlockHeaderDb] = headersByHeight.maxBy(_._1)._2 + + val chains = tips.map { tip => + Blockchain + .connectWalkBackwards(tip, headers) + .sortBy(_.height)(Ordering.Int.reverse) } - chains.flatten.distinct + val init = chains.map(Blockchain(_)) + + loop(init, headers).distinct } } }