Skip to content

Commit 8e7148a

Browse files
committed
Address feedback
1 parent 4865936 commit 8e7148a

File tree

3 files changed

+30
-20
lines changed

3 files changed

+30
-20
lines changed

llvm/include/llvm/Support/raw_socket_stream.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class WSABalancer {
4747
/// Usage example:
4848
/// \code{.cpp}
4949
/// std::string Path = "/path/to/socket"
50-
/// Expected<ListeningSocket> S = ListeningSocket::createListeningSocket(Path);
50+
/// Expected<ListeningSocket> S = ListeningSocket::createUnix(Path);
5151
///
5252
/// if (S) {
5353
/// Expected<std::unique_ptr<raw_socket_stream>> connection = S->accept();
@@ -95,10 +95,11 @@ class ListeningSocket {
9595
/// specified amount of time has passed. By default the method will block
9696
/// until the socket has recieved a connection.
9797
///
98-
/// \param Timeout An optional timeout duration in milliseconds
98+
/// \param Timeout An optional timeout duration in milliseconds. Setting
99+
/// Timeout to -1 causes accept to block indefinitely
99100
///
100101
Expected<std::unique_ptr<raw_socket_stream>>
101-
accept(std::optional<std::chrono::milliseconds> Timeout = std::nullopt);
102+
accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
102103

103104
/// Creates a listening socket bound to the specified file system path.
104105
/// Handles the socket creation, binding, and immediately starts listening for

llvm/lib/Support/raw_socket_stream.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
178178
}
179179

180180
Expected<std::unique_ptr<raw_socket_stream>>
181-
ListeningSocket::accept(std::optional<std::chrono::milliseconds> Timeout) {
181+
ListeningSocket::accept(std::chrono::milliseconds Timeout) {
182182

183183
struct pollfd FDs[2];
184184
FDs[0].events = POLLIN;
@@ -191,15 +191,14 @@ ListeningSocket::accept(std::optional<std::chrono::milliseconds> Timeout) {
191191
FDs[1].events = POLLIN;
192192
FDs[1].fd = PipeFD[0];
193193

194-
std::chrono::milliseconds OriginalTimeout =
195-
Timeout.value_or(std::chrono::milliseconds(-1));
196-
int RemainingTime = OriginalTimeout.count();
194+
// Keep track of how much time has passed in case poll is interupted by a
195+
// signal and needs to be recalled
196+
int RemainingTime = Timeout.count();
197197
std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
198-
199198
int PollStatus = -1;
200-
while (PollStatus == -1 &&
201-
(RemainingTime == -1 || ElapsedTime < OriginalTimeout)) {
202-
if (RemainingTime != -1)
199+
200+
while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
201+
if (Timeout.count() != -1)
203202
RemainingTime -= ElapsedTime.count();
204203

205204
auto Start = std::chrono::steady_clock::now();
@@ -213,7 +212,7 @@ ListeningSocket::accept(std::optional<std::chrono::milliseconds> Timeout) {
213212
// Ignore error if caused by interupting signal
214213
std::error_code PollErrCode = getLastSocketErrorCode();
215214
if (PollErrCode != std::errc::interrupted)
216-
return llvm::make_error<StringError>(PollErrCode, "poll failed");
215+
return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
217216
}
218217

219218
if (PollStatus == 0)
@@ -226,9 +225,14 @@ ListeningSocket::accept(std::optional<std::chrono::milliseconds> Timeout) {
226225
std::make_error_code(std::errc::bad_file_descriptor),
227226
"File descriptor closed by another thread");
228227

229-
auto End = std::chrono::steady_clock::now();
228+
if (FDs[1].revents & POLLIN)
229+
return llvm::make_error<StringError>(
230+
std::make_error_code(std::errc::operation_canceled),
231+
"Accept canceled");
232+
233+
auto Stop = std::chrono::steady_clock::now();
230234
ElapsedTime +=
231-
std::chrono::duration_cast<std::chrono::milliseconds>(End - Start);
235+
std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
232236
}
233237

234238
int AcceptFD;
@@ -241,24 +245,27 @@ ListeningSocket::accept(std::optional<std::chrono::milliseconds> Timeout) {
241245

242246
if (AcceptFD == -1)
243247
return llvm::make_error<StringError>(getLastSocketErrorCode(),
244-
"accept failed");
248+
"Socket accept failed");
245249
return std::make_unique<raw_socket_stream>(AcceptFD);
246250
}
247251

248252
void ListeningSocket::shutdown() {
249253
int ObservedFD = FD.load();
254+
250255
if (ObservedFD == -1)
251256
return;
257+
258+
// If FD equals ObservedFD set FD to -1; If FD doesn't equal ObservedFD then
259+
// another thread is responsible for shutdown so return
252260
if (!FD.compare_exchange_strong(ObservedFD, -1))
253261
return;
254-
::close(FD);
262+
263+
::close(ObservedFD);
255264
::unlink(SocketPath.c_str());
256265

257266
// Ensure ::poll returns if shutdown is called by a seperate thread
258267
char Byte = 'A';
259268
::write(PipeFD[1], &Byte, 1);
260-
261-
FD = -1;
262269
}
263270

264271
ListeningSocket::~ListeningSocket() {
@@ -267,6 +274,8 @@ ListeningSocket::~ListeningSocket() {
267274
// Close the pipe's FDs in the destructor instead of within
268275
// ListeningSocket::shutdown to avoid unnecessary synchronization issues that
269276
// would occur as PipeFD's values would have to be changed to -1
277+
//
278+
// The move constructor sets PipeFD to -1
270279
if (PipeFD[0] != -1)
271280
::close(PipeFD[0]);
272281
if (PipeFD[1] != -1)

llvm/unittests/Support/raw_socket_stream_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
118118
// Create a separate thread to close the socket after a delay. Simulates a
119119
// signal handler calling ServerListener::shutdown
120120
std::thread CloseThread([&]() {
121-
std::this_thread::sleep_for(std::chrono::seconds(2));
121+
std::this_thread::sleep_for(std::chrono::milliseconds(500));
122122
ServerListener.shutdown();
123123
});
124124

@@ -132,7 +132,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
132132
llvm::Error Err = MaybeServer.takeError();
133133
llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
134134
std::error_code EC = SE.convertToErrorCode();
135-
ASSERT_EQ(EC, std::errc::bad_file_descriptor);
135+
ASSERT_EQ(EC, std::errc::operation_canceled);
136136
});
137137
}
138138
} // namespace

0 commit comments

Comments
 (0)