mirror of
https://github.com/bitcoin/bitcoin.git
synced 2025-02-22 06:52:36 +01:00
net: add V1Transport lock protecting receive state
Rather than relying on the caller to prevent concurrent calls to the various receive-side functions of Transport, introduce a private m_cs_recv inside the implementation to protect the lock state. Of course, this does not remove the need for callers to synchronize calls entirely, as it is a stateful object, and e.g. the order in which Receive(), Complete(), and GetMessage() are called matters. It seems impossible to use a Transport object in a meaningful way in a multi-threaded way without some form of external synchronization, but it still feels safer to make the transport object itself responsible for protecting its internal state.
This commit is contained in:
parent
93594e42c3
commit
27f9ba23ef
2 changed files with 41 additions and 22 deletions
|
@ -719,6 +719,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
|
|||
|
||||
int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
|
||||
{
|
||||
AssertLockHeld(m_recv_mutex);
|
||||
// copy data to temporary parsing buffer
|
||||
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
|
||||
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
|
||||
|
@ -759,6 +760,7 @@ int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
|
|||
|
||||
int V1Transport::readData(Span<const uint8_t> msg_bytes)
|
||||
{
|
||||
AssertLockHeld(m_recv_mutex);
|
||||
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
|
||||
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
|
||||
|
||||
|
@ -776,7 +778,8 @@ int V1Transport::readData(Span<const uint8_t> msg_bytes)
|
|||
|
||||
const uint256& V1Transport::GetMessageHash() const
|
||||
{
|
||||
assert(Complete());
|
||||
AssertLockHeld(m_recv_mutex);
|
||||
assert(CompleteInternal());
|
||||
if (data_hash.IsNull())
|
||||
hasher.Finalize(data_hash);
|
||||
return data_hash;
|
||||
|
@ -784,9 +787,11 @@ const uint256& V1Transport::GetMessageHash() const
|
|||
|
||||
CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message)
|
||||
{
|
||||
AssertLockNotHeld(m_recv_mutex);
|
||||
// Initialize out parameter
|
||||
reject_message = false;
|
||||
// decompose a single CNetMessage from the TransportDeserializer
|
||||
LOCK(m_recv_mutex);
|
||||
CNetMessage msg(std::move(vRecv));
|
||||
|
||||
// store message type string, time, and sizes
|
||||
|
|
56
src/net.h
56
src/net.h
|
@ -259,8 +259,7 @@ public:
|
|||
virtual ~Transport() {}
|
||||
|
||||
// 1. Receiver side functions, for decoding bytes received on the wire into transport protocol
|
||||
// agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of
|
||||
// these functions are called concurrently w.r.t. one another.
|
||||
// agnostic CNetMessage (message type & payload) objects.
|
||||
|
||||
// returns true if the current deserialization is complete
|
||||
virtual bool Complete() const = 0;
|
||||
|
@ -282,20 +281,22 @@ class V1Transport final : public Transport
|
|||
private:
|
||||
const CChainParams& m_chain_params;
|
||||
const NodeId m_node_id; // Only for logging
|
||||
mutable CHash256 hasher;
|
||||
mutable uint256 data_hash;
|
||||
bool in_data; // parsing header (false) or data (true)
|
||||
CDataStream hdrbuf; // partially received header
|
||||
CMessageHeader hdr; // complete header
|
||||
CDataStream vRecv; // received message data
|
||||
unsigned int nHdrPos;
|
||||
unsigned int nDataPos;
|
||||
mutable Mutex m_recv_mutex; //!< Lock for receive state
|
||||
mutable CHash256 hasher GUARDED_BY(m_recv_mutex);
|
||||
mutable uint256 data_hash GUARDED_BY(m_recv_mutex);
|
||||
bool in_data GUARDED_BY(m_recv_mutex); // parsing header (false) or data (true)
|
||||
CDataStream hdrbuf GUARDED_BY(m_recv_mutex); // partially received header
|
||||
CMessageHeader hdr GUARDED_BY(m_recv_mutex); // complete header
|
||||
CDataStream vRecv GUARDED_BY(m_recv_mutex); // received message data
|
||||
unsigned int nHdrPos GUARDED_BY(m_recv_mutex);
|
||||
unsigned int nDataPos GUARDED_BY(m_recv_mutex);
|
||||
|
||||
const uint256& GetMessageHash() const;
|
||||
int readHeader(Span<const uint8_t> msg_bytes);
|
||||
int readData(Span<const uint8_t> msg_bytes);
|
||||
const uint256& GetMessageHash() const EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
|
||||
int readHeader(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
|
||||
int readData(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
|
||||
|
||||
void Reset() {
|
||||
void Reset() EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) {
|
||||
AssertLockHeld(m_recv_mutex);
|
||||
vRecv.clear();
|
||||
hdrbuf.clear();
|
||||
hdrbuf.resize(24);
|
||||
|
@ -306,6 +307,13 @@ private:
|
|||
hasher.Reset();
|
||||
}
|
||||
|
||||
bool CompleteInternal() const noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex)
|
||||
{
|
||||
AssertLockHeld(m_recv_mutex);
|
||||
if (!in_data) return false;
|
||||
return hdr.nMessageSize == nDataPos;
|
||||
}
|
||||
|
||||
public:
|
||||
V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
|
||||
: m_chain_params(chain_params),
|
||||
|
@ -313,22 +321,28 @@ public:
|
|||
hdrbuf(nTypeIn, nVersionIn),
|
||||
vRecv(nTypeIn, nVersionIn)
|
||||
{
|
||||
LOCK(m_recv_mutex);
|
||||
Reset();
|
||||
}
|
||||
|
||||
bool Complete() const override
|
||||
bool Complete() const override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
|
||||
{
|
||||
if (!in_data)
|
||||
return false;
|
||||
return (hdr.nMessageSize == nDataPos);
|
||||
AssertLockNotHeld(m_recv_mutex);
|
||||
return WITH_LOCK(m_recv_mutex, return CompleteInternal());
|
||||
}
|
||||
void SetVersion(int nVersionIn) override
|
||||
|
||||
void SetVersion(int nVersionIn) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
|
||||
{
|
||||
AssertLockNotHeld(m_recv_mutex);
|
||||
LOCK(m_recv_mutex);
|
||||
hdrbuf.SetVersion(nVersionIn);
|
||||
vRecv.SetVersion(nVersionIn);
|
||||
}
|
||||
int Read(Span<const uint8_t>& msg_bytes) override
|
||||
|
||||
int Read(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
|
||||
{
|
||||
AssertLockNotHeld(m_recv_mutex);
|
||||
LOCK(m_recv_mutex);
|
||||
int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes);
|
||||
if (ret < 0) {
|
||||
Reset();
|
||||
|
@ -337,7 +351,7 @@ public:
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
|
||||
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
|
||||
|
||||
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
|
||||
};
|
||||
|
|
Loading…
Add table
Reference in a new issue