Skip to content

[llvm][Support] Implement raw_socket_stream::read with optional timeout #92308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 22, 2024

Conversation

cpsughrue
Copy link
Contributor

@cpsughrue cpsughrue commented May 15, 2024

This PR implements raw_socket_stream::read, which overloads the base class raw_fd_stream::read. raw_socket_stream::read provides a way to timeout the underlying ::read. The timeout functionality was not added to raw_fd_stream::read to avoid needlessly increasing compile times and allow for convenient code reuse with raw_socket_stream::accept, which also requires timeout functionality. This PR supports the module build daemon and will help guarantee it never becomes a zombie process.

Copy link

github-actions bot commented May 29, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@cpsughrue cpsughrue marked this pull request as ready for review June 3, 2024 04:36
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2024

@llvm/pr-subscribers-llvm-support

Author: Connor Sughrue (cpsughrue)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/92308.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Support/raw_socket_stream.h (+18-3)
  • (modified) llvm/lib/Support/raw_socket_stream.cpp (+70-23)
  • (modified) llvm/unittests/Support/raw_socket_stream_test.cpp (+96-11)
diff --git a/llvm/include/llvm/Support/raw_socket_stream.h b/llvm/include/llvm/Support/raw_socket_stream.h
index bddd47eb75e1a..225980cb28a42 100644
--- a/llvm/include/llvm/Support/raw_socket_stream.h
+++ b/llvm/include/llvm/Support/raw_socket_stream.h
@@ -92,10 +92,11 @@ class ListeningSocket {
   /// Accepts an incoming connection on the listening socket. This method can
   /// optionally either block until a connection is available or timeout after a
   /// specified amount of time has passed. By default the method will block
-  /// until the socket has recieved a connection.
+  /// until the socket has recieved a connection. If the accept timesout this
+  /// method will return std::errc:timed_out
   ///
   /// \param Timeout An optional timeout duration in milliseconds. Setting
-  /// Timeout to -1 causes accept to block indefinitely
+  /// Timeout to a negative number causes ::accept to block indefinitely
   ///
   Expected<std::unique_ptr<raw_socket_stream>>
   accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
@@ -124,11 +125,25 @@ class raw_socket_stream : public raw_fd_stream {
 
 public:
   raw_socket_stream(int SocketFD);
+  ~raw_socket_stream();
+
   /// Create a \p raw_socket_stream connected to the UNIX domain socket at \p
   /// SocketPath.
   static Expected<std::unique_ptr<raw_socket_stream>>
   createConnectedUnix(StringRef SocketPath);
-  ~raw_socket_stream();
+
+  /// Attempt to read from the raw_socket_stream's file descriptor. This method
+  /// can optionally either block until data is read or an error has occurred or
+  /// timeout after a specified amount of time has passed. By default the method
+  /// will block until the socket has read data or encountered an error. If the
+  /// read timesout this method will return std::errc:timed_out
+  ///
+  /// \param Timeout An optional timeout duration in milliseconds
+  /// \param Ptr The start of the buffer that will hold any read data
+  /// \param Size The number of bytes to be read
+  ///
+  Expected<std::string> readFromSocket(
+      std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 549d537709bf2..063f6fc366da9 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -18,6 +18,7 @@
 
 #include <atomic>
 #include <fcntl.h>
+#include <functional>
 #include <thread>
 
 #ifndef _WIN32
@@ -177,22 +178,31 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
 #endif // _WIN32
 }
 
-Expected<std::unique_ptr<raw_socket_stream>>
-ListeningSocket::accept(std::chrono::milliseconds Timeout) {
-
-  struct pollfd FDs[2];
-  FDs[0].events = POLLIN;
+// If a file descriptor being monitored by poll is closed by another thread, the
+// result is unspecified. In the case poll does not unblock and return when
+// ActiveFD is closed you can provide another file descriptor via CancelFD that
+// when written to will cause poll to return. Typically CancelFD is the read end
+// of a unidirectional pipe.
+static llvm::Error manageTimeout(std::chrono::milliseconds Timeout,
+                                 std::function<int()> getActiveFD,
+                                 std::optional<int> CancelFD = std::nullopt) {
+  struct pollfd FD[2];
+  FD[0].events = POLLIN;
 #ifdef _WIN32
-  SOCKET WinServerSock = _get_osfhandle(FD);
-  FDs[0].fd = WinServerSock;
+  SOCKET WinServerSock = _get_osfhandle(getActiveFD());
+  FD[0].fd = WinServerSock;
 #else
-  FDs[0].fd = FD;
+  FD[0].fd = getActiveFD();
 #endif
-  FDs[1].events = POLLIN;
-  FDs[1].fd = PipeFD[0];
+  uint8_t FDCount = 1;
+  if (CancelFD.has_value()) {
+    FD[1].events = POLLIN;
+    FD[1].fd = CancelFD.value();
+    FDCount++;
+  }
 
-  // Keep track of how much time has passed in case poll is interupted by a
-  // signal and needs to be recalled
+  // Keep track of how much time has passed in case ::poll or WSAPoll are
+  // interupted by a signal and need to be recalled
   int RemainingTime = Timeout.count();
   std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
   int PollStatus = -1;
@@ -200,20 +210,20 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
   while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
     if (Timeout.count() != -1)
       RemainingTime -= ElapsedTime.count();
-
     auto Start = std::chrono::steady_clock::now();
+
 #ifdef _WIN32
-    PollStatus = WSAPoll(FDs, 2, RemainingTime);
+    PollStatus = WSAPoll(FD, FDCount, RemainingTime);
 #else
-    PollStatus = ::poll(FDs, 2, RemainingTime);
+    PollStatus = ::poll(FD, FDCount, RemainingTime);
 #endif
-    // If FD equals -1 then ListeningSocket::shutdown has been called and it is
-    // appropriate to return operation_canceled
-    if (FD.load() == -1)
+
+    // If ActiveFD equals -1 or CancelFD has data to be read then the operation
+    // has been canceled by another thread
+    if (getActiveFD() == -1 || FD[1].revents & POLLIN)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::operation_canceled),
           "Accept canceled");
-
 #if _WIN32
     if (PollStatus == SOCKET_ERROR) {
 #else
@@ -222,14 +232,14 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
       std::error_code PollErrCode = getLastSocketErrorCode();
       // Ignore EINTR (signal occured before any request event) and retry
       if (PollErrCode != std::errc::interrupted)
-        return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
+        return llvm::make_error<StringError>(PollErrCode, "poll failed");
     }
     if (PollStatus == 0)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::timed_out),
           "No client requests within timeout window");
 
-    if (FDs[0].revents & POLLNVAL)
+    if (FD[0].revents & POLLNVAL)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::bad_file_descriptor));
 
@@ -237,10 +247,19 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
     ElapsedTime +=
         std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
   }
+  return llvm::Error::success();
+}
+
+Expected<std::unique_ptr<raw_socket_stream>>
+ListeningSocket::accept(std::chrono::milliseconds Timeout) {
+  auto getActiveFD = [this]() -> int { return FD; };
+  llvm::Error TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
+  if (TimeoutErr)
+    return TimeoutErr;
 
   int AcceptFD;
 #ifdef _WIN32
-  SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
+  SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
   AcceptFD = _open_osfhandle(WinAcceptSock, 0);
 #else
   AcceptFD = ::accept(FD, NULL, NULL);
@@ -295,6 +314,8 @@ ListeningSocket::~ListeningSocket() {
 raw_socket_stream::raw_socket_stream(int SocketFD)
     : raw_fd_stream(SocketFD, true) {}
 
+raw_socket_stream::~raw_socket_stream() {}
+
 Expected<std::unique_ptr<raw_socket_stream>>
 raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
 #ifdef _WIN32
@@ -306,4 +327,30 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
   return std::make_unique<raw_socket_stream>(*FD);
 }
 
-raw_socket_stream::~raw_socket_stream() {}
+Expected<std::string>
+raw_socket_stream::readFromSocket(std::chrono::milliseconds Timeout) {
+  auto getActiveFD = [this]() -> int { return this->get_fd(); };
+  llvm::Error TimeoutErr = manageTimeout(Timeout, getActiveFD);
+  if (TimeoutErr)
+    return TimeoutErr;
+
+  std::vector<char> Buffer;
+  constexpr ssize_t TmpBufferSize = 1024;
+  char TmpBuffer[TmpBufferSize];
+
+  while (true) {
+    std::memset(TmpBuffer, 0, TmpBufferSize);
+    ssize_t BytesRead = this->read(TmpBuffer, TmpBufferSize);
+    if (BytesRead == -1)
+      return llvm::make_error<StringError>(this->error(), "read failed");
+    else if (BytesRead == 0)
+      break;
+    else
+      Buffer.insert(Buffer.end(), TmpBuffer, TmpBuffer + BytesRead);
+    // All available bytes have been read. Another call to read will block
+    if (BytesRead < TmpBufferSize)
+      break;
+  }
+
+  return std::string(Buffer.begin(), Buffer.end());
+}
diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp
index c4e8cfbbe7e6a..1b8f85f88f1af 100644
--- a/llvm/unittests/Support/raw_socket_stream_test.cpp
+++ b/llvm/unittests/Support/raw_socket_stream_test.cpp
@@ -58,21 +58,106 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
   Client << "01234567";
   Client.flush();
 
-  char Bytes[8];
-  ssize_t BytesRead = Server.read(Bytes, 8);
+  llvm::Expected<std::string> MaybeText = Server.readFromSocket();
+  ASSERT_THAT_EXPECTED(MaybeText, llvm::Succeeded());
+  ASSERT_EQ("01234567", *MaybeText);
+}
+
+TEST(raw_socket_streamTest, LARGE_READ) {
+  if (!hasUnixSocketSupport())
+    GTEST_SKIP();
+
+  SmallString<100> SocketPath;
+  llvm::sys::fs::createUniquePath("large_read.sock", SocketPath, true);
+
+  // Make sure socket file does not exist. May still be there from the last test
+  std::remove(SocketPath.c_str());
+
+  Expected<ListeningSocket> MaybeServerListener =
+      ListeningSocket::createUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
+  ListeningSocket ServerListener = std::move(*MaybeServerListener);
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
+      raw_socket_stream::createConnectedUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
+  raw_socket_stream &Client = **MaybeClient;
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
+      ServerListener.accept();
+  ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
+  raw_socket_stream &Server = **MaybeServer;
+
+  // raw_socket_stream::readFromSocket pre-allocates a buffer 1024 bytes large.
+  // Test to make sure readFromSocket can handle messages larger then size of
+  // pre-allocated block
+  constexpr int TextLength = 1342;
+  constexpr char Text[TextLength] =
+      "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do "
+      "eiusmod tempor incididunt ut labore et dolore magna aliqua. Vel orci "
+      "porta non pulvinar neque laoreet suspendisse interdum consectetur. "
+      "Nulla facilisi etiam dignissim diam quis. Porttitor massa id neque "
+      "aliquam vestibulum morbi blandit cursus. Purus viverra accumsan in "
+      "nisl. Nunc non blandit massa enim nec dui nunc mattis enim. Rhoncus "
+      "dolor purus non enim praesent elementum facilisis leo. Parturient "
+      "montes nascetur ridiculus mus mauris. Urna condimentum mattis "
+      "pellentesque id nibh tortor id aliquet lectus. Orci eu lobortis "
+      "elementum nibh. Sagittis eu volutpat odio facilisis. Molestie a "
+      "iaculis at erat pellentesque adipiscing. Tincidunt augue interdum "
+      "velit euismod in pellentesque massa placerat. Cras ornare arcu dui "
+      "vivamus arcu felis bibendum ut tristique. Tellus elementum sagittis "
+      "vitae et leo duis. Scelerisque fermentum dui faucibus in ornare "
+      "quam. Ipsum a arcu cursus vitae congue. Sit amet nisl suscipit "
+      "adipiscing. Sociis natoque penatibus et magnis. Cras semper auctor "
+      "neque vitae tempus quam pellentesque. Neque gravida in fermentum et "
+      "sollicitudin ac orci phasellus egestas. Vitae suscipit tellus mauris "
+      "a diam maecenas sed. Lectus arcu bibendum at varius vel pharetra. "
+      "Dignissim sodales ut eu sem integer vitae justo. Id cursus metus "
+      "aliquam eleifend mi.";
+  Client << Text;
+  Client.flush();
+
+  llvm::Expected<std::string> MaybeText = Server.readFromSocket();
+  ASSERT_THAT_EXPECTED(MaybeText, llvm::Succeeded());
+  ASSERT_EQ(Text, *MaybeText);
+}
 
-  std::string string(Bytes, 8);
+TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
+  if (!hasUnixSocketSupport())
+    GTEST_SKIP();
+
+  SmallString<100> SocketPath;
+  llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true);
 
-  ASSERT_EQ(8, BytesRead);
-  ASSERT_EQ("01234567", string);
+  // Make sure socket file does not exist. May still be there from the last test
+  std::remove(SocketPath.c_str());
+
+  Expected<ListeningSocket> MaybeServerListener =
+      ListeningSocket::createUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
+  ListeningSocket ServerListener = std::move(*MaybeServerListener);
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
+      raw_socket_stream::createConnectedUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
+      ServerListener.accept();
+  ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
+  raw_socket_stream &Server = **MaybeServer;
+
+  llvm::Expected<std::string> MaybeBytesRead =
+      Server.readFromSocket(std::chrono::milliseconds(100));
+  ASSERT_EQ(llvm::errorToErrorCode(MaybeBytesRead.takeError()),
+            std::errc::timed_out);
 }
 
-TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
+TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
   if (!hasUnixSocketSupport())
     GTEST_SKIP();
 
   SmallString<100> SocketPath;
-  llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true);
+  llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true);
 
   // Make sure socket file does not exist. May still be there from the last test
   std::remove(SocketPath.c_str());
@@ -82,19 +167,19 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
   ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
   ListeningSocket ServerListener = std::move(*MaybeServerListener);
 
-  std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
   Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
-      ServerListener.accept(Timeout);
+      ServerListener.accept(std::chrono::milliseconds(100));
   ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
             std::errc::timed_out);
 }
 
-TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
+TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
   if (!hasUnixSocketSupport())
     GTEST_SKIP();
 
   SmallString<100> SocketPath;
-  llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
+  llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath,
+                                  true);
 
   // Make sure socket file does not exist. May still be there from the last test
   std::remove(SocketPath.c_str());

@cpsughrue cpsughrue changed the title [llvm][Support] Add function to read from raw_socket_stream file descriptor with timeout [llvm][Support] Implement raw_socket_stream::readFromSocket with optional timeout Jun 3, 2024
@cpsughrue cpsughrue changed the title [llvm][Support] Implement raw_socket_stream::readFromSocket with optional timeout [llvm][Support] Implement raw_socket_stream::read with optional timeout Jun 17, 2024
Copy link
Contributor

@Bigcheese Bigcheese left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good with the suggested changes and a good commit message.

@cpsughrue cpsughrue merged commit 76321b9 into llvm:main Jul 22, 2024
7 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…ut (#92308)

Summary:
This PR implements `raw_socket_stream::read`, which overloads the base
class `raw_fd_stream::read`. `raw_socket_stream::read` provides a way to
timeout the underlying `::read`. The timeout functionality was not added
to `raw_fd_stream::read` to avoid needlessly increasing compile times
and allow for convenient code reuse with `raw_socket_stream::accept`,
which also requires timeout functionality. This PR supports the module
build daemon and will help guarantee it never becomes a zombie process.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251217
@cpsughrue cpsughrue deleted the read_timeout branch August 15, 2024 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants