Skip to content

Commit a55d698

Browse files
committed
WIP (not protable) - Add self-pipe trick to ListeningServer
1 parent 8546986 commit a55d698

File tree

3 files changed

+61
-3
lines changed

3 files changed

+61
-3
lines changed

llvm/include/llvm/Support/raw_socket_stream.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,15 @@ class WSABalancer {
5858
/// \endcode
5959
///
6060
class ListeningSocket {
61+
62+
/// If ListeningSocket::shutdown is used by a signal handler to clean up
63+
/// ListeningSocket resources FD may be closed while ::poll is waiting for
64+
/// FD to become ready to perform I/O. When FD is closed ::poll will
65+
/// continue to block so use the self-pipe trick to get ::poll to return
66+
int PipeFD[2];
67+
std::mutex PipeMutex;
6168
std::atomic<int> FD;
62-
std::string SocketPath;
69+
std::string SocketPath; // Never modified
6370
ListeningSocket(int SocketFD, StringRef SocketPath);
6471

6572
#ifdef _WIN32

llvm/lib/Support/raw_socket_stream.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include "llvm/Support/FileSystem.h"
1818

1919
#include <atomic>
20+
#include <fcntl.h>
2021
#include <poll.h>
22+
#include <thread>
2123

2224
#ifndef _WIN32
2325
#include <sys/socket.h>
@@ -168,7 +170,7 @@ ListeningSocket::createListeningUnixSocket(StringRef SocketPath,
168170
Expected<std::unique_ptr<raw_socket_stream>>
169171
ListeningSocket::accept(std::optional<std::chrono::milliseconds> Timeout) {
170172

171-
struct pollfd FDs[1];
173+
struct pollfd FDs[2];
172174
FDs[0].events = POLLIN;
173175
#ifdef _WIN32
174176
SOCKET WinServerSock = _get_osfhandle(FD);
@@ -177,8 +179,16 @@ ListeningSocket::accept(std::optional<std::chrono::milliseconds> Timeout) {
177179
FDs[0].fd = FD;
178180
#endif
179181

182+
FDs[1].events = POLLIN;
183+
PipeMutex.lock();
184+
if (::pipe(PipeFD) == -1)
185+
return llvm::make_error<StringError>(getLastSocketErrorCode(),
186+
"pipe failed");
187+
FDs[1].fd = PipeFD[0];
188+
PipeMutex.unlock();
189+
180190
int TimeoutCount = Timeout.value_or(std::chrono::milliseconds(-1)).count();
181-
int PollStatus = ::poll(FDs, 1, TimeoutCount);
191+
int PollStatus = ::poll(FDs, 2, TimeoutCount);
182192

183193
if (PollStatus == -1)
184194
return llvm::make_error<StringError>(getLastSocketErrorCode(),
@@ -212,6 +222,11 @@ void ListeningSocket::shutdown() {
212222
return;
213223
::close(FD);
214224
::unlink(SocketPath.c_str());
225+
226+
char Byte = 'A';
227+
PipeMutex.lock();
228+
write(PipeFD[1], &Byte, 1);
229+
PipeMutex.unlock();
215230
FD = -1;
216231
}
217232

llvm/unittests/Support/raw_socket_stream_test.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,40 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
9797
ASSERT_EQ(EC, std::errc::timed_out);
9898
});
9999
}
100+
101+
TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
102+
if (!hasUnixSocketSupport())
103+
GTEST_SKIP();
104+
105+
SmallString<100> SocketPath;
106+
llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
107+
108+
// Make sure socket file does not exist. May still be there from the last test
109+
std::remove(SocketPath.c_str());
110+
111+
Expected<ListeningSocket> MaybeServerListener =
112+
ListeningSocket::createListeningUnixSocket(SocketPath);
113+
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
114+
ListeningSocket ServerListener = std::move(*MaybeServerListener);
115+
116+
// Create a separate thread to close the socket after a delay. Simulates a
117+
// signal handler calling ServerListener::shutdown
118+
std::thread CloseThread([&]() {
119+
std::this_thread::sleep_for(std::chrono::seconds(2));
120+
ServerListener.shutdown();
121+
});
122+
123+
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
124+
ServerListener.accept();
125+
126+
// Wait for the CloseThread to finish
127+
CloseThread.join();
128+
129+
ASSERT_THAT_EXPECTED(MaybeServer, Failed());
130+
llvm::Error Err = MaybeServer.takeError();
131+
llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
132+
std::error_code EC = SE.convertToErrorCode();
133+
ASSERT_EQ(EC, std::errc::bad_file_descriptor);
134+
});
135+
}
100136
} // namespace

0 commit comments

Comments
 (0)