Skip to content

[llvm][Support] Implement raw_socket_stream::read with optional timeout #92308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions llvm/include/llvm/Support/raw_socket_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@ class ListeningSocket {
/// Accepts an incoming connection on the listening socket. This method can
/// optionally either block until a connection is available or timeout after a
/// specified amount of time has passed. By default the method will block
/// until the socket has recieved a connection.
/// until the socket has recieved a connection. If the accept timesout this
/// method will return std::errc:timed_out
///
/// \param Timeout An optional timeout duration in milliseconds. Setting
/// Timeout to -1 causes accept to block indefinitely
/// Timeout to a negative number causes ::accept to block indefinitely
///
Expected<std::unique_ptr<raw_socket_stream>>
accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
Expected<std::unique_ptr<raw_socket_stream>> accept(
const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));

/// Creates a listening socket bound to the specified file system path.
/// Handles the socket creation, binding, and immediately starts listening for
Expand All @@ -124,11 +125,28 @@ class raw_socket_stream : public raw_fd_stream {

public:
raw_socket_stream(int SocketFD);
~raw_socket_stream();

/// Create a \p raw_socket_stream connected to the UNIX domain socket at \p
/// SocketPath.
static Expected<std::unique_ptr<raw_socket_stream>>
createConnectedUnix(StringRef SocketPath);
~raw_socket_stream();

/// Attempt to read from the raw_socket_stream's file descriptor.
///
/// This method can optionally either block until data is read or an error has
/// occurred or timeout after a specified amount of time has passed. By
/// default the method will block until the socket has read data or
/// encountered an error. If the read times out this method will return
/// std::errc:timed_out
///
/// \param Ptr The start of the buffer that will hold any read data
/// \param Size The number of bytes to be read
/// \param Timeout An optional timeout duration in milliseconds
///
ssize_t read(
char *Ptr, size_t Size,
const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
};

} // end namespace llvm
Expand Down
132 changes: 82 additions & 50 deletions llvm/lib/Support/raw_socket_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <atomic>
#include <fcntl.h>
#include <functional>
#include <thread>

#ifndef _WIN32
Expand Down Expand Up @@ -177,70 +178,89 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
#endif // _WIN32
}

Expected<std::unique_ptr<raw_socket_stream>>
ListeningSocket::accept(std::chrono::milliseconds Timeout) {

struct pollfd FDs[2];
FDs[0].events = POLLIN;
// If a file descriptor being monitored by ::poll is closed by another thread,
// the result is unspecified. In the case ::poll does not unblock and return,
// when ActiveFD is closed, you can provide another file descriptor via CancelFD
// that when written to will cause poll to return. Typically CancelFD is the
// read end of a unidirectional pipe.
//
// Timeout should be -1 to block indefinitly
//
// getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
static std::error_code
manageTimeout(const std::chrono::milliseconds &Timeout,
const std::function<int()> &getActiveFD,
const std::optional<int> &CancelFD = std::nullopt) {
struct pollfd FD[2];
FD[0].events = POLLIN;
#ifdef _WIN32
SOCKET WinServerSock = _get_osfhandle(FD);
FDs[0].fd = WinServerSock;
SOCKET WinServerSock = _get_osfhandle(getActiveFD());
FD[0].fd = WinServerSock;
#else
FDs[0].fd = FD;
FD[0].fd = getActiveFD();
#endif
FDs[1].events = POLLIN;
FDs[1].fd = PipeFD[0];

// Keep track of how much time has passed in case poll is interupted by a
// signal and needs to be recalled
int RemainingTime = Timeout.count();
std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
int PollStatus = -1;

while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
if (Timeout.count() != -1)
RemainingTime -= ElapsedTime.count();
uint8_t FDCount = 1;
if (CancelFD.has_value()) {
FD[1].events = POLLIN;
FD[1].fd = CancelFD.value();
FDCount++;
}

auto Start = std::chrono::steady_clock::now();
// Keep track of how much time has passed in case ::poll or WSAPoll are
// interupted by a signal and need to be recalled
auto Start = std::chrono::steady_clock::now();
auto RemainingTimeout = Timeout;
int PollStatus = 0;
do {
// If Timeout is -1 then poll should block and RemainingTimeout does not
// need to be recalculated
if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) {
auto TotalElapsedTime =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - Start);

if (TotalElapsedTime >= Timeout)
return std::make_error_code(std::errc::operation_would_block);

RemainingTimeout = Timeout - TotalElapsedTime;
}
#ifdef _WIN32
PollStatus = WSAPoll(FDs, 2, RemainingTime);
PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count());
} while (PollStatus == SOCKET_ERROR &&
getLastSocketErrorCode() == std::errc::interrupted);
#else
PollStatus = ::poll(FDs, 2, RemainingTime);
PollStatus = ::poll(FD, FDCount, RemainingTimeout.count());
} while (PollStatus == -1 &&
getLastSocketErrorCode() == std::errc::interrupted);
#endif
// If FD equals -1 then ListeningSocket::shutdown has been called and it is
// appropriate to return operation_canceled
if (FD.load() == -1)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::operation_canceled),
"Accept canceled");

// If ActiveFD equals -1 or CancelFD has data to be read then the operation
// has been canceled by another thread
if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN))
return std::make_error_code(std::errc::operation_canceled);
#if _WIN32
if (PollStatus == SOCKET_ERROR) {
if (PollStatus == SOCKET_ERROR)
#else
if (PollStatus == -1) {
if (PollStatus == -1)
#endif
std::error_code PollErrCode = getLastSocketErrorCode();
// Ignore EINTR (signal occured before any request event) and retry
if (PollErrCode != std::errc::interrupted)
return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
}
if (PollStatus == 0)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::timed_out),
"No client requests within timeout window");

if (FDs[0].revents & POLLNVAL)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::bad_file_descriptor));
return getLastSocketErrorCode();
if (PollStatus == 0)
return std::make_error_code(std::errc::timed_out);
if (FD[0].revents & POLLNVAL)
return std::make_error_code(std::errc::bad_file_descriptor);
return std::error_code();
}

auto Stop = std::chrono::steady_clock::now();
ElapsedTime +=
std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
}
Expected<std::unique_ptr<raw_socket_stream>>
ListeningSocket::accept(const std::chrono::milliseconds &Timeout) {
auto getActiveFD = [this]() -> int { return FD; };
std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
if (TimeoutErr)
return llvm::make_error<StringError>(TimeoutErr, "Timeout error");

int AcceptFD;
#ifdef _WIN32
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
#else
AcceptFD = ::accept(FD, NULL, NULL);
Expand Down Expand Up @@ -295,6 +315,8 @@ ListeningSocket::~ListeningSocket() {
raw_socket_stream::raw_socket_stream(int SocketFD)
: raw_fd_stream(SocketFD, true) {}

raw_socket_stream::~raw_socket_stream() {}

Expected<std::unique_ptr<raw_socket_stream>>
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
#ifdef _WIN32
Expand All @@ -306,4 +328,14 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
return std::make_unique<raw_socket_stream>(*FD);
}

raw_socket_stream::~raw_socket_stream() {}
ssize_t raw_socket_stream::read(char *Ptr, size_t Size,
const std::chrono::milliseconds &Timeout) {
auto getActiveFD = [this]() -> int { return this->get_fd(); };
std::error_code Err = manageTimeout(Timeout, getActiveFD);
// Mimic raw_fd_stream::read error handling behavior
if (Err) {
raw_fd_stream::error_detected(Err);
return -1;
}
return raw_fd_stream::read(Ptr, Size);
}
45 changes: 39 additions & 6 deletions llvm/unittests/Support/raw_socket_stream_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,50 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
ssize_t BytesRead = Server.read(Bytes, 8);

std::string string(Bytes, 8);
ASSERT_EQ(Server.has_error(), false);

ASSERT_EQ(8, BytesRead);
ASSERT_EQ("01234567", string);
}

TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();

SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true);
llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true);

// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());

Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);

Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
raw_socket_stream::createConnectedUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());

Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept();
ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
raw_socket_stream &Server = **MaybeServer;

char Bytes[8];
ssize_t BytesRead = Server.read(Bytes, 8, std::chrono::milliseconds(100));
ASSERT_EQ(BytesRead, -1);
ASSERT_EQ(Server.has_error(), true);
ASSERT_EQ(Server.error(), std::errc::timed_out);
Server.clear_error();
}

TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();

SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true);

// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());
Expand All @@ -82,19 +115,19 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);

std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept(Timeout);
ServerListener.accept(std::chrono::milliseconds(100));
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
std::errc::timed_out);
}

TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
if (!hasUnixSocketSupport())
GTEST_SKIP();

SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath,
true);

// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());
Expand Down