18
18
19
19
#include < atomic>
20
20
#include < fcntl.h>
21
+ #include < functional>
21
22
#include < thread>
22
23
23
24
#ifndef _WIN32
@@ -177,70 +178,89 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
177
178
#endif // _WIN32
178
179
}
179
180
180
- Expected<std::unique_ptr<raw_socket_stream>>
181
- ListeningSocket::accept (std::chrono::milliseconds Timeout) {
182
-
183
- struct pollfd FDs[2 ];
184
- FDs[0 ].events = POLLIN;
181
+ // If a file descriptor being monitored by ::poll is closed by another thread,
182
+ // the result is unspecified. In the case ::poll does not unblock and return,
183
+ // when ActiveFD is closed, you can provide another file descriptor via CancelFD
184
+ // that when written to will cause poll to return. Typically CancelFD is the
185
+ // read end of a unidirectional pipe.
186
+ //
187
+ // Timeout should be -1 to block indefinitly
188
+ //
189
+ // getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
190
+ static std::error_code
191
+ manageTimeout (const std::chrono::milliseconds &Timeout,
192
+ const std::function<int ()> &getActiveFD,
193
+ const std::optional<int > &CancelFD = std::nullopt) {
194
+ struct pollfd FD[2 ];
195
+ FD[0 ].events = POLLIN;
185
196
#ifdef _WIN32
186
- SOCKET WinServerSock = _get_osfhandle (FD );
187
- FDs [0 ].fd = WinServerSock;
197
+ SOCKET WinServerSock = _get_osfhandle (getActiveFD () );
198
+ FD [0 ].fd = WinServerSock;
188
199
#else
189
- FDs [0 ].fd = FD ;
200
+ FD [0 ].fd = getActiveFD () ;
190
201
#endif
191
- FDs[1 ].events = POLLIN;
192
- FDs[1 ].fd = PipeFD[0 ];
193
-
194
- // Keep track of how much time has passed in case poll is interupted by a
195
- // signal and needs to be recalled
196
- int RemainingTime = Timeout.count ();
197
- std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds (0 );
198
- int PollStatus = -1 ;
199
-
200
- while (PollStatus == -1 && (Timeout.count () == -1 || ElapsedTime < Timeout)) {
201
- if (Timeout.count () != -1 )
202
- RemainingTime -= ElapsedTime.count ();
202
+ uint8_t FDCount = 1 ;
203
+ if (CancelFD.has_value ()) {
204
+ FD[1 ].events = POLLIN;
205
+ FD[1 ].fd = CancelFD.value ();
206
+ FDCount++;
207
+ }
203
208
204
- auto Start = std::chrono::steady_clock::now ();
209
+ // Keep track of how much time has passed in case ::poll or WSAPoll are
210
+ // interupted by a signal and need to be recalled
211
+ auto Start = std::chrono::steady_clock::now ();
212
+ auto RemainingTimeout = Timeout;
213
+ int PollStatus = 0 ;
214
+ do {
215
+ // If Timeout is -1 then poll should block and RemainingTimeout does not
216
+ // need to be recalculated
217
+ if (PollStatus != 0 && Timeout != std::chrono::milliseconds (-1 )) {
218
+ auto TotalElapsedTime =
219
+ std::chrono::duration_cast<std::chrono::milliseconds>(
220
+ std::chrono::steady_clock::now () - Start);
221
+
222
+ if (TotalElapsedTime >= Timeout)
223
+ return std::make_error_code (std::errc::operation_would_block);
224
+
225
+ RemainingTimeout = Timeout - TotalElapsedTime;
226
+ }
205
227
#ifdef _WIN32
206
- PollStatus = WSAPoll (FDs, 2 , RemainingTime);
228
+ PollStatus = WSAPoll (FD, FDCount, RemainingTimeout.count ());
229
+ } while (PollStatus == SOCKET_ERROR &&
230
+ getLastSocketErrorCode () == std::errc::interrupted);
207
231
#else
208
- PollStatus = ::poll (FDs, 2 , RemainingTime);
232
+ PollStatus = ::poll (FD, FDCount, RemainingTimeout.count ());
233
+ } while (PollStatus == -1 &&
234
+ getLastSocketErrorCode () == std::errc::interrupted);
209
235
#endif
210
- // If FD equals -1 then ListeningSocket::shutdown has been called and it is
211
- // appropriate to return operation_canceled
212
- if (FD.load () == -1 )
213
- return llvm::make_error<StringError>(
214
- std::make_error_code (std::errc::operation_canceled),
215
- " Accept canceled" );
216
236
237
+ // If ActiveFD equals -1 or CancelFD has data to be read then the operation
238
+ // has been canceled by another thread
239
+ if (getActiveFD () == -1 || (CancelFD.has_value () && FD[1 ].revents & POLLIN))
240
+ return std::make_error_code (std::errc::operation_canceled);
217
241
#if _WIN32
218
- if (PollStatus == SOCKET_ERROR) {
242
+ if (PollStatus == SOCKET_ERROR)
219
243
#else
220
- if (PollStatus == -1 ) {
244
+ if (PollStatus == -1 )
221
245
#endif
222
- std::error_code PollErrCode = getLastSocketErrorCode ();
223
- // Ignore EINTR (signal occured before any request event) and retry
224
- if (PollErrCode != std::errc::interrupted)
225
- return llvm::make_error<StringError>(PollErrCode, " FD poll failed" );
226
- }
227
- if (PollStatus == 0 )
228
- return llvm::make_error<StringError>(
229
- std::make_error_code (std::errc::timed_out),
230
- " No client requests within timeout window" );
231
-
232
- if (FDs[0 ].revents & POLLNVAL)
233
- return llvm::make_error<StringError>(
234
- std::make_error_code (std::errc::bad_file_descriptor));
246
+ return getLastSocketErrorCode ();
247
+ if (PollStatus == 0 )
248
+ return std::make_error_code (std::errc::timed_out);
249
+ if (FD[0 ].revents & POLLNVAL)
250
+ return std::make_error_code (std::errc::bad_file_descriptor);
251
+ return std::error_code ();
252
+ }
235
253
236
- auto Stop = std::chrono::steady_clock::now ();
237
- ElapsedTime +=
238
- std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
239
- }
254
+ Expected<std::unique_ptr<raw_socket_stream>>
255
+ ListeningSocket::accept (const std::chrono::milliseconds &Timeout) {
256
+ auto getActiveFD = [this ]() -> int { return FD; };
257
+ std::error_code TimeoutErr = manageTimeout (Timeout, getActiveFD, PipeFD[0 ]);
258
+ if (TimeoutErr)
259
+ return llvm::make_error<StringError>(TimeoutErr, " Timeout error" );
240
260
241
261
int AcceptFD;
242
262
#ifdef _WIN32
243
- SOCKET WinAcceptSock = ::accept (WinServerSock , NULL , NULL );
263
+ SOCKET WinAcceptSock = ::accept (_get_osfhandle (FD) , NULL , NULL );
244
264
AcceptFD = _open_osfhandle (WinAcceptSock, 0 );
245
265
#else
246
266
AcceptFD = ::accept (FD, NULL , NULL );
@@ -295,6 +315,8 @@ ListeningSocket::~ListeningSocket() {
295
315
raw_socket_stream::raw_socket_stream (int SocketFD)
296
316
: raw_fd_stream (SocketFD, true ) {}
297
317
318
+ raw_socket_stream::~raw_socket_stream () {}
319
+
298
320
Expected<std::unique_ptr<raw_socket_stream>>
299
321
raw_socket_stream::createConnectedUnix (StringRef SocketPath) {
300
322
#ifdef _WIN32
@@ -306,4 +328,14 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
306
328
return std::make_unique<raw_socket_stream>(*FD);
307
329
}
308
330
309
- raw_socket_stream::~raw_socket_stream () {}
331
+ ssize_t raw_socket_stream::read (char *Ptr, size_t Size,
332
+ const std::chrono::milliseconds &Timeout) {
333
+ auto getActiveFD = [this ]() -> int { return this ->get_fd (); };
334
+ std::error_code Err = manageTimeout (Timeout, getActiveFD);
335
+ // Mimic raw_fd_stream::read error handling behavior
336
+ if (Err) {
337
+ raw_fd_stream::error_detected (Err);
338
+ return -1 ;
339
+ }
340
+ return raw_fd_stream::read (Ptr, Size);
341
+ }
0 commit comments