Fix getBlockchainsBetweenHeights (#1710)

* Fix getBlockchainsBetweenHeights

* Fix getBlockchainsBetweenHeights

* Fix compile issue for older version

* Improve test

* Optimizations
This commit is contained in:
Ben Carman 2020-07-29 16:36:22 -05:00 committed by GitHub
parent 666d53d94a
commit cad6fbeaaf
2 changed files with 77 additions and 12 deletions

View file

@ -350,12 +350,30 @@ class BlockHeaderDAOTest extends ChainDbUnitTest {
it must "successfully getBlockchainsBetweenHeights" in {
blockerHeaderDAO: BlockHeaderDAO =>
val duplicate = BlockHeader(
val duplicate3 = BlockHeader(
version = Int32.one,
previousBlockHash = ChainTestUtil.blockHeader562463.hash,
merkleRootHash = DoubleSha256Digest.empty,
time = UInt32.zero,
nBits = ChainTestUtil.blockHeader562464.nBits,
nonce = UInt32.zero
)
val duplicate2 = BlockHeader(
version = Int32.one,
previousBlockHash = ChainTestUtil.blockHeader562462.hash,
merkleRootHash = DoubleSha256Digest.empty,
time = UInt32.zero,
nBits = UInt32.zero,
nBits = ChainTestUtil.blockHeader562463.nBits,
nonce = UInt32.zero
)
val duplicate1 = BlockHeader(
version = Int32.one,
previousBlockHash = genesisHeaderDb.hashBE.flip,
merkleRootHash = DoubleSha256Digest.empty,
time = UInt32.zero,
nBits = ChainTestUtil.blockHeader562462.nBits,
nonce = UInt32.zero
)
@ -372,21 +390,40 @@ class BlockHeaderDAOTest extends ChainDbUnitTest {
)
val chain2 = Vector(
BlockHeaderDbHelper.fromBlockHeader(2, BigInt(1), duplicate),
BlockHeaderDbHelper.fromBlockHeader(2, BigInt(1), duplicate2),
BlockHeaderDbHelper.fromBlockHeader(1,
BigInt(0),
ChainTestUtil.blockHeader562462)
)
val headers = (chain1 ++ chain2).distinct
val chain3 = Vector(
BlockHeaderDbHelper.fromBlockHeader(3, BigInt(2), duplicate3),
BlockHeaderDbHelper.fromBlockHeader(2,
BigInt(1),
ChainTestUtil.blockHeader562463),
BlockHeaderDbHelper.fromBlockHeader(1,
BigInt(0),
ChainTestUtil.blockHeader562462)
)
val expectedChains = Vector(Blockchain(chain1), Blockchain(chain2))
val chain4 = Vector(
BlockHeaderDbHelper.fromBlockHeader(1, BigInt(0), duplicate1)
)
val expectedChains =
Vector(Blockchain(chain1),
Blockchain(chain2),
Blockchain(chain3),
Blockchain(chain4))
val headers = expectedChains.flatMap(_.headers).distinct
for {
_ <- blockerHeaderDAO.createAll(headers)
chains <- blockerHeaderDAO.getBlockchainsBetweenHeights(1, 3)
} yield {
assert(chains.forall(expectedChains.contains))
assert(chains.nonEmpty)
assert(expectedChains.forall(chains.contains))
}
}
}

View file

@ -334,19 +334,47 @@ case class BlockHeaderDAO()(implicit
headersF.map(headers => Blockchain.fromHeaders(headers.reverse))
}
@tailrec
private def loop(
chains: Vector[Blockchain],
allHeaders: Vector[BlockHeaderDb]): Vector[Blockchain] = {
val usedHeaders = chains.flatMap(_.headers).distinct
val diff = allHeaders.filter(header =>
!usedHeaders.exists(_.hashBE == header.hashBE))
if (diff.isEmpty) {
chains
} else {
val sortedDiff = diff.sortBy(_.height)(Ordering.Int.reverse)
val newChainHeaders =
Blockchain.connectWalkBackwards(sortedDiff.head, allHeaders)
val newChain = Blockchain(
newChainHeaders.sortBy(_.height)(Ordering.Int.reverse))
loop(chains :+ newChain, allHeaders)
}
}
/** Retrieves a blockchain with the best tip being the given header */
def getBlockchainsBetweenHeights(from: Int, to: Int)(implicit
ec: ExecutionContext): Future[Vector[Blockchain]] = {
getBetweenHeights(from = from, to = to).map { headers =>
if (headers.map(_.height).distinct.size == headers.size) {
Vector(Blockchain.fromHeaders(headers.reverse))
Vector(
Blockchain.fromHeaders(
headers.sortBy(_.height)(Ordering.Int.reverse)))
} else {
val headersByHeight = headers.groupBy(_.height).toVector
val sortedHeaders = headersByHeight.sortBy(_._1).reverse.map(_._2)
val chains = sortedHeaders.map { headers =>
Blockchain.reconstructFromHeaders(headers.head, headers.tail)
val headersByHeight: Vector[(Int, Vector[BlockHeaderDb])] =
headers.groupBy(_.height).toVector
val tips: Vector[BlockHeaderDb] = headersByHeight.maxBy(_._1)._2
val chains = tips.map { tip =>
Blockchain
.connectWalkBackwards(tip, headers)
.sortBy(_.height)(Ordering.Int.reverse)
}
chains.flatten.distinct
val init = chains.map(Blockchain(_))
loop(init, headers).distinct
}
}
}