[checkqueue] support user-defined return type through std::optional

The check type function now needs to return a std::optional<R> for some type R,
and the check queue overall will return std::nullopt if all individual checks
return that, or one of the non-nullopt values if there is at least one.

For most tests, we use R=int, but for the actual validation code, we make it return
the ScriptError.
This commit is contained in:
Pieter Wuille 2024-10-16 05:53:19 -04:00
parent ebe4cac38b
commit 1ac1c33f3f
8 changed files with 107 additions and 92 deletions

View File

@ -34,9 +34,9 @@ static void CCheckQueueSpeedPrevectorJob(benchmark::Bench& bench)
explicit PrevectorJob(FastRandomContext& insecure_rand){
p.resize(insecure_rand.randrange(PREVECTOR_SIZE*2));
}
bool operator()()
std::optional<int> operator()()
{
return true;
return std::nullopt;
}
};
@ -62,7 +62,7 @@ static void CCheckQueueSpeedPrevectorJob(benchmark::Bench& bench)
}
// control waits for completion by RAII, but
// it is done explicitly here for clarity
control.Wait();
control.Complete();
});
}
BENCHMARK(CCheckQueueSpeedPrevectorJob, benchmark::PriorityLevel::HIGH);

View File

@ -11,19 +11,24 @@
#include <algorithm>
#include <iterator>
#include <optional>
#include <vector>
/**
* Queue for verifications that have to be performed.
* The verifications are represented by a type T, which must provide an
* operator(), returning a bool.
* operator(), returning an std::optional<R>.
*
* The overall result of the computation is std::nullopt if all invocations
* return std::nullopt, or one of the other results otherwise.
*
* One thread (the master) is assumed to push batches of verifications
* onto the queue, where they are processed by N-1 worker threads. When
* the master is done adding work, it temporarily joins the worker pool
* as an N'th worker, until all jobs are done.
*
*/
template <typename T>
template <typename T, typename R = std::remove_cvref_t<decltype(std::declval<T>()().value())>>
class CCheckQueue
{
private:
@ -47,7 +52,7 @@ private:
int nTotal GUARDED_BY(m_mutex){0};
//! The temporary evaluation result.
bool fAllOk GUARDED_BY(m_mutex){true};
std::optional<R> m_result GUARDED_BY(m_mutex);
/**
* Number of verifications that haven't completed yet.
@ -62,24 +67,28 @@ private:
std::vector<std::thread> m_worker_threads;
bool m_request_stop GUARDED_BY(m_mutex){false};
/** Internal function that does bulk of the verification work. */
bool Loop(bool fMaster) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex)
/** Internal function that does bulk of the verification work. If fMaster, return the final result. */
std::optional<R> Loop(bool fMaster) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex)
{
std::condition_variable& cond = fMaster ? m_master_cv : m_worker_cv;
std::vector<T> vChecks;
vChecks.reserve(nBatchSize);
unsigned int nNow = 0;
bool fOk = true;
std::optional<R> local_result;
bool do_work;
do {
{
WAIT_LOCK(m_mutex, lock);
// first do the clean-up of the previous loop run (allowing us to do it in the same critsect)
if (nNow) {
fAllOk &= fOk;
if (local_result.has_value() && !m_result.has_value()) {
std::swap(local_result, m_result);
}
nTodo -= nNow;
if (nTodo == 0 && !fMaster)
if (nTodo == 0 && !fMaster) {
// We processed the last element; inform the master it can exit and return the result
m_master_cv.notify_one();
}
} else {
// first iteration
nTotal++;
@ -88,18 +97,19 @@ private:
while (queue.empty() && !m_request_stop) {
if (fMaster && nTodo == 0) {
nTotal--;
bool fRet = fAllOk;
std::optional<R> to_return = std::move(m_result);
// reset the status for new work later
fAllOk = true;
m_result = std::nullopt;
// return the current status
return fRet;
return to_return;
}
nIdle++;
cond.wait(lock); // wait
nIdle--;
}
if (m_request_stop) {
return false;
// return value does not matter, because m_request_stop is only set in the destructor.
return std::nullopt;
}
// Decide how many work units to process now.
@ -112,12 +122,15 @@ private:
vChecks.assign(std::make_move_iterator(start_it), std::make_move_iterator(queue.end()));
queue.erase(start_it, queue.end());
// Check whether we need to do work at all
fOk = fAllOk;
do_work = !m_result.has_value();
}
// execute work
for (T& check : vChecks)
if (fOk)
fOk = check();
if (do_work) {
for (T& check : vChecks) {
local_result = check();
if (local_result.has_value()) break;
}
}
vChecks.clear();
} while (true);
}
@ -146,8 +159,9 @@ public:
CCheckQueue(CCheckQueue&&) = delete;
CCheckQueue& operator=(CCheckQueue&&) = delete;
//! Wait until execution finishes, and return whether all evaluations were successful.
bool Wait() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex)
//! Join the execution until completion. If at least one evaluation wasn't successful, return
//! its error.
std::optional<R> Complete() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex)
{
return Loop(true /* master thread */);
}
@ -188,11 +202,11 @@ public:
* RAII-style controller object for a CCheckQueue that guarantees the passed
* queue is finished before continuing.
*/
template <typename T>
template <typename T, typename R = std::remove_cvref_t<decltype(std::declval<T>()().value())>>
class CCheckQueueControl
{
private:
CCheckQueue<T> * const pqueue;
CCheckQueue<T, R> * const pqueue;
bool fDone;
public:
@ -207,13 +221,12 @@ public:
}
}
bool Wait()
std::optional<R> Complete()
{
if (pqueue == nullptr)
return true;
bool fRet = pqueue->Wait();
if (pqueue == nullptr) return std::nullopt;
auto ret = pqueue->Complete();
fDone = true;
return fRet;
return ret;
}
void Add(std::vector<T>&& vChecks)
@ -226,7 +239,7 @@ public:
~CCheckQueueControl()
{
if (!fDone)
Wait();
Complete();
if (pqueue != nullptr) {
LEAVE_CRITICAL_SECTION(pqueue->m_control_mutex);
}

View File

@ -42,28 +42,26 @@ static const unsigned int QUEUE_BATCH_SIZE = 128;
static const int SCRIPT_CHECK_THREADS = 3;
struct FakeCheck {
bool operator()() const
std::optional<int> operator()() const
{
return true;
return std::nullopt;
}
};
struct FakeCheckCheckCompletion {
static std::atomic<size_t> n_calls;
bool operator()()
std::optional<int> operator()()
{
n_calls.fetch_add(1, std::memory_order_relaxed);
return true;
return std::nullopt;
}
};
struct FailingCheck {
bool fails;
FailingCheck(bool _fails) : fails(_fails){};
bool operator()() const
{
return !fails;
}
struct FixedCheck
{
std::optional<int> m_result;
FixedCheck(std::optional<int> result) : m_result(result){};
std::optional<int> operator()() const { return m_result; }
};
struct UniqueCheck {
@ -71,11 +69,11 @@ struct UniqueCheck {
static std::unordered_multiset<size_t> results GUARDED_BY(m);
size_t check_id;
UniqueCheck(size_t check_id_in) : check_id(check_id_in){};
bool operator()()
std::optional<int> operator()()
{
LOCK(m);
results.insert(check_id);
return true;
return std::nullopt;
}
};
@ -83,9 +81,9 @@ struct UniqueCheck {
struct MemoryCheck {
static std::atomic<size_t> fake_allocated_memory;
bool b {false};
bool operator()() const
std::optional<int> operator()() const
{
return true;
return std::nullopt;
}
MemoryCheck(const MemoryCheck& x)
{
@ -110,9 +108,9 @@ struct FrozenCleanupCheck {
static std::condition_variable cv;
static std::mutex m;
bool should_freeze{true};
bool operator()() const
std::optional<int> operator()() const
{
return true;
return std::nullopt;
}
FrozenCleanupCheck() = default;
~FrozenCleanupCheck()
@ -149,7 +147,7 @@ std::atomic<size_t> MemoryCheck::fake_allocated_memory{0};
// Queue Typedefs
typedef CCheckQueue<FakeCheckCheckCompletion> Correct_Queue;
typedef CCheckQueue<FakeCheck> Standard_Queue;
typedef CCheckQueue<FailingCheck> Failing_Queue;
typedef CCheckQueue<FixedCheck> Fixed_Queue;
typedef CCheckQueue<UniqueCheck> Unique_Queue;
typedef CCheckQueue<MemoryCheck> Memory_Queue;
typedef CCheckQueue<FrozenCleanupCheck> FrozenCleanup_Queue;
@ -174,7 +172,7 @@ void CheckQueueTest::Correct_Queue_range(std::vector<size_t> range)
total -= vChecks.size();
control.Add(std::move(vChecks));
}
BOOST_REQUIRE(control.Wait());
BOOST_REQUIRE(!control.Complete().has_value());
BOOST_REQUIRE_EQUAL(FakeCheckCheckCompletion::n_calls, i);
}
}
@ -217,27 +215,27 @@ BOOST_AUTO_TEST_CASE(test_CheckQueue_Correct_Random)
}
/** Test that failing checks are caught */
/** Test that distinct failing checks are caught */
BOOST_AUTO_TEST_CASE(test_CheckQueue_Catches_Failure)
{
auto fail_queue = std::make_unique<Failing_Queue>(QUEUE_BATCH_SIZE, SCRIPT_CHECK_THREADS);
auto fixed_queue = std::make_unique<Fixed_Queue>(QUEUE_BATCH_SIZE, SCRIPT_CHECK_THREADS);
for (size_t i = 0; i < 1001; ++i) {
CCheckQueueControl<FailingCheck> control(fail_queue.get());
CCheckQueueControl<FixedCheck> control(fixed_queue.get());
size_t remaining = i;
while (remaining) {
size_t r = m_rng.randrange(10);
std::vector<FailingCheck> vChecks;
std::vector<FixedCheck> vChecks;
vChecks.reserve(r);
for (size_t k = 0; k < r && remaining; k++, remaining--)
vChecks.emplace_back(remaining == 1);
vChecks.emplace_back(remaining == 1 ? std::make_optional<int>(17 * i) : std::nullopt);
control.Add(std::move(vChecks));
}
bool success = control.Wait();
auto result = control.Complete();
if (i > 0) {
BOOST_REQUIRE(!success);
} else if (i == 0) {
BOOST_REQUIRE(success);
BOOST_REQUIRE(result.has_value() && *result == static_cast<int>(17 * i));
} else {
BOOST_REQUIRE(!result.has_value());
}
}
}
@ -245,17 +243,17 @@ BOOST_AUTO_TEST_CASE(test_CheckQueue_Catches_Failure)
// future blocks, ie, the bad state is cleared.
BOOST_AUTO_TEST_CASE(test_CheckQueue_Recovers_From_Failure)
{
auto fail_queue = std::make_unique<Failing_Queue>(QUEUE_BATCH_SIZE, SCRIPT_CHECK_THREADS);
auto fail_queue = std::make_unique<Fixed_Queue>(QUEUE_BATCH_SIZE, SCRIPT_CHECK_THREADS);
for (auto times = 0; times < 10; ++times) {
for (const bool end_fails : {true, false}) {
CCheckQueueControl<FailingCheck> control(fail_queue.get());
CCheckQueueControl<FixedCheck> control(fail_queue.get());
{
std::vector<FailingCheck> vChecks;
vChecks.resize(100, false);
vChecks[99] = end_fails;
std::vector<FixedCheck> vChecks;
vChecks.resize(100, FixedCheck(std::nullopt));
vChecks[99] = FixedCheck(end_fails ? std::make_optional<int>(2) : std::nullopt);
control.Add(std::move(vChecks));
}
bool r =control.Wait();
bool r = !control.Complete().has_value();
BOOST_REQUIRE(r != end_fails);
}
}
@ -329,8 +327,8 @@ BOOST_AUTO_TEST_CASE(test_CheckQueue_FrozenCleanup)
CCheckQueueControl<FrozenCleanupCheck> control(queue.get());
std::vector<FrozenCleanupCheck> vChecks(1);
control.Add(std::move(vChecks));
bool waitResult = control.Wait(); // Hangs here
assert(waitResult);
auto result = control.Complete(); // Hangs here
assert(!result);
});
{
std::unique_lock<std::mutex> l(FrozenCleanupCheck::m);

View File

@ -19,9 +19,10 @@ struct DumbCheck {
{
}
bool operator()() const
std::optional<int> operator()() const
{
return result;
if (result) return std::nullopt;
return 1;
}
};
} // namespace
@ -45,7 +46,7 @@ FUZZ_TARGET(checkqueue)
check_queue_1.Add(std::move(checks_1));
}
if (fuzzed_data_provider.ConsumeBool()) {
(void)check_queue_1.Wait();
(void)check_queue_1.Complete();
}
CCheckQueueControl<DumbCheck> check_queue_control{&check_queue_2};
@ -53,6 +54,6 @@ FUZZ_TARGET(checkqueue)
check_queue_control.Add(std::move(checks_2));
}
if (fuzzed_data_provider.ConsumeBool()) {
(void)check_queue_control.Wait();
(void)check_queue_control.Complete();
}
}

View File

@ -121,7 +121,7 @@ BOOST_AUTO_TEST_CASE(sign)
{
CScript sigSave = txTo[i].vin[0].scriptSig;
txTo[i].vin[0].scriptSig = txTo[j].vin[0].scriptSig;
bool sigOK = CScriptCheck(txFrom.vout[txTo[i].vin[0].prevout.n], CTransaction(txTo[i]), signature_cache, 0, SCRIPT_VERIFY_P2SH | SCRIPT_VERIFY_STRICTENC, false, &txdata)();
bool sigOK = !CScriptCheck(txFrom.vout[txTo[i].vin[0].prevout.n], CTransaction(txTo[i]), signature_cache, 0, SCRIPT_VERIFY_P2SH | SCRIPT_VERIFY_STRICTENC, false, &txdata)().has_value();
if (i == j)
BOOST_CHECK_MESSAGE(sigOK, strprintf("VerifySignature %d %d", i, j));
else

View File

@ -588,7 +588,7 @@ BOOST_AUTO_TEST_CASE(test_big_witness_transaction)
control.Add(std::move(vChecks));
}
bool controlCheck = control.Wait();
bool controlCheck = !control.Complete().has_value();
assert(controlCheck);
}

View File

@ -2103,10 +2103,15 @@ void UpdateCoins(const CTransaction& tx, CCoinsViewCache& inputs, CTxUndo &txund
AddCoins(inputs, tx, nHeight);
}
bool CScriptCheck::operator()() {
std::optional<ScriptError> CScriptCheck::operator()() {
const CScript &scriptSig = ptxTo->vin[nIn].scriptSig;
const CScriptWitness *witness = &ptxTo->vin[nIn].scriptWitness;
return VerifyScript(scriptSig, m_tx_out.scriptPubKey, witness, nFlags, CachingTransactionSignatureChecker(ptxTo, nIn, m_tx_out.nValue, cacheStore, *m_signature_cache, *txdata), &error);
ScriptError error{SCRIPT_ERR_UNKNOWN_ERROR};
if (VerifyScript(scriptSig, m_tx_out.scriptPubKey, witness, nFlags, CachingTransactionSignatureChecker(ptxTo, nIn, m_tx_out.nValue, cacheStore, *m_signature_cache, *txdata), &error)) {
return std::nullopt;
} else {
return error;
}
}
ValidationCache::ValidationCache(const size_t script_execution_cache_bytes, const size_t signature_cache_bytes)
@ -2195,9 +2200,7 @@ bool CheckInputScripts(const CTransaction& tx, TxValidationState& state,
CScriptCheck check(txdata.m_spent_outputs[i], tx, validation_cache.m_signature_cache, i, flags, cacheSigStore, &txdata);
if (pvChecks) {
pvChecks->emplace_back(std::move(check));
} else if (!check()) {
ScriptError error{check.GetScriptError()};
} else if (auto result = check(); result.has_value()) {
if (flags & STANDARD_NOT_MANDATORY_VERIFY_FLAGS) {
// Check whether the failure was caused by a
// non-mandatory script verification check, such as
@ -2209,21 +2212,23 @@ bool CheckInputScripts(const CTransaction& tx, TxValidationState& state,
// data providers.
CScriptCheck check2(txdata.m_spent_outputs[i], tx, validation_cache.m_signature_cache, i,
flags & ~STANDARD_NOT_MANDATORY_VERIFY_FLAGS, cacheSigStore, &txdata);
if (check2())
return state.Invalid(TxValidationResult::TX_NOT_STANDARD, strprintf("non-mandatory-script-verify-flag (%s)", ScriptErrorString(check.GetScriptError())));
// If the second check failed, it failed due to a mandatory script verification
// flag, but the first check might have failed on a non-mandatory script
// verification flag.
//
// Avoid reporting a mandatory script check failure with a non-mandatory error
// string by reporting the error from the second check.
error = check2.GetScriptError();
auto mandatory_result = check2();
if (!mandatory_result.has_value()) {
return state.Invalid(TxValidationResult::TX_NOT_STANDARD, strprintf("non-mandatory-script-verify-flag (%s)", ScriptErrorString(*result)));
} else {
// If the second check failed, it failed due to a mandatory script verification
// flag, but the first check might have failed on a non-mandatory script
// verification flag.
//
// Avoid reporting a mandatory script check failure with a non-mandatory error
// string by reporting the error from the second check.
result = mandatory_result;
}
}
// MANDATORY flag failures correspond to
// TxValidationResult::TX_CONSENSUS.
return state.Invalid(TxValidationResult::TX_CONSENSUS, strprintf("mandatory-script-verify-flag-failed (%s)", ScriptErrorString(error)));
return state.Invalid(TxValidationResult::TX_CONSENSUS, strprintf("mandatory-script-verify-flag-failed (%s)", ScriptErrorString(*result)));
}
}
@ -2710,7 +2715,8 @@ bool Chainstate::ConnectBlock(const CBlock& block, BlockValidationState& state,
return state.Invalid(BlockValidationResult::BLOCK_CONSENSUS, "bad-cb-amount");
}
if (!control.Wait()) {
auto parallel_result = control.Complete();
if (parallel_result.has_value()) {
LogPrintf("ERROR: %s: CheckQueue failed\n", __func__);
return state.Invalid(BlockValidationResult::BLOCK_CONSENSUS, "block-validation-failed");
}

View File

@ -335,7 +335,6 @@ private:
unsigned int nIn;
unsigned int nFlags;
bool cacheStore;
ScriptError error{SCRIPT_ERR_UNKNOWN_ERROR};
PrecomputedTransactionData *txdata;
SignatureCache* m_signature_cache;
@ -348,9 +347,7 @@ public:
CScriptCheck(CScriptCheck&&) = default;
CScriptCheck& operator=(CScriptCheck&&) = default;
bool operator()();
ScriptError GetScriptError() const { return error; }
std::optional<ScriptError> operator()();
};
// CScriptCheck is used a lot in std::vector, make sure that's efficient