Skip to content

Commit 76321b9

Browse files
authored
[llvm][Support] Implement raw_socket_stream::read with optional timeout (#92308)
This PR implements `raw_socket_stream::read`, which overloads the base class `raw_fd_stream::read`. `raw_socket_stream::read` provides a way to timeout the underlying `::read`. The timeout functionality was not added to `raw_fd_stream::read` to avoid needlessly increasing compile times and allow for convenient code reuse with `raw_socket_stream::accept`, which also requires timeout functionality. This PR supports the module build daemon and will help guarantee it never becomes a zombie process.
1 parent 324fea9 commit 76321b9

File tree

3 files changed

+144
-61
lines changed

3 files changed

+144
-61
lines changed

llvm/include/llvm/Support/raw_socket_stream.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,14 @@ class ListeningSocket {
9292
/// Accepts an incoming connection on the listening socket. This method can
9393
/// optionally either block until a connection is available or timeout after a
9494
/// specified amount of time has passed. By default the method will block
95-
/// until the socket has recieved a connection.
95+
/// until the socket has recieved a connection. If the accept timesout this
96+
/// method will return std::errc:timed_out
9697
///
9798
/// \param Timeout An optional timeout duration in milliseconds. Setting
98-
/// Timeout to -1 causes accept to block indefinitely
99+
/// Timeout to a negative number causes ::accept to block indefinitely
99100
///
100-
Expected<std::unique_ptr<raw_socket_stream>>
101-
accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
101+
Expected<std::unique_ptr<raw_socket_stream>> accept(
102+
const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
102103

103104
/// Creates a listening socket bound to the specified file system path.
104105
/// Handles the socket creation, binding, and immediately starts listening for
@@ -124,11 +125,28 @@ class raw_socket_stream : public raw_fd_stream {
124125

125126
public:
126127
raw_socket_stream(int SocketFD);
128+
~raw_socket_stream();
129+
127130
/// Create a \p raw_socket_stream connected to the UNIX domain socket at \p
128131
/// SocketPath.
129132
static Expected<std::unique_ptr<raw_socket_stream>>
130133
createConnectedUnix(StringRef SocketPath);
131-
~raw_socket_stream();
134+
135+
/// Attempt to read from the raw_socket_stream's file descriptor.
136+
///
137+
/// This method can optionally either block until data is read or an error has
138+
/// occurred or timeout after a specified amount of time has passed. By
139+
/// default the method will block until the socket has read data or
140+
/// encountered an error. If the read times out this method will return
141+
/// std::errc:timed_out
142+
///
143+
/// \param Ptr The start of the buffer that will hold any read data
144+
/// \param Size The number of bytes to be read
145+
/// \param Timeout An optional timeout duration in milliseconds
146+
///
147+
ssize_t read(
148+
char *Ptr, size_t Size,
149+
const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
132150
};
133151

134152
} // end namespace llvm

llvm/lib/Support/raw_socket_stream.cpp

Lines changed: 82 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <atomic>
2020
#include <fcntl.h>
21+
#include <functional>
2122
#include <thread>
2223

2324
#ifndef _WIN32
@@ -177,70 +178,89 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
177178
#endif // _WIN32
178179
}
179180

180-
Expected<std::unique_ptr<raw_socket_stream>>
181-
ListeningSocket::accept(std::chrono::milliseconds Timeout) {
182-
183-
struct pollfd FDs[2];
184-
FDs[0].events = POLLIN;
181+
// If a file descriptor being monitored by ::poll is closed by another thread,
182+
// the result is unspecified. In the case ::poll does not unblock and return,
183+
// when ActiveFD is closed, you can provide another file descriptor via CancelFD
184+
// that when written to will cause poll to return. Typically CancelFD is the
185+
// read end of a unidirectional pipe.
186+
//
187+
// Timeout should be -1 to block indefinitly
188+
//
189+
// getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
190+
static std::error_code
191+
manageTimeout(const std::chrono::milliseconds &Timeout,
192+
const std::function<int()> &getActiveFD,
193+
const std::optional<int> &CancelFD = std::nullopt) {
194+
struct pollfd FD[2];
195+
FD[0].events = POLLIN;
185196
#ifdef _WIN32
186-
SOCKET WinServerSock = _get_osfhandle(FD);
187-
FDs[0].fd = WinServerSock;
197+
SOCKET WinServerSock = _get_osfhandle(getActiveFD());
198+
FD[0].fd = WinServerSock;
188199
#else
189-
FDs[0].fd = FD;
200+
FD[0].fd = getActiveFD();
190201
#endif
191-
FDs[1].events = POLLIN;
192-
FDs[1].fd = PipeFD[0];
193-
194-
// Keep track of how much time has passed in case poll is interupted by a
195-
// signal and needs to be recalled
196-
int RemainingTime = Timeout.count();
197-
std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
198-
int PollStatus = -1;
199-
200-
while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
201-
if (Timeout.count() != -1)
202-
RemainingTime -= ElapsedTime.count();
202+
uint8_t FDCount = 1;
203+
if (CancelFD.has_value()) {
204+
FD[1].events = POLLIN;
205+
FD[1].fd = CancelFD.value();
206+
FDCount++;
207+
}
203208

204-
auto Start = std::chrono::steady_clock::now();
209+
// Keep track of how much time has passed in case ::poll or WSAPoll are
210+
// interupted by a signal and need to be recalled
211+
auto Start = std::chrono::steady_clock::now();
212+
auto RemainingTimeout = Timeout;
213+
int PollStatus = 0;
214+
do {
215+
// If Timeout is -1 then poll should block and RemainingTimeout does not
216+
// need to be recalculated
217+
if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) {
218+
auto TotalElapsedTime =
219+
std::chrono::duration_cast<std::chrono::milliseconds>(
220+
std::chrono::steady_clock::now() - Start);
221+
222+
if (TotalElapsedTime >= Timeout)
223+
return std::make_error_code(std::errc::operation_would_block);
224+
225+
RemainingTimeout = Timeout - TotalElapsedTime;
226+
}
205227
#ifdef _WIN32
206-
PollStatus = WSAPoll(FDs, 2, RemainingTime);
228+
PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count());
229+
} while (PollStatus == SOCKET_ERROR &&
230+
getLastSocketErrorCode() == std::errc::interrupted);
207231
#else
208-
PollStatus = ::poll(FDs, 2, RemainingTime);
232+
PollStatus = ::poll(FD, FDCount, RemainingTimeout.count());
233+
} while (PollStatus == -1 &&
234+
getLastSocketErrorCode() == std::errc::interrupted);
209235
#endif
210-
// If FD equals -1 then ListeningSocket::shutdown has been called and it is
211-
// appropriate to return operation_canceled
212-
if (FD.load() == -1)
213-
return llvm::make_error<StringError>(
214-
std::make_error_code(std::errc::operation_canceled),
215-
"Accept canceled");
216236

237+
// If ActiveFD equals -1 or CancelFD has data to be read then the operation
238+
// has been canceled by another thread
239+
if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN))
240+
return std::make_error_code(std::errc::operation_canceled);
217241
#if _WIN32
218-
if (PollStatus == SOCKET_ERROR) {
242+
if (PollStatus == SOCKET_ERROR)
219243
#else
220-
if (PollStatus == -1) {
244+
if (PollStatus == -1)
221245
#endif
222-
std::error_code PollErrCode = getLastSocketErrorCode();
223-
// Ignore EINTR (signal occured before any request event) and retry
224-
if (PollErrCode != std::errc::interrupted)
225-
return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
226-
}
227-
if (PollStatus == 0)
228-
return llvm::make_error<StringError>(
229-
std::make_error_code(std::errc::timed_out),
230-
"No client requests within timeout window");
231-
232-
if (FDs[0].revents & POLLNVAL)
233-
return llvm::make_error<StringError>(
234-
std::make_error_code(std::errc::bad_file_descriptor));
246+
return getLastSocketErrorCode();
247+
if (PollStatus == 0)
248+
return std::make_error_code(std::errc::timed_out);
249+
if (FD[0].revents & POLLNVAL)
250+
return std::make_error_code(std::errc::bad_file_descriptor);
251+
return std::error_code();
252+
}
235253

236-
auto Stop = std::chrono::steady_clock::now();
237-
ElapsedTime +=
238-
std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
239-
}
254+
Expected<std::unique_ptr<raw_socket_stream>>
255+
ListeningSocket::accept(const std::chrono::milliseconds &Timeout) {
256+
auto getActiveFD = [this]() -> int { return FD; };
257+
std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
258+
if (TimeoutErr)
259+
return llvm::make_error<StringError>(TimeoutErr, "Timeout error");
240260

241261
int AcceptFD;
242262
#ifdef _WIN32
243-
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
263+
SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
244264
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
245265
#else
246266
AcceptFD = ::accept(FD, NULL, NULL);
@@ -295,6 +315,8 @@ ListeningSocket::~ListeningSocket() {
295315
raw_socket_stream::raw_socket_stream(int SocketFD)
296316
: raw_fd_stream(SocketFD, true) {}
297317

318+
raw_socket_stream::~raw_socket_stream() {}
319+
298320
Expected<std::unique_ptr<raw_socket_stream>>
299321
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
300322
#ifdef _WIN32
@@ -306,4 +328,14 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
306328
return std::make_unique<raw_socket_stream>(*FD);
307329
}
308330

309-
raw_socket_stream::~raw_socket_stream() {}
331+
ssize_t raw_socket_stream::read(char *Ptr, size_t Size,
332+
const std::chrono::milliseconds &Timeout) {
333+
auto getActiveFD = [this]() -> int { return this->get_fd(); };
334+
std::error_code Err = manageTimeout(Timeout, getActiveFD);
335+
// Mimic raw_fd_stream::read error handling behavior
336+
if (Err) {
337+
raw_fd_stream::error_detected(Err);
338+
return -1;
339+
}
340+
return raw_fd_stream::read(Ptr, Size);
341+
}

llvm/unittests/Support/raw_socket_stream_test.cpp

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,50 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
6262
ssize_t BytesRead = Server.read(Bytes, 8);
6363

6464
std::string string(Bytes, 8);
65+
ASSERT_EQ(Server.has_error(), false);
6566

6667
ASSERT_EQ(8, BytesRead);
6768
ASSERT_EQ("01234567", string);
6869
}
6970

70-
TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
71+
TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
7172
if (!hasUnixSocketSupport())
7273
GTEST_SKIP();
7374

7475
SmallString<100> SocketPath;
75-
llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true);
76+
llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true);
77+
78+
// Make sure socket file does not exist. May still be there from the last test
79+
std::remove(SocketPath.c_str());
80+
81+
Expected<ListeningSocket> MaybeServerListener =
82+
ListeningSocket::createUnix(SocketPath);
83+
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
84+
ListeningSocket ServerListener = std::move(*MaybeServerListener);
85+
86+
Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
87+
raw_socket_stream::createConnectedUnix(SocketPath);
88+
ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
89+
90+
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
91+
ServerListener.accept();
92+
ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
93+
raw_socket_stream &Server = **MaybeServer;
94+
95+
char Bytes[8];
96+
ssize_t BytesRead = Server.read(Bytes, 8, std::chrono::milliseconds(100));
97+
ASSERT_EQ(BytesRead, -1);
98+
ASSERT_EQ(Server.has_error(), true);
99+
ASSERT_EQ(Server.error(), std::errc::timed_out);
100+
Server.clear_error();
101+
}
102+
103+
TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
104+
if (!hasUnixSocketSupport())
105+
GTEST_SKIP();
106+
107+
SmallString<100> SocketPath;
108+
llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true);
76109

77110
// Make sure socket file does not exist. May still be there from the last test
78111
std::remove(SocketPath.c_str());
@@ -82,19 +115,19 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
82115
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
83116
ListeningSocket ServerListener = std::move(*MaybeServerListener);
84117

85-
std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
86118
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
87-
ServerListener.accept(Timeout);
119+
ServerListener.accept(std::chrono::milliseconds(100));
88120
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
89121
std::errc::timed_out);
90122
}
91123

92-
TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
124+
TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
93125
if (!hasUnixSocketSupport())
94126
GTEST_SKIP();
95127

96128
SmallString<100> SocketPath;
97-
llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
129+
llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath,
130+
true);
98131

99132
// Make sure socket file does not exist. May still be there from the last test
100133
std::remove(SocketPath.c_str());

0 commit comments

Comments
 (0)