diff --git a/src/core_io.h b/src/core_io.h index 6f87161f46a..ae377eb6e81 100644 --- a/src/core_io.h +++ b/src/core_io.h @@ -37,7 +37,11 @@ bool DecodeHexBlockHeader(CBlockHeader&, const std::string& hex_header); */ bool ParseHashStr(const std::string& strHex, uint256& result); std::vector ParseHexUV(const UniValue& v, const std::string& strName); -NODISCARD bool DecodePSBT(PartiallySignedTransaction& psbt, const std::string& base64_tx, std::string& error); + +//! Decode a base64ed PSBT into a PartiallySignedTransaction +NODISCARD bool DecodeBase64PSBT(PartiallySignedTransaction& decoded_psbt, const std::string& base64_psbt, std::string& error); +//! Decode a raw (binary blob) PSBT into a PartiallySignedTransaction +NODISCARD bool DecodeRawPSBT(PartiallySignedTransaction& decoded_psbt, const std::string& raw_psbt, std::string& error); int ParseSighashString(const UniValue& sighash); // core_write.cpp diff --git a/src/core_read.cpp b/src/core_read.cpp index 3b82b2853c8..0379098712f 100644 --- a/src/core_read.cpp +++ b/src/core_read.cpp @@ -176,10 +176,20 @@ bool DecodeHexBlk(CBlock& block, const std::string& strHexBlk) return true; } -bool DecodePSBT(PartiallySignedTransaction& psbt, const std::string& base64_tx, std::string& error) +bool DecodeBase64PSBT(PartiallySignedTransaction& psbt, const std::string& base64_tx, std::string& error) { - std::vector tx_data = DecodeBase64(base64_tx.c_str()); - CDataStream ss_data(tx_data, SER_NETWORK, PROTOCOL_VERSION); + bool invalid; + std::string tx_data = DecodeBase64(base64_tx, &invalid); + if (invalid) { + error = "invalid base64"; + return false; + } + return DecodeRawPSBT(psbt, tx_data, error); +} + +bool DecodeRawPSBT(PartiallySignedTransaction& psbt, const std::string& tx_data, std::string& error) +{ + CDataStream ss_data(tx_data.data(), tx_data.data() + tx_data.size(), SER_NETWORK, PROTOCOL_VERSION); try { ss_data >> psbt; if (!ss_data.empty()) { diff --git a/src/rpc/rawtransaction.cpp b/src/rpc/rawtransaction.cpp index cce62bacef7..9dac989b977 100644 --- a/src/rpc/rawtransaction.cpp +++ b/src/rpc/rawtransaction.cpp @@ -1323,7 +1323,7 @@ UniValue decodepsbt(const JSONRPCRequest& request) // Unserialize the transactions PartiallySignedTransaction psbtx; std::string error; - if (!DecodePSBT(psbtx, request.params[0].get_str(), error)) { + if (!DecodeBase64PSBT(psbtx, request.params[0].get_str(), error)) { throw JSONRPCError(RPC_DESERIALIZATION_ERROR, strprintf("TX decode failed %s", error)); } @@ -1524,7 +1524,7 @@ UniValue combinepsbt(const JSONRPCRequest& request) for (unsigned int i = 0; i < txs.size(); ++i) { PartiallySignedTransaction psbtx; std::string error; - if (!DecodePSBT(psbtx, txs[i].get_str(), error)) { + if (!DecodeBase64PSBT(psbtx, txs[i].get_str(), error)) { throw JSONRPCError(RPC_DESERIALIZATION_ERROR, strprintf("TX decode failed %s", error)); } psbtxs.push_back(psbtx); @@ -1581,7 +1581,7 @@ UniValue finalizepsbt(const JSONRPCRequest& request) // Unserialize the transactions PartiallySignedTransaction psbtx; std::string error; - if (!DecodePSBT(psbtx, request.params[0].get_str(), error)) { + if (!DecodeBase64PSBT(psbtx, request.params[0].get_str(), error)) { throw JSONRPCError(RPC_DESERIALIZATION_ERROR, strprintf("TX decode failed %s", error)); } diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index 9bbbdc61322..b01f40b1251 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -4046,7 +4046,7 @@ UniValue walletprocesspsbt(const JSONRPCRequest& request) // Unserialize the transaction PartiallySignedTransaction psbtx; std::string error; - if (!DecodePSBT(psbtx, request.params[0].get_str(), error)) { + if (!DecodeBase64PSBT(psbtx, request.params[0].get_str(), error)) { throw JSONRPCError(RPC_DESERIALIZATION_ERROR, strprintf("TX decode failed %s", error)); } diff --git a/test/functional/rpc_psbt.py b/test/functional/rpc_psbt.py index a82a5d0208f..c98f1058282 100755 --- a/test/functional/rpc_psbt.py +++ b/test/functional/rpc_psbt.py @@ -293,5 +293,8 @@ class PSBTTest(BitcoinTestFramework): psbt = self.nodes[1].walletcreatefundedpsbt([], [{p2pkh : 1}], 0, {"includeWatching" : True}, True) self.nodes[0].decodepsbt(psbt['psbt']) + # Test decoding error: invalid base64 + assert_raises_rpc_error(-22, "TX decode failed invalid base64", self.nodes[0].decodepsbt, ";definitely not base64;") + if __name__ == '__main__': PSBTTest().main()