net: Treat raw message bytes as uint8_t

This commit is contained in:
MarcoFalke 2020-11-20 10:16:10 +01:00
parent fdd068507d
commit fabecce719
No known key found for this signature in database
GPG key ID: CE2B75697E69A548
6 changed files with 19 additions and 19 deletions

View file

@ -629,7 +629,7 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap)
} }
#undef X #undef X
bool CNode::ReceiveMsgBytes(Span<const char> msg_bytes, bool& complete) bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
{ {
complete = false; complete = false;
const auto time = GetTime<std::chrono::microseconds>(); const auto time = GetTime<std::chrono::microseconds>();
@ -673,7 +673,7 @@ bool CNode::ReceiveMsgBytes(Span<const char> msg_bytes, bool& complete)
return true; return true;
} }
int V1TransportDeserializer::readHeader(Span<const char> msg_bytes) int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes)
{ {
// copy data to temporary parsing buffer // copy data to temporary parsing buffer
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos; unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
@ -713,7 +713,7 @@ int V1TransportDeserializer::readHeader(Span<const char> msg_bytes)
return nCopy; return nCopy;
} }
int V1TransportDeserializer::readData(Span<const char> msg_bytes) int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes)
{ {
unsigned int nRemaining = hdr.nMessageSize - nDataPos; unsigned int nRemaining = hdr.nMessageSize - nDataPos;
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size()); unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@ -723,7 +723,7 @@ int V1TransportDeserializer::readData(Span<const char> msg_bytes)
vRecv.resize(std::min(hdr.nMessageSize, nDataPos + nCopy + 256 * 1024)); vRecv.resize(std::min(hdr.nMessageSize, nDataPos + nCopy + 256 * 1024));
} }
hasher.Write(MakeUCharSpan(msg_bytes.first(nCopy))); hasher.Write(msg_bytes.first(nCopy));
memcpy(&vRecv[nDataPos], msg_bytes.data(), nCopy); memcpy(&vRecv[nDataPos], msg_bytes.data(), nCopy);
nDataPos += nCopy; nDataPos += nCopy;
@ -1463,18 +1463,18 @@ void CConnman::SocketHandler()
if (recvSet || errorSet) if (recvSet || errorSet)
{ {
// typical socket buffer is 8K-64K // typical socket buffer is 8K-64K
char pchBuf[0x10000]; uint8_t pchBuf[0x10000];
int nBytes = 0; int nBytes = 0;
{ {
LOCK(pnode->cs_hSocket); LOCK(pnode->cs_hSocket);
if (pnode->hSocket == INVALID_SOCKET) if (pnode->hSocket == INVALID_SOCKET)
continue; continue;
nBytes = recv(pnode->hSocket, pchBuf, sizeof(pchBuf), MSG_DONTWAIT); nBytes = recv(pnode->hSocket, (char*)pchBuf, sizeof(pchBuf), MSG_DONTWAIT);
} }
if (nBytes > 0) if (nBytes > 0)
{ {
bool notify = false; bool notify = false;
if (!pnode->ReceiveMsgBytes(Span<const char>(pchBuf, nBytes), notify)) if (!pnode->ReceiveMsgBytes(Span<const uint8_t>(pchBuf, nBytes), notify))
pnode->CloseSocketDisconnect(); pnode->CloseSocketDisconnect();
RecordBytesRecv(nBytes); RecordBytesRecv(nBytes);
if (notify) { if (notify) {

View file

@ -758,7 +758,7 @@ public:
// set the serialization context version // set the serialization context version
virtual void SetVersion(int version) = 0; virtual void SetVersion(int version) = 0;
/** read and deserialize data, advances msg_bytes data pointer */ /** read and deserialize data, advances msg_bytes data pointer */
virtual int Read(Span<const char>& msg_bytes) = 0; virtual int Read(Span<const uint8_t>& msg_bytes) = 0;
// decomposes a message from the context // decomposes a message from the context
virtual Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0; virtual Optional<CNetMessage> GetMessage(std::chrono::microseconds time, uint32_t& out_err) = 0;
virtual ~TransportDeserializer() {} virtual ~TransportDeserializer() {}
@ -779,8 +779,8 @@ private:
unsigned int nDataPos; unsigned int nDataPos;
const uint256& GetMessageHash() const; const uint256& GetMessageHash() const;
int readHeader(Span<const char> msg_bytes); int readHeader(Span<const uint8_t> msg_bytes);
int readData(Span<const char> msg_bytes); int readData(Span<const uint8_t> msg_bytes);
void Reset() { void Reset() {
vRecv.clear(); vRecv.clear();
@ -814,7 +814,7 @@ public:
hdrbuf.SetVersion(nVersionIn); hdrbuf.SetVersion(nVersionIn);
vRecv.SetVersion(nVersionIn); vRecv.SetVersion(nVersionIn);
} }
int Read(Span<const char>& msg_bytes) override int Read(Span<const uint8_t>& msg_bytes) override
{ {
int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes); int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes);
if (ret < 0) { if (ret < 0) {
@ -1132,7 +1132,7 @@ public:
* @return True if the peer should stay connected, * @return True if the peer should stay connected,
* False if the peer should be disconnected from. * False if the peer should be disconnected from.
*/ */
bool ReceiveMsgBytes(Span<const char> msg_bytes, bool& complete); bool ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete);
void SetCommonVersion(int greatest_common_version) void SetCommonVersion(int greatest_common_version)
{ {

View file

@ -128,7 +128,7 @@ void test_one_input(const std::vector<uint8_t>& buffer)
case 11: { case 11: {
const std::vector<uint8_t> b = ConsumeRandomLengthByteVector(fuzzed_data_provider); const std::vector<uint8_t> b = ConsumeRandomLengthByteVector(fuzzed_data_provider);
bool complete; bool complete;
node.ReceiveMsgBytes({(const char*)b.data(), b.size()}, complete); node.ReceiveMsgBytes(b, complete);
break; break;
} }
} }

View file

@ -21,7 +21,7 @@ void test_one_input(const std::vector<uint8_t>& buffer)
{ {
// Construct deserializer, with a dummy NodeId // Construct deserializer, with a dummy NodeId
V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION};
Span<const char> msg_bytes{(const char*)buffer.data(), buffer.size()}; Span<const uint8_t> msg_bytes{buffer};
while (msg_bytes.size() > 0) { while (msg_bytes.size() > 0) {
const int handled = deserializer.Read(msg_bytes); const int handled = deserializer.Read(msg_bytes);
if (handled < 0) { if (handled < 0) {

View file

@ -7,7 +7,7 @@
#include <chainparams.h> #include <chainparams.h>
#include <net.h> #include <net.h>
void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const char> msg_bytes, bool& complete) const void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const uint8_t> msg_bytes, bool& complete) const
{ {
assert(node.ReceiveMsgBytes(msg_bytes, complete)); assert(node.ReceiveMsgBytes(msg_bytes, complete));
if (complete) { if (complete) {
@ -29,11 +29,11 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const char> msg_bytes
bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const
{ {
std::vector<unsigned char> ser_msg_header; std::vector<uint8_t> ser_msg_header;
node.m_serializer->prepareForTransport(ser_msg, ser_msg_header); node.m_serializer->prepareForTransport(ser_msg, ser_msg_header);
bool complete; bool complete;
NodeReceiveMsgBytes(node, {(const char*)ser_msg_header.data(), ser_msg_header.size()}, complete); NodeReceiveMsgBytes(node, ser_msg_header, complete);
NodeReceiveMsgBytes(node, {(const char*)ser_msg.data.data(), ser_msg.data.size()}, complete); NodeReceiveMsgBytes(node, ser_msg.data, complete);
return complete; return complete;
} }

View file

@ -25,7 +25,7 @@ struct ConnmanTestMsg : public CConnman {
void ProcessMessagesOnce(CNode& node) { m_msgproc->ProcessMessages(&node, flagInterruptMsgProc); } void ProcessMessagesOnce(CNode& node) { m_msgproc->ProcessMessages(&node, flagInterruptMsgProc); }
void NodeReceiveMsgBytes(CNode& node, Span<const char> msg_bytes, bool& complete) const; void NodeReceiveMsgBytes(CNode& node, Span<const uint8_t> msg_bytes, bool& complete) const;
bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const; bool ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const;
}; };