Fix getBlockchainsBetweenHeights (#1710)

* Fix getBlockchainsBetweenHeights

* Fix getBlockchainsBetweenHeights

* Fix compile issue for older version

* Improve test

* Optimizations
This commit is contained in:
Ben Carman 2020-07-29 16:36:22 -05:00 committed by GitHub
parent 666d53d94a
commit cad6fbeaaf
2 changed files with 77 additions and 12 deletions

View file

@ -350,12 +350,30 @@ class BlockHeaderDAOTest extends ChainDbUnitTest {
it must "successfully getBlockchainsBetweenHeights" in { it must "successfully getBlockchainsBetweenHeights" in {
blockerHeaderDAO: BlockHeaderDAO => blockerHeaderDAO: BlockHeaderDAO =>
val duplicate = BlockHeader( val duplicate3 = BlockHeader(
version = Int32.one,
previousBlockHash = ChainTestUtil.blockHeader562463.hash,
merkleRootHash = DoubleSha256Digest.empty,
time = UInt32.zero,
nBits = ChainTestUtil.blockHeader562464.nBits,
nonce = UInt32.zero
)
val duplicate2 = BlockHeader(
version = Int32.one, version = Int32.one,
previousBlockHash = ChainTestUtil.blockHeader562462.hash, previousBlockHash = ChainTestUtil.blockHeader562462.hash,
merkleRootHash = DoubleSha256Digest.empty, merkleRootHash = DoubleSha256Digest.empty,
time = UInt32.zero, time = UInt32.zero,
nBits = UInt32.zero, nBits = ChainTestUtil.blockHeader562463.nBits,
nonce = UInt32.zero
)
val duplicate1 = BlockHeader(
version = Int32.one,
previousBlockHash = genesisHeaderDb.hashBE.flip,
merkleRootHash = DoubleSha256Digest.empty,
time = UInt32.zero,
nBits = ChainTestUtil.blockHeader562462.nBits,
nonce = UInt32.zero nonce = UInt32.zero
) )
@ -372,21 +390,40 @@ class BlockHeaderDAOTest extends ChainDbUnitTest {
) )
val chain2 = Vector( val chain2 = Vector(
BlockHeaderDbHelper.fromBlockHeader(2, BigInt(1), duplicate), BlockHeaderDbHelper.fromBlockHeader(2, BigInt(1), duplicate2),
BlockHeaderDbHelper.fromBlockHeader(1, BlockHeaderDbHelper.fromBlockHeader(1,
BigInt(0), BigInt(0),
ChainTestUtil.blockHeader562462) ChainTestUtil.blockHeader562462)
) )
val headers = (chain1 ++ chain2).distinct val chain3 = Vector(
BlockHeaderDbHelper.fromBlockHeader(3, BigInt(2), duplicate3),
BlockHeaderDbHelper.fromBlockHeader(2,
BigInt(1),
ChainTestUtil.blockHeader562463),
BlockHeaderDbHelper.fromBlockHeader(1,
BigInt(0),
ChainTestUtil.blockHeader562462)
)
val expectedChains = Vector(Blockchain(chain1), Blockchain(chain2)) val chain4 = Vector(
BlockHeaderDbHelper.fromBlockHeader(1, BigInt(0), duplicate1)
)
val expectedChains =
Vector(Blockchain(chain1),
Blockchain(chain2),
Blockchain(chain3),
Blockchain(chain4))
val headers = expectedChains.flatMap(_.headers).distinct
for { for {
_ <- blockerHeaderDAO.createAll(headers) _ <- blockerHeaderDAO.createAll(headers)
chains <- blockerHeaderDAO.getBlockchainsBetweenHeights(1, 3) chains <- blockerHeaderDAO.getBlockchainsBetweenHeights(1, 3)
} yield { } yield {
assert(chains.forall(expectedChains.contains)) assert(chains.nonEmpty)
assert(expectedChains.forall(chains.contains))
} }
} }
} }

View file

@ -334,19 +334,47 @@ case class BlockHeaderDAO()(implicit
headersF.map(headers => Blockchain.fromHeaders(headers.reverse)) headersF.map(headers => Blockchain.fromHeaders(headers.reverse))
} }
@tailrec
private def loop(
chains: Vector[Blockchain],
allHeaders: Vector[BlockHeaderDb]): Vector[Blockchain] = {
val usedHeaders = chains.flatMap(_.headers).distinct
val diff = allHeaders.filter(header =>
!usedHeaders.exists(_.hashBE == header.hashBE))
if (diff.isEmpty) {
chains
} else {
val sortedDiff = diff.sortBy(_.height)(Ordering.Int.reverse)
val newChainHeaders =
Blockchain.connectWalkBackwards(sortedDiff.head, allHeaders)
val newChain = Blockchain(
newChainHeaders.sortBy(_.height)(Ordering.Int.reverse))
loop(chains :+ newChain, allHeaders)
}
}
/** Retrieves a blockchain with the best tip being the given header */ /** Retrieves a blockchain with the best tip being the given header */
def getBlockchainsBetweenHeights(from: Int, to: Int)(implicit def getBlockchainsBetweenHeights(from: Int, to: Int)(implicit
ec: ExecutionContext): Future[Vector[Blockchain]] = { ec: ExecutionContext): Future[Vector[Blockchain]] = {
getBetweenHeights(from = from, to = to).map { headers => getBetweenHeights(from = from, to = to).map { headers =>
if (headers.map(_.height).distinct.size == headers.size) { if (headers.map(_.height).distinct.size == headers.size) {
Vector(Blockchain.fromHeaders(headers.reverse)) Vector(
Blockchain.fromHeaders(
headers.sortBy(_.height)(Ordering.Int.reverse)))
} else { } else {
val headersByHeight = headers.groupBy(_.height).toVector val headersByHeight: Vector[(Int, Vector[BlockHeaderDb])] =
val sortedHeaders = headersByHeight.sortBy(_._1).reverse.map(_._2) headers.groupBy(_.height).toVector
val chains = sortedHeaders.map { headers => val tips: Vector[BlockHeaderDb] = headersByHeight.maxBy(_._1)._2
Blockchain.reconstructFromHeaders(headers.head, headers.tail)
val chains = tips.map { tip =>
Blockchain
.connectWalkBackwards(tip, headers)
.sortBy(_.height)(Ordering.Int.reverse)
} }
chains.flatten.distinct val init = chains.map(Blockchain(_))
loop(init, headers).distinct
} }
} }
} }