mirror of
https://github.com/bitcoin/bitcoin.git
synced 2025-02-22 06:52:36 +01:00
net: add RAII socket and use it instead of bare SOCKET
Introduce a class to manage the lifetime of a socket - when the object that contains the socket goes out of scope, the underlying socket will be closed. In addition, the new `Sock` class has a `Send()`, `Recv()` and `Wait()` methods that can be overridden by unit tests to mock the socket operations. The `Wait()` method also hides the `#ifdef USE_POLL poll() #else select() #endif` technique from higher level code.
This commit is contained in:
parent
dec9b5e850
commit
ba9d73268f
7 changed files with 250 additions and 41 deletions
47
src/net.cpp
47
src/net.cpp
|
@ -429,24 +429,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
|
|||
|
||||
// Connect
|
||||
bool connected = false;
|
||||
SOCKET hSocket = INVALID_SOCKET;
|
||||
std::unique_ptr<Sock> sock;
|
||||
proxyType proxy;
|
||||
if (addrConnect.IsValid()) {
|
||||
bool proxyConnectionFailed = false;
|
||||
|
||||
if (GetProxy(addrConnect.GetNetwork(), proxy)) {
|
||||
hSocket = CreateSocket(proxy.proxy);
|
||||
if (hSocket == INVALID_SOCKET) {
|
||||
sock = CreateSock(proxy.proxy);
|
||||
if (!sock) {
|
||||
return nullptr;
|
||||
}
|
||||
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), hSocket, nConnectTimeout, proxyConnectionFailed);
|
||||
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(),
|
||||
sock->Get(), nConnectTimeout, proxyConnectionFailed);
|
||||
} else {
|
||||
// no proxy needed (none set for target network)
|
||||
hSocket = CreateSocket(addrConnect);
|
||||
if (hSocket == INVALID_SOCKET) {
|
||||
sock = CreateSock(addrConnect);
|
||||
if (!sock) {
|
||||
return nullptr;
|
||||
}
|
||||
connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, conn_type == ConnectionType::MANUAL);
|
||||
connected = ConnectSocketDirectly(addrConnect, sock->Get(), nConnectTimeout,
|
||||
conn_type == ConnectionType::MANUAL);
|
||||
}
|
||||
if (!proxyConnectionFailed) {
|
||||
// If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to
|
||||
|
@ -454,26 +456,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
|
|||
addrman.Attempt(addrConnect, fCountFailure);
|
||||
}
|
||||
} else if (pszDest && GetNameProxy(proxy)) {
|
||||
hSocket = CreateSocket(proxy.proxy);
|
||||
if (hSocket == INVALID_SOCKET) {
|
||||
sock = CreateSock(proxy.proxy);
|
||||
if (!sock) {
|
||||
return nullptr;
|
||||
}
|
||||
std::string host;
|
||||
int port = default_port;
|
||||
SplitHostPort(std::string(pszDest), port, host);
|
||||
bool proxyConnectionFailed;
|
||||
connected = ConnectThroughProxy(proxy, host, port, hSocket, nConnectTimeout, proxyConnectionFailed);
|
||||
connected = ConnectThroughProxy(proxy, host, port, sock->Get(), nConnectTimeout,
|
||||
proxyConnectionFailed);
|
||||
}
|
||||
if (!connected) {
|
||||
CloseSocket(hSocket);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Add node
|
||||
NodeId id = GetNewNodeId();
|
||||
uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize();
|
||||
CAddress addr_bind = GetBindAddress(hSocket);
|
||||
CNode* pnode = new CNode(id, nLocalServices, hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type);
|
||||
CAddress addr_bind = GetBindAddress(sock->Get());
|
||||
CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type);
|
||||
pnode->AddRef();
|
||||
|
||||
// We're making a new connection, harvest entropy from the time (and our peer count)
|
||||
|
@ -2177,9 +2179,8 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
|
|||
return false;
|
||||
}
|
||||
|
||||
SOCKET hListenSocket = CreateSocket(addrBind);
|
||||
if (hListenSocket == INVALID_SOCKET)
|
||||
{
|
||||
std::unique_ptr<Sock> sock = CreateSock(addrBind);
|
||||
if (!sock) {
|
||||
strError = strprintf(Untranslated("Error: Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError()));
|
||||
LogPrintf("%s\n", strError.original);
|
||||
return false;
|
||||
|
@ -2187,21 +2188,21 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
|
|||
|
||||
// Allow binding if the port is still in TIME_WAIT state after
|
||||
// the program was closed and restarted.
|
||||
setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int));
|
||||
setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int));
|
||||
|
||||
// some systems don't have IPV6_V6ONLY but are always v6only; others do have the option
|
||||
// and enable it by default or not. Try to enable it, if possible.
|
||||
if (addrBind.IsIPv6()) {
|
||||
#ifdef IPV6_V6ONLY
|
||||
setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int));
|
||||
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int));
|
||||
#endif
|
||||
#ifdef WIN32
|
||||
int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED;
|
||||
setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int));
|
||||
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int));
|
||||
#endif
|
||||
}
|
||||
|
||||
if (::bind(hListenSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
|
||||
if (::bind(sock->Get(), (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
|
||||
{
|
||||
int nErr = WSAGetLastError();
|
||||
if (nErr == WSAEADDRINUSE)
|
||||
|
@ -2209,21 +2210,19 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
|
|||
else
|
||||
strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToString(), NetworkErrorString(nErr));
|
||||
LogPrintf("%s\n", strError.original);
|
||||
CloseSocket(hListenSocket);
|
||||
return false;
|
||||
}
|
||||
LogPrintf("Bound to %s\n", addrBind.ToString());
|
||||
|
||||
// Listen for incoming connections
|
||||
if (listen(hListenSocket, SOMAXCONN) == SOCKET_ERROR)
|
||||
if (listen(sock->Get(), SOMAXCONN) == SOCKET_ERROR)
|
||||
{
|
||||
strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError()));
|
||||
LogPrintf("%s\n", strError.original);
|
||||
CloseSocket(hListenSocket);
|
||||
return false;
|
||||
}
|
||||
|
||||
vhListenSocket.push_back(ListenSocket(hListenSocket, permissions));
|
||||
vhListenSocket.push_back(ListenSocket(sock->Release(), permissions));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
#ifndef WIN32
|
||||
#include <fcntl.h>
|
||||
|
@ -559,34 +561,28 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
|
|||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to create a socket file descriptor with specific properties in the
|
||||
* communications domain (address family) of the specified service.
|
||||
*
|
||||
* For details on the desired properties, see the inline comments in the source
|
||||
* code.
|
||||
*/
|
||||
SOCKET CreateSocket(const CService &addrConnect)
|
||||
std::unique_ptr<Sock> CreateSockTCP(const CService& address_family)
|
||||
{
|
||||
// Create a sockaddr from the specified service.
|
||||
struct sockaddr_storage sockaddr;
|
||||
socklen_t len = sizeof(sockaddr);
|
||||
if (!addrConnect.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
|
||||
LogPrintf("Cannot create socket for %s: unsupported network\n", addrConnect.ToString());
|
||||
return INVALID_SOCKET;
|
||||
if (!address_family.GetSockAddr((struct sockaddr*)&sockaddr, &len)) {
|
||||
LogPrintf("Cannot create socket for %s: unsupported network\n", address_family.ToString());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create a TCP socket in the address family of the specified service.
|
||||
SOCKET hSocket = socket(((struct sockaddr*)&sockaddr)->sa_family, SOCK_STREAM, IPPROTO_TCP);
|
||||
if (hSocket == INVALID_SOCKET)
|
||||
return INVALID_SOCKET;
|
||||
if (hSocket == INVALID_SOCKET) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Ensure that waiting for I/O on this socket won't result in undefined
|
||||
// behavior.
|
||||
if (!IsSelectableSocket(hSocket)) {
|
||||
CloseSocket(hSocket);
|
||||
LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n");
|
||||
return INVALID_SOCKET;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef SO_NOSIGPIPE
|
||||
|
@ -602,11 +598,14 @@ SOCKET CreateSocket(const CService &addrConnect)
|
|||
// Set the non-blocking option on the socket.
|
||||
if (!SetSocketNonBlocking(hSocket, true)) {
|
||||
CloseSocket(hSocket);
|
||||
LogPrintf("CreateSocket: Setting socket to non-blocking failed, error %s\n", NetworkErrorString(WSAGetLastError()));
|
||||
LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError()));
|
||||
return nullptr;
|
||||
}
|
||||
return hSocket;
|
||||
return std::make_unique<Sock>(hSocket);
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP;
|
||||
|
||||
template<typename... Args>
|
||||
static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) {
|
||||
std::string error_message = tfm::format(fmt, args...);
|
||||
|
|
|
@ -12,7 +12,10 @@
|
|||
#include <compat.h>
|
||||
#include <netaddress.h>
|
||||
#include <serialize.h>
|
||||
#include <util/sock.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -51,7 +54,19 @@ bool Lookup(const std::string& name, CService& addr, int portDefault, bool fAllo
|
|||
bool Lookup(const std::string& name, std::vector<CService>& vAddr, int portDefault, bool fAllowLookup, unsigned int nMaxSolutions);
|
||||
CService LookupNumeric(const std::string& name, int portDefault = 0);
|
||||
bool LookupSubNet(const std::string& strSubnet, CSubNet& subnet);
|
||||
SOCKET CreateSocket(const CService &addrConnect);
|
||||
|
||||
/**
|
||||
* Create a TCP socket in the given address family.
|
||||
* @param[in] address_family The socket is created in the same address family as this address.
|
||||
* @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure
|
||||
*/
|
||||
std::unique_ptr<Sock> CreateSockTCP(const CService& address_family);
|
||||
|
||||
/**
|
||||
* Socket factory. Defaults to `CreateSockTCP()`, but can be overridden by unit tests.
|
||||
*/
|
||||
extern std::function<std::unique_ptr<Sock>(const CService&)> CreateSock;
|
||||
|
||||
bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection);
|
||||
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed);
|
||||
/** Disable or enable blocking-mode for a socket */
|
||||
|
|
|
@ -6,12 +6,97 @@
|
|||
#include <logging.h>
|
||||
#include <tinyformat.h>
|
||||
#include <util/sock.h>
|
||||
#include <util/system.h>
|
||||
#include <util/time.h>
|
||||
|
||||
#include <codecvt>
|
||||
#include <cwchar>
|
||||
#include <locale>
|
||||
#include <string>
|
||||
|
||||
#ifdef USE_POLL
|
||||
#include <poll.h>
|
||||
#endif
|
||||
|
||||
Sock::Sock() : m_socket(INVALID_SOCKET) {}
|
||||
|
||||
Sock::Sock(SOCKET s) : m_socket(s) {}
|
||||
|
||||
Sock::Sock(Sock&& other)
|
||||
{
|
||||
m_socket = other.m_socket;
|
||||
other.m_socket = INVALID_SOCKET;
|
||||
}
|
||||
|
||||
Sock::~Sock() { Reset(); }
|
||||
|
||||
Sock& Sock::operator=(Sock&& other)
|
||||
{
|
||||
Reset();
|
||||
m_socket = other.m_socket;
|
||||
other.m_socket = INVALID_SOCKET;
|
||||
return *this;
|
||||
}
|
||||
|
||||
SOCKET Sock::Get() const { return m_socket; }
|
||||
|
||||
SOCKET Sock::Release()
|
||||
{
|
||||
const SOCKET s = m_socket;
|
||||
m_socket = INVALID_SOCKET;
|
||||
return s;
|
||||
}
|
||||
|
||||
void Sock::Reset() { CloseSocket(m_socket); }
|
||||
|
||||
ssize_t Sock::Send(const void* data, size_t len, int flags) const
|
||||
{
|
||||
return send(m_socket, static_cast<const char*>(data), len, flags);
|
||||
}
|
||||
|
||||
ssize_t Sock::Recv(void* buf, size_t len, int flags) const
|
||||
{
|
||||
return recv(m_socket, static_cast<char*>(buf), len, flags);
|
||||
}
|
||||
|
||||
bool Sock::Wait(std::chrono::milliseconds timeout, Event requested) const
|
||||
{
|
||||
#ifdef USE_POLL
|
||||
pollfd fd;
|
||||
fd.fd = m_socket;
|
||||
fd.events = 0;
|
||||
if (requested & RECV) {
|
||||
fd.events |= POLLIN;
|
||||
}
|
||||
if (requested & SEND) {
|
||||
fd.events |= POLLOUT;
|
||||
}
|
||||
|
||||
return poll(&fd, 1, count_milliseconds(timeout)) != SOCKET_ERROR;
|
||||
#else
|
||||
if (!IsSelectableSocket(m_socket)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
fd_set fdset_recv;
|
||||
fd_set fdset_send;
|
||||
FD_ZERO(&fdset_recv);
|
||||
FD_ZERO(&fdset_send);
|
||||
|
||||
if (requested & RECV) {
|
||||
FD_SET(m_socket, &fdset_recv);
|
||||
}
|
||||
|
||||
if (requested & SEND) {
|
||||
FD_SET(m_socket, &fdset_send);
|
||||
}
|
||||
|
||||
timeval timeout_struct = MillisToTimeval(timeout);
|
||||
|
||||
return select(m_socket + 1, &fdset_recv, &fdset_send, nullptr, &timeout_struct) != SOCKET_ERROR;
|
||||
#endif /* USE_POLL */
|
||||
}
|
||||
|
||||
#ifdef WIN32
|
||||
std::string NetworkErrorString(int err)
|
||||
{
|
||||
|
|
100
src/util/sock.h
100
src/util/sock.h
|
@ -7,8 +7,108 @@
|
|||
|
||||
#include <compat.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
|
||||
/**
|
||||
* RAII helper class that manages a socket. Mimics `std::unique_ptr`, but instead of a pointer it
|
||||
* contains a socket and closes it automatically when it goes out of scope.
|
||||
*/
|
||||
class Sock
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* Default constructor, creates an empty object that does nothing when destroyed.
|
||||
*/
|
||||
Sock();
|
||||
|
||||
/**
|
||||
* Take ownership of an existent socket.
|
||||
*/
|
||||
explicit Sock(SOCKET s);
|
||||
|
||||
/**
|
||||
* Copy constructor, disabled because closing the same socket twice is undesirable.
|
||||
*/
|
||||
Sock(const Sock&) = delete;
|
||||
|
||||
/**
|
||||
* Move constructor, grab the socket from another object and close ours (if set).
|
||||
*/
|
||||
Sock(Sock&& other);
|
||||
|
||||
/**
|
||||
* Destructor, close the socket or do nothing if empty.
|
||||
*/
|
||||
virtual ~Sock();
|
||||
|
||||
/**
|
||||
* Copy assignment operator, disabled because closing the same socket twice is undesirable.
|
||||
*/
|
||||
Sock& operator=(const Sock&) = delete;
|
||||
|
||||
/**
|
||||
* Move assignment operator, grab the socket from another object and close ours (if set).
|
||||
*/
|
||||
virtual Sock& operator=(Sock&& other);
|
||||
|
||||
/**
|
||||
* Get the value of the contained socket.
|
||||
* @return socket or INVALID_SOCKET if empty
|
||||
*/
|
||||
virtual SOCKET Get() const;
|
||||
|
||||
/**
|
||||
* Get the value of the contained socket and drop ownership. It will not be closed by the
|
||||
* destructor after this call.
|
||||
* @return socket or INVALID_SOCKET if empty
|
||||
*/
|
||||
virtual SOCKET Release();
|
||||
|
||||
/**
|
||||
* Close if non-empty.
|
||||
*/
|
||||
virtual void Reset();
|
||||
|
||||
/**
|
||||
* send(2) wrapper. Equivalent to `send(this->Get(), data, len, flags);`. Code that uses this
|
||||
* wrapper can be unit-tested if this method is overridden by a mock Sock implementation.
|
||||
*/
|
||||
virtual ssize_t Send(const void* data, size_t len, int flags) const;
|
||||
|
||||
/**
|
||||
* recv(2) wrapper. Equivalent to `recv(this->Get(), buf, len, flags);`. Code that uses this
|
||||
* wrapper can be unit-tested if this method is overridden by a mock Sock implementation.
|
||||
*/
|
||||
virtual ssize_t Recv(void* buf, size_t len, int flags) const;
|
||||
|
||||
using Event = uint8_t;
|
||||
|
||||
/**
|
||||
* If passed to `Wait()`, then it will wait for readiness to read from the socket.
|
||||
*/
|
||||
static constexpr Event RECV = 0b01;
|
||||
|
||||
/**
|
||||
* If passed to `Wait()`, then it will wait for readiness to send to the socket.
|
||||
*/
|
||||
static constexpr Event SEND = 0b10;
|
||||
|
||||
/**
|
||||
* Wait for readiness for input (recv) or output (send).
|
||||
* @param[in] timeout Wait this much for at least one of the requested events to occur.
|
||||
* @param[in] requested Wait for those events, bitwise-or of `RECV` and `SEND`.
|
||||
* @return true on success and false otherwise
|
||||
*/
|
||||
virtual bool Wait(std::chrono::milliseconds timeout, Event requested) const;
|
||||
|
||||
private:
|
||||
/**
|
||||
* Contained socket. `INVALID_SOCKET` designates the object is empty.
|
||||
*/
|
||||
SOCKET m_socket;
|
||||
};
|
||||
|
||||
/** Return readable error string for a network error code */
|
||||
std::string NetworkErrorString(int err);
|
||||
|
||||
|
|
|
@ -123,3 +123,8 @@ struct timeval MillisToTimeval(int64_t nTimeout)
|
|||
timeout.tv_usec = (nTimeout % 1000) * 1000;
|
||||
return timeout;
|
||||
}
|
||||
|
||||
struct timeval MillisToTimeval(std::chrono::milliseconds ms)
|
||||
{
|
||||
return MillisToTimeval(count_milliseconds(ms));
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ void UninterruptibleSleep(const std::chrono::microseconds& n);
|
|||
* interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI)
|
||||
*/
|
||||
inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); }
|
||||
inline int64_t count_milliseconds(std::chrono::milliseconds t) { return t.count(); }
|
||||
inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); }
|
||||
|
||||
/**
|
||||
|
@ -64,4 +65,9 @@ int64_t ParseISO8601DateTime(const std::string& str);
|
|||
*/
|
||||
struct timeval MillisToTimeval(int64_t nTimeout);
|
||||
|
||||
/**
|
||||
* Convert milliseconds to a struct timeval for e.g. select.
|
||||
*/
|
||||
struct timeval MillisToTimeval(std::chrono::milliseconds ms);
|
||||
|
||||
#endif // BITCOIN_UTIL_TIME_H
|
||||
|
|
Loading…
Add table
Reference in a new issue