Skip to content

[llvm][Support] Add UNIX socket support #73603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions llvm/include/llvm/Support/raw_ostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DataTypes.h"
#include "llvm/Support/Threading.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -615,6 +616,8 @@ class raw_fd_stream : public raw_fd_ostream {
/// immediately destroyed.
raw_fd_stream(StringRef Filename, std::error_code &EC);

raw_fd_stream(int fd, bool shouldClose);

/// This reads the \p Size bytes into a buffer pointed by \p Ptr.
///
/// \param Ptr The start of the buffer to hold data to be read.
Expand All @@ -630,6 +633,54 @@ class raw_fd_stream : public raw_fd_ostream {
static bool classof(const raw_ostream *OS);
};

//===----------------------------------------------------------------------===//
// Socket Streams
//===----------------------------------------------------------------------===//

/// A raw stream for sockets reading/writing

class raw_socket_stream;

// Make sure that calls to WSAStartup and WSACleanup are balanced.
#ifdef _WIN32
class WSABalancer {
public:
WSABalancer();
~WSABalancer();
};
#endif // _WIN32

class ListeningSocket {
int FD;
std::string SocketPath;
ListeningSocket(int SocketFD, StringRef SocketPath);
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32

public:
static Expected<ListeningSocket> createUnix(
StringRef SocketPath,
int MaxBacklog = llvm::hardware_concurrency().compute_thread_count());
Expected<std::unique_ptr<raw_socket_stream>> accept();
ListeningSocket(ListeningSocket &&LS);
~ListeningSocket();
};
class raw_socket_stream : public raw_fd_stream {
uint64_t current_pos() const override { return 0; }
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32

public:
raw_socket_stream(int SocketFD);
/// Create a \p raw_socket_stream connected to the Unix domain socket at \p
/// SocketPath.
static Expected<std::unique_ptr<raw_socket_stream>>
createConnectedUnix(StringRef SocketPath);
~raw_socket_stream();
};

//===----------------------------------------------------------------------===//
// Output Stream Adaptors
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ endif()
if( MSVC OR MINGW )
# libuuid required for FOLDERID_Profile usage in lib/Support/Windows/Path.inc.
# advapi32 required for CryptAcquireContextW in lib/Support/Windows/Path.inc.
set(system_libs ${system_libs} psapi shell32 ole32 uuid advapi32)
set(system_libs ${system_libs} psapi shell32 ole32 uuid advapi32 Ws2_32)
elseif( CMAKE_HOST_UNIX )
if( HAVE_LIBRT )
set(system_libs ${system_libs} rt)
Expand Down
158 changes: 157 additions & 1 deletion llvm/lib/Support/raw_ostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/Config/config.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Duration.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Format.h"
Expand All @@ -23,11 +24,17 @@
#include "llvm/Support/NativeFormatting.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/Threading.h"
#include <algorithm>
#include <cerrno>
#include <cstdio>
#include <sys/stat.h>

#ifndef _WIN32
#include <sys/socket.h>
#include <sys/un.h>
#endif // _WIN32

// <fcntl.h> may provide O_BINARY.
#if defined(HAVE_FCNTL_H)
# include <fcntl.h>
Expand Down Expand Up @@ -58,6 +65,13 @@
#include "llvm/Support/ConvertUTF.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/Windows/WindowsSupport.h"
// winsock2.h must be included before afunix.h. Briefly turn off clang-format to
// avoid error.
// clang-format off
#include <winsock2.h>
#include <afunix.h>
// clang-format on
#include <io.h>
#endif

using namespace llvm;
Expand Down Expand Up @@ -644,7 +658,7 @@ raw_fd_ostream::raw_fd_ostream(int fd, bool shouldClose, bool unbuffered,
// Check if this is a console device. This is not equivalent to isatty.
IsWindowsConsole =
::GetFileType((HANDLE)::_get_osfhandle(fd)) == FILE_TYPE_CHAR;
#endif
#endif // _WIN32

// Get the starting position.
off_t loc = ::lseek(FD, 0, SEEK_CUR);
Expand Down Expand Up @@ -928,6 +942,9 @@ raw_fd_stream::raw_fd_stream(StringRef Filename, std::error_code &EC)
EC = std::make_error_code(std::errc::invalid_argument);
}

raw_fd_stream::raw_fd_stream(int fd, bool shouldClose)
: raw_fd_ostream(fd, shouldClose, false, OStreamKind::OK_FDStream) {}

ssize_t raw_fd_stream::read(char *Ptr, size_t Size) {
assert(get_fd() >= 0 && "File already closed.");
ssize_t Ret = ::read(get_fd(), (void *)Ptr, Size);
Expand All @@ -942,6 +959,145 @@ bool raw_fd_stream::classof(const raw_ostream *OS) {
return OS->get_kind() == OStreamKind::OK_FDStream;
}

//===----------------------------------------------------------------------===//
// raw_socket_stream
//===----------------------------------------------------------------------===//

#ifdef _WIN32
WSABalancer::WSABalancer() {
WSADATA WsaData = {0};
if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) {
llvm::report_fatal_error("WSAStartup failed");
}
}

WSABalancer::~WSABalancer() { WSACleanup(); }

#endif // _WIN32

static std::error_code getLastSocketErrorCode() {
#ifdef _WIN32
return std::error_code(::WSAGetLastError(), std::system_category());
#else
return std::error_code(errno, std::system_category());
#endif
}

ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath)
: FD(SocketFD), SocketPath(SocketPath) {}

ListeningSocket::ListeningSocket(ListeningSocket &&LS)
: FD(LS.FD), SocketPath(LS.SocketPath) {
LS.FD = -1;
}

Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
int MaxBacklog) {

#ifdef _WIN32
WSABalancer _;
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == INVALID_SOCKET) {
#else
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == -1) {
#endif
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"socket create failed");
}

struct sockaddr_un Addr;
memset(&Addr, 0, sizeof(Addr));
Addr.sun_family = AF_UNIX;
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);

if (bind(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
std::error_code Err = getLastSocketErrorCode();
if (Err == std::errc::address_in_use)
::close(MaybeWinsocket);
return llvm::make_error<StringError>(Err, "Bind error");
}
if (listen(MaybeWinsocket, MaxBacklog) == -1) {
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Listen error");
}
int UnixSocket;
#ifdef _WIN32
UnixSocket = _open_osfhandle(MaybeWinsocket, 0);
#else
UnixSocket = MaybeWinsocket;
#endif // _WIN32
ListeningSocket ListenSocket(UnixSocket, SocketPath);
return ListenSocket;
}

Expected<std::unique_ptr<raw_socket_stream>> ListeningSocket::accept() {
int AcceptFD;
#ifdef _WIN32
SOCKET WinServerSock = _get_osfhandle(FD);
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
#else
AcceptFD = ::accept(FD, NULL, NULL);
#endif //_WIN32
if (AcceptFD == -1)
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Accept failed");
return std::make_unique<raw_socket_stream>(AcceptFD);
}

ListeningSocket::~ListeningSocket() {
if (FD == -1)
return;
::close(FD);
unlink(SocketPath.c_str());
}

static Expected<int> GetSocketFD(StringRef SocketPath) {
#ifdef _WIN32
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == INVALID_SOCKET) {
#else
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == -1) {
#endif // _WIN32
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Create socket failed");
}

struct sockaddr_un Addr;
memset(&Addr, 0, sizeof(Addr));
Addr.sun_family = AF_UNIX;
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);

int status = connect(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr));
if (status == -1) {
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Connect socket failed");
}
#ifdef _WIN32
return _open_osfhandle(MaybeWinsocket, 0);
#else
return MaybeWinsocket;
#endif // _WIN32
}

raw_socket_stream::raw_socket_stream(int SocketFD)
: raw_fd_stream(SocketFD, true) {}

Expected<std::unique_ptr<raw_socket_stream>>
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32
Expected<int> FD = GetSocketFD(SocketPath);
if (!FD)
return FD.takeError();
return std::make_unique<raw_socket_stream>(*FD);
}

raw_socket_stream::~raw_socket_stream() {}

//===----------------------------------------------------------------------===//
// raw_string_ostream
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ add_llvm_unittest(SupportTests
raw_ostream_test.cpp
raw_pwrite_stream_test.cpp
raw_sha1_ostream_test.cpp
raw_socket_stream_test.cpp
xxhashTest.cpp

DEPENDS
Expand Down
52 changes: 52 additions & 0 deletions llvm/unittests/Support/raw_socket_stream_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"
#include <future>
#include <iostream>
#include <stdlib.h>

using namespace llvm;

namespace {

TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("test_raw_socket_stream.sock", SocketPath,
true);

char Bytes[8];

Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());

ListeningSocket ServerListener = std::move(*MaybeServerListener);

Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
raw_socket_stream::createConnectedUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());

raw_socket_stream &Client = **MaybeClient;

Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept();
ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());

raw_socket_stream &Server = **MaybeServer;

Client << "01234567";
Client.flush();

ssize_t BytesRead = Server.read(Bytes, 8);

std::string string(Bytes, 8);

ASSERT_EQ(8, BytesRead);
ASSERT_EQ("01234567", string);
}
} // namespace