From 021feb3187b207d511561c1f0ffd7f9e5e0c9c1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Barbosa?= Date: Thu, 17 Sep 2020 22:23:45 +0100 Subject: [PATCH 1/2] refactor: Drop redudant CWallet::GetDBHandle --- src/wallet/wallet.h | 7 ------- src/wallet/walletdb.cpp | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 74de55dcb52..6ad75d3d69b 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -732,13 +732,6 @@ public: */ mutable RecursiveMutex cs_wallet; - /** Get database handle used by this wallet. Ideally this function would - * not be necessary. - */ - WalletDatabase& GetDBHandle() - { - return *database; - } WalletDatabase& GetDatabase() const override { return *database; } /** diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index aa3b3c10b08..9ed98184793 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -943,7 +943,7 @@ void MaybeCompactWalletDB() } for (const std::shared_ptr& pwallet : GetWallets()) { - WalletDatabase& dbh = pwallet->GetDBHandle(); + WalletDatabase& dbh = pwallet->GetDatabase(); unsigned int nUpdateCounter = dbh.nUpdateCounter; From 9b74461fa293453a9eb0b1717b30b3f7fa778d91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Barbosa?= Date: Sun, 20 Sep 2020 00:25:45 +0100 Subject: [PATCH 2/2] refactor: Assert before dereference in CWallet::GetDatabase --- src/wallet/wallet.cpp | 56 +++++++++++++++++++++---------------------- src/wallet/wallet.h | 10 +++++--- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index d1cde6aa896..8609520c1b6 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -419,7 +419,7 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, return false; if (!crypter.Encrypt(_vMasterKey, pMasterKey.second.vchCryptedKey)) return false; - WalletBatch(*database).WriteMasterKey(pMasterKey.first, pMasterKey.second); + WalletBatch(GetDatabase()).WriteMasterKey(pMasterKey.first, pMasterKey.second); if (fWasLocked) Lock(); return true; @@ -432,7 +432,7 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, void CWallet::chainStateFlushed(const CBlockLocator& loc) { - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); batch.WriteBestBlock(loc); } @@ -452,7 +452,7 @@ void CWallet::SetMinVersion(enum WalletFeature nVersion, WalletBatch* batch_in, nWalletMaxVersion = nVersion; { - WalletBatch* batch = batch_in ? batch_in : new WalletBatch(*database); + WalletBatch* batch = batch_in ? batch_in : new WalletBatch(GetDatabase()); if (nWalletVersion > 40000) batch->WriteMinVersion(nWalletVersion); if (!batch_in) @@ -504,12 +504,12 @@ bool CWallet::HasWalletSpend(const uint256& txid) const void CWallet::Flush() { - database->Flush(); + GetDatabase().Flush(); } void CWallet::Close() { - database->Close(); + GetDatabase().Close(); } void CWallet::SyncMetaData(std::pair range) @@ -635,7 +635,7 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) { LOCK(cs_wallet); mapMasterKeys[++nMasterKeyMaxID] = kMasterKey; - WalletBatch* encrypted_batch = new WalletBatch(*database); + WalletBatch* encrypted_batch = new WalletBatch(GetDatabase()); if (!encrypted_batch->TxnBegin()) { delete encrypted_batch; encrypted_batch = nullptr; @@ -687,12 +687,12 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) // Need to completely rewrite the wallet file; if we don't, bdb might keep // bits of the unencrypted private key in slack space in the database file. - database->Rewrite(); + GetDatabase().Rewrite(); // BDB seems to have a bad habit of writing old data into // slack space in .dat files; that is bad if the old data is // unencrypted private keys. So: - database->ReloadDbEnv(); + GetDatabase().ReloadDbEnv(); } NotifyStatusChanged(this); @@ -703,7 +703,7 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) DBErrors CWallet::ReorderTransactions() { LOCK(cs_wallet); - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); // Old wallets didn't have any defined order for transactions // Probably a bad idea to change the output of this @@ -764,7 +764,7 @@ int64_t CWallet::IncOrderPosNext(WalletBatch* batch) if (batch) { batch->WriteOrderPosNext(nOrderPosNext); } else { - WalletBatch(*database).WriteOrderPosNext(nOrderPosNext); + WalletBatch(GetDatabase()).WriteOrderPosNext(nOrderPosNext); } return nRet; } @@ -794,7 +794,7 @@ bool CWallet::MarkReplaced(const uint256& originalHash, const uint256& newHash) wtx.mapValue["replaced_by_txid"] = newHash.ToString(); - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); bool success = true; if (!batch.WriteTx(wtx)) { @@ -866,7 +866,7 @@ CWalletTx* CWallet::AddToWallet(CTransactionRef tx, const CWalletTx::Confirmatio { LOCK(cs_wallet); - WalletBatch batch(*database, fFlushOnClose); + WalletBatch batch(GetDatabase(), fFlushOnClose); uint256 hash = tx->GetHash(); @@ -1065,7 +1065,7 @@ bool CWallet::AbandonTransaction(const uint256& hashTx) { LOCK(cs_wallet); - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); std::set todo; std::set done; @@ -1128,7 +1128,7 @@ void CWallet::MarkConflicted(const uint256& hashBlock, int conflicting_height, c return; // Do not flush the wallet here for performance reasons - WalletBatch batch(*database, false); + WalletBatch batch(GetDatabase(), false); std::set todo; std::set done; @@ -1466,13 +1466,13 @@ void CWallet::SetWalletFlag(uint64_t flags) { LOCK(cs_wallet); m_wallet_flags |= flags; - if (!WalletBatch(*database).WriteWalletFlags(m_wallet_flags)) + if (!WalletBatch(GetDatabase()).WriteWalletFlags(m_wallet_flags)) throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed"); } void CWallet::UnsetWalletFlag(uint64_t flag) { - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); UnsetWalletFlagWithDB(batch, flag); } @@ -1511,7 +1511,7 @@ bool CWallet::AddWalletFlags(uint64_t flags) LOCK(cs_wallet); // We should never be writing unknown non-tolerable wallet flags assert(((flags & KNOWN_WALLET_FLAGS) >> 32) == (flags >> 32)); - if (!WalletBatch(*database).WriteWalletFlags(flags)) { + if (!WalletBatch(GetDatabase()).WriteWalletFlags(flags)) { throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed"); } @@ -1602,7 +1602,7 @@ bool CWallet::ImportScriptPubKeys(const std::string& label, const std::setRewrite("\x04pool")) + if (GetDatabase().Rewrite("\x04pool")) { for (const auto& spk_man_pair : m_spk_managers) { spk_man_pair.second->RewriteDB(); @@ -3220,7 +3220,7 @@ DBErrors CWallet::LoadWallet(bool& fFirstRunRet) DBErrors CWallet::ZapSelectTx(std::vector& vHashIn, std::vector& vHashOut) { AssertLockHeld(cs_wallet); - DBErrors nZapSelectTxRet = WalletBatch(*database).ZapSelectTx(vHashIn, vHashOut); + DBErrors nZapSelectTxRet = WalletBatch(GetDatabase()).ZapSelectTx(vHashIn, vHashOut); for (const uint256& hash : vHashOut) { const auto& it = mapWallet.find(hash); wtxOrdered.erase(it->second.m_it_wtxOrdered); @@ -3232,7 +3232,7 @@ DBErrors CWallet::ZapSelectTx(std::vector& vHashIn, std::vectorRewrite("\x04pool")) + if (GetDatabase().Rewrite("\x04pool")) { for (const auto& spk_man_pair : m_spk_managers) { spk_man_pair.second->RewriteDB(); @@ -3270,14 +3270,14 @@ bool CWallet::SetAddressBookWithDB(WalletBatch& batch, const CTxDestination& add bool CWallet::SetAddressBook(const CTxDestination& address, const std::string& strName, const std::string& strPurpose) { - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); return SetAddressBookWithDB(batch, address, strName, strPurpose); } bool CWallet::DelAddressBook(const CTxDestination& address) { bool is_mine; - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); { LOCK(cs_wallet); // If we want to delete receiving addresses, we need to take care that DestData "used" (and possibly newer DestData) gets preserved (and the "deleted" address transformed into a change entry instead of actually being deleted) @@ -4024,7 +4024,7 @@ std::shared_ptr CWallet::Create(interfaces::Chain& chain, const std::st int rescan_height = 0; if (!gArgs.GetBoolArg("-rescan", false)) { - WalletBatch batch(*walletInstance->database); + WalletBatch batch(walletInstance->GetDatabase()); CBlockLocator locator; if (batch.ReadBestBlock(locator)) { if (const Optional fork_height = chain.findLocatorFork(locator)) { @@ -4087,7 +4087,7 @@ std::shared_ptr CWallet::Create(interfaces::Chain& chain, const std::st } } walletInstance->chainStateFlushed(chain.getTipLocator()); - walletInstance->database->IncrementUpdateCounter(); + walletInstance->GetDatabase().IncrementUpdateCounter(); } { @@ -4168,7 +4168,7 @@ void CWallet::postInitProcess() bool CWallet::BackupWallet(const std::string& strDest) const { - return database->Backup(strDest); + return GetDatabase().Backup(strDest); } CKeyPool::CKeyPool() @@ -4471,7 +4471,7 @@ void CWallet::SetupDescriptorScriptPubKeyMans() void CWallet::AddActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal) { - WalletBatch batch(*database); + WalletBatch batch(GetDatabase()); if (!batch.WriteActiveScriptPubKeyMan(static_cast(type), id, internal)) { throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed"); } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 6ad75d3d69b..5e23622e237 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -698,7 +698,7 @@ private: std::string m_name; /** Internal database handle. */ - std::unique_ptr database; + std::unique_ptr const m_database; /** * The following is used to keep track of how far behind the wallet is @@ -732,7 +732,11 @@ public: */ mutable RecursiveMutex cs_wallet; - WalletDatabase& GetDatabase() const override { return *database; } + WalletDatabase& GetDatabase() const override + { + assert(static_cast(m_database)); + return *m_database; + } /** * Select a set of coins such that nValueRet >= nTargetValue and at least @@ -754,7 +758,7 @@ public: CWallet(interfaces::Chain* chain, const std::string& name, std::unique_ptr database) : m_chain(chain), m_name(name), - database(std::move(database)) + m_database(std::move(database)) { }