netbase: extend CreateSock() to support creating arbitrary sockets

Allow the callers of `CreateSock()` to pass all 3 arguments to the
`socket(2)` syscall. This makes it possible to create sockets of
any domain/type/protocol.
This commit is contained in:
Vasil Dimov 2024-05-30 13:38:26 +02:00
parent 0b94fb8720
commit 1245d1388b
No known key found for this signature in database
GPG key ID: 54DF06F64B55CBBF
6 changed files with 33 additions and 33 deletions

View file

@ -3029,7 +3029,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
return false; return false;
} }
std::unique_ptr<Sock> sock = CreateSock(addrBind.GetSAFamily()); std::unique_ptr<Sock> sock = CreateSock(addrBind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP);
if (!sock) { if (!sock) {
strError = strprintf(Untranslated("Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError())); strError = strprintf(Untranslated("Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError()));
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original);

View file

@ -487,24 +487,23 @@ bool Socks5(const std::string& strDest, uint16_t port, const ProxyCredentials* a
} }
} }
std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family) std::unique_ptr<Sock> CreateSockOS(int domain, int type, int protocol)
{ {
// Not IPv4, IPv6 or UNIX // Not IPv4, IPv6 or UNIX
if (address_family == AF_UNSPEC) return nullptr; if (domain == AF_UNSPEC) return nullptr;
int protocol{IPPROTO_TCP};
#if HAVE_SOCKADDR_UN
if (address_family == AF_UNIX) protocol = 0;
#endif
// Create a socket in the specified address family. // Create a socket in the specified address family.
SOCKET hSocket = socket(address_family, SOCK_STREAM, protocol); SOCKET hSocket = socket(domain, type, protocol);
if (hSocket == INVALID_SOCKET) { if (hSocket == INVALID_SOCKET) {
return nullptr; return nullptr;
} }
auto sock = std::make_unique<Sock>(hSocket); auto sock = std::make_unique<Sock>(hSocket);
if (domain != AF_INET && domain != AF_INET6 && domain != AF_UNIX) {
return sock;
}
// Ensure that waiting for I/O on this socket won't result in undefined // Ensure that waiting for I/O on this socket won't result in undefined
// behavior. // behavior.
if (!sock->IsSelectable()) { if (!sock->IsSelectable()) {
@ -529,18 +528,21 @@ std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family)
} }
#if HAVE_SOCKADDR_UN #if HAVE_SOCKADDR_UN
if (address_family == AF_UNIX) return sock; if (domain == AF_UNIX) return sock;
#endif #endif
if (protocol == IPPROTO_TCP) {
// Set the no-delay option (disable Nagle's algorithm) on the TCP socket. // Set the no-delay option (disable Nagle's algorithm) on the TCP socket.
const int on{1}; const int on{1};
if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n"); LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n");
} }
}
return sock; return sock;
} }
std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock = CreateSockOS; std::function<std::unique_ptr<Sock>(int, int, int)> CreateSock = CreateSockOS;
template<typename... Args> template<typename... Args>
static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) { static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) {
@ -609,7 +611,7 @@ static bool ConnectToSocket(const Sock& sock, struct sockaddr* sockaddr, socklen
std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connection) std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connection)
{ {
auto sock = CreateSock(dest.GetSAFamily()); auto sock = CreateSock(dest.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP);
if (!sock) { if (!sock) {
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", dest.ToStringAddrPort()); LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", dest.ToStringAddrPort());
return {}; return {};
@ -637,7 +639,7 @@ std::unique_ptr<Sock> Proxy::Connect() const
if (!m_is_unix_socket) return ConnectDirectly(proxy, /*manual_connection=*/true); if (!m_is_unix_socket) return ConnectDirectly(proxy, /*manual_connection=*/true);
#if HAVE_SOCKADDR_UN #if HAVE_SOCKADDR_UN
auto sock = CreateSock(AF_UNIX); auto sock = CreateSock(AF_UNIX, SOCK_STREAM, 0);
if (!sock) { if (!sock) {
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", m_unix_socket_path); LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", m_unix_socket_path);
return {}; return {};

View file

@ -262,16 +262,18 @@ CService LookupNumeric(const std::string& name, uint16_t portDefault = 0, DNSLoo
CSubNet LookupSubNet(const std::string& subnet_str); CSubNet LookupSubNet(const std::string& subnet_str);
/** /**
* Create a TCP or UNIX socket in the given address family. * Create a real socket from the operating system.
* @param[in] address_family to use for the socket. * @param[in] domain Communications domain, first argument to the socket(2) syscall.
* @param[in] type Type of the socket, second argument to the socket(2) syscall.
* @param[in] protocol The particular protocol to be used with the socket, third argument to the socket(2) syscall.
* @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure * @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure
*/ */
std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family); std::unique_ptr<Sock> CreateSockOS(int domain, int type, int protocol);
/** /**
* Socket factory. Defaults to `CreateSockOS()`, but can be overridden by unit tests. * Socket factory. Defaults to `CreateSockOS()`, but can be overridden by unit tests.
*/ */
extern std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock; extern std::function<std::unique_ptr<Sock>(int, int, int)> CreateSock;
/** /**
* Create a socket and try to connect to the specified service. * Create a socket and try to connect to the specified service.

View file

@ -101,8 +101,9 @@ void ResetCoverageCounters() {}
void initialize() void initialize()
{ {
// Terminate immediately if a fuzzing harness ever tries to create a TCP socket. // Terminate immediately if a fuzzing harness ever tries to create a socket.
CreateSock = [](const sa_family_t&) -> std::unique_ptr<Sock> { std::terminate(); }; // Individual tests can override this by pointing CreateSock to a mocked alternative.
CreateSock = [](int, int, int) -> std::unique_ptr<Sock> { std::terminate(); };
// Terminate immediately if a fuzzing harness ever tries to perform a DNS lookup. // Terminate immediately if a fuzzing harness ever tries to perform a DNS lookup.
g_dns_lookup = [](const std::string& name, bool allow_lookup) { g_dns_lookup = [](const std::string& name, bool allow_lookup) {

View file

@ -27,7 +27,7 @@ FUZZ_TARGET(i2p, .init = initialize_i2p)
// Mock CreateSock() to create FuzzedSock. // Mock CreateSock() to create FuzzedSock.
auto CreateSockOrig = CreateSock; auto CreateSockOrig = CreateSock;
CreateSock = [&fuzzed_data_provider](const sa_family_t&) { CreateSock = [&fuzzed_data_provider](int, int, int) {
return std::make_unique<FuzzedSock>(fuzzed_data_provider); return std::make_unique<FuzzedSock>(fuzzed_data_provider);
}; };

View file

@ -39,15 +39,14 @@ public:
private: private:
const BCLog::Level m_prev_log_level; const BCLog::Level m_prev_log_level;
const std::function<std::unique_ptr<Sock>(const sa_family_t&)> m_create_sock_orig; const decltype(CreateSock) m_create_sock_orig;
}; };
BOOST_FIXTURE_TEST_SUITE(i2p_tests, EnvTestingSetup) BOOST_FIXTURE_TEST_SUITE(i2p_tests, EnvTestingSetup)
BOOST_AUTO_TEST_CASE(unlimited_recv) BOOST_AUTO_TEST_CASE(unlimited_recv)
{ {
// Mock CreateSock() to create MockSock. CreateSock = [](int, int, int) {
CreateSock = [](const sa_family_t&) {
return std::make_unique<StaticContentsSock>(std::string(i2p::sam::MAX_MSG_SIZE + 1, 'a')); return std::make_unique<StaticContentsSock>(std::string(i2p::sam::MAX_MSG_SIZE + 1, 'a'));
}; };
@ -69,7 +68,7 @@ BOOST_AUTO_TEST_CASE(unlimited_recv)
BOOST_AUTO_TEST_CASE(listen_ok_accept_fail) BOOST_AUTO_TEST_CASE(listen_ok_accept_fail)
{ {
size_t num_sockets{0}; size_t num_sockets{0};
CreateSock = [&num_sockets](const sa_family_t&) { CreateSock = [&num_sockets](int, int, int) {
// clang-format off // clang-format off
++num_sockets; ++num_sockets;
// First socket is the control socket for creating the session. // First socket is the control socket for creating the session.
@ -133,9 +132,7 @@ BOOST_AUTO_TEST_CASE(listen_ok_accept_fail)
BOOST_AUTO_TEST_CASE(damaged_private_key) BOOST_AUTO_TEST_CASE(damaged_private_key)
{ {
const auto CreateSockOrig = CreateSock; CreateSock = [](int, int, int) {
CreateSock = [](const sa_family_t&) {
return std::make_unique<StaticContentsSock>("HELLO REPLY RESULT=OK VERSION=3.1\n" return std::make_unique<StaticContentsSock>("HELLO REPLY RESULT=OK VERSION=3.1\n"
"SESSION STATUS RESULT=OK DESTINATION=\n"); "SESSION STATUS RESULT=OK DESTINATION=\n");
}; };
@ -172,8 +169,6 @@ BOOST_AUTO_TEST_CASE(damaged_private_key)
BOOST_CHECK(!session.Connect(CService{}, conn, proxy_error)); BOOST_CHECK(!session.Connect(CService{}, conn, proxy_error));
} }
} }
CreateSock = CreateSockOrig;
} }
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()