net: add V1Transport lock protecting receive state

Rather than relying on the caller to prevent concurrent calls to the
various receive-side functions of Transport, introduce a private m_cs_recv
inside the implementation to protect the lock state.

Of course, this does not remove the need for callers to synchronize calls
entirely, as it is a stateful object, and e.g. the order in which Receive(),
Complete(), and GetMessage() are called matters. It seems impossible to use
a Transport object in a meaningful way in a multi-threaded way without some
form of external synchronization, but it still feels safer to make the
transport object itself responsible for protecting its internal state.
This commit is contained in:
Pieter Wuille 2023-07-26 13:19:31 -04:00
parent 93594e42c3
commit 27f9ba23ef
2 changed files with 41 additions and 22 deletions

View file

@ -719,6 +719,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
{
AssertLockHeld(m_recv_mutex);
// copy data to temporary parsing buffer
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@ -759,6 +760,7 @@ int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
int V1Transport::readData(Span<const uint8_t> msg_bytes)
{
AssertLockHeld(m_recv_mutex);
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@ -776,7 +778,8 @@ int V1Transport::readData(Span<const uint8_t> msg_bytes)
const uint256& V1Transport::GetMessageHash() const
{
assert(Complete());
AssertLockHeld(m_recv_mutex);
assert(CompleteInternal());
if (data_hash.IsNull())
hasher.Finalize(data_hash);
return data_hash;
@ -784,9 +787,11 @@ const uint256& V1Transport::GetMessageHash() const
CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message)
{
AssertLockNotHeld(m_recv_mutex);
// Initialize out parameter
reject_message = false;
// decompose a single CNetMessage from the TransportDeserializer
LOCK(m_recv_mutex);
CNetMessage msg(std::move(vRecv));
// store message type string, time, and sizes

View file

@ -259,8 +259,7 @@ public:
virtual ~Transport() {}
// 1. Receiver side functions, for decoding bytes received on the wire into transport protocol
// agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of
// these functions are called concurrently w.r.t. one another.
// agnostic CNetMessage (message type & payload) objects.
// returns true if the current deserialization is complete
virtual bool Complete() const = 0;
@ -282,20 +281,22 @@ class V1Transport final : public Transport
private:
const CChainParams& m_chain_params;
const NodeId m_node_id; // Only for logging
mutable CHash256 hasher;
mutable uint256 data_hash;
bool in_data; // parsing header (false) or data (true)
CDataStream hdrbuf; // partially received header
CMessageHeader hdr; // complete header
CDataStream vRecv; // received message data
unsigned int nHdrPos;
unsigned int nDataPos;
mutable Mutex m_recv_mutex; //!< Lock for receive state
mutable CHash256 hasher GUARDED_BY(m_recv_mutex);
mutable uint256 data_hash GUARDED_BY(m_recv_mutex);
bool in_data GUARDED_BY(m_recv_mutex); // parsing header (false) or data (true)
CDataStream hdrbuf GUARDED_BY(m_recv_mutex); // partially received header
CMessageHeader hdr GUARDED_BY(m_recv_mutex); // complete header
CDataStream vRecv GUARDED_BY(m_recv_mutex); // received message data
unsigned int nHdrPos GUARDED_BY(m_recv_mutex);
unsigned int nDataPos GUARDED_BY(m_recv_mutex);
const uint256& GetMessageHash() const;
int readHeader(Span<const uint8_t> msg_bytes);
int readData(Span<const uint8_t> msg_bytes);
const uint256& GetMessageHash() const EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
int readHeader(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
int readData(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
void Reset() {
void Reset() EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) {
AssertLockHeld(m_recv_mutex);
vRecv.clear();
hdrbuf.clear();
hdrbuf.resize(24);
@ -306,6 +307,13 @@ private:
hasher.Reset();
}
bool CompleteInternal() const noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex)
{
AssertLockHeld(m_recv_mutex);
if (!in_data) return false;
return hdr.nMessageSize == nDataPos;
}
public:
V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
: m_chain_params(chain_params),
@ -313,22 +321,28 @@ public:
hdrbuf(nTypeIn, nVersionIn),
vRecv(nTypeIn, nVersionIn)
{
LOCK(m_recv_mutex);
Reset();
}
bool Complete() const override
bool Complete() const override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
if (!in_data)
return false;
return (hdr.nMessageSize == nDataPos);
AssertLockNotHeld(m_recv_mutex);
return WITH_LOCK(m_recv_mutex, return CompleteInternal());
}
void SetVersion(int nVersionIn) override
void SetVersion(int nVersionIn) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
AssertLockNotHeld(m_recv_mutex);
LOCK(m_recv_mutex);
hdrbuf.SetVersion(nVersionIn);
vRecv.SetVersion(nVersionIn);
}
int Read(Span<const uint8_t>& msg_bytes) override
int Read(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
AssertLockNotHeld(m_recv_mutex);
LOCK(m_recv_mutex);
int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes);
if (ret < 0) {
Reset();
@ -337,7 +351,7 @@ public:
}
return ret;
}
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
};