Skip to content

Commit 4569e17

Browse files
committed
Improvements to raw_socket_stream functionality and documentation
1 parent 60dda1f commit 4569e17

File tree

3 files changed

+218
-71
lines changed

3 files changed

+218
-71
lines changed

llvm/include/llvm/Support/raw_socket_stream.h

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,105 @@
1717
#include "llvm/Support/Threading.h"
1818
#include "llvm/Support/raw_ostream.h"
1919

20+
#include <atomic>
21+
#include <chrono>
22+
2023
namespace llvm {
2124

2225
class raw_socket_stream;
2326

24-
// Make sure that calls to WSAStartup and WSACleanup are balanced.
2527
#ifdef _WIN32
28+
/// @brief Ensures proper initialization and cleanup of winsock resources
29+
///
30+
/// @details
31+
/// Make sure that calls to WSAStartup and WSACleanup are balanced.
2632
class WSABalancer {
2733
public:
2834
WSABalancer();
2935
~WSABalancer();
3036
};
3137
#endif // _WIN32
3238

39+
/// @class ListeningSocket
40+
/// @brief Manages a passive (i.e., listening) UNIX domain socket
41+
///
42+
/// The ListeningSocket class encapsulates a UNIX domain socket that can listen
43+
/// and accept incoming connections. ListeningSocket is portable and supports
44+
/// Windows builds begining with Insider Build 17063. ListeningSocket is
45+
/// designed for server-side operations, working alongside raw_socket_streams
46+
/// that function as client connections.
47+
///
48+
/// Usage example:
49+
/// @code{.cpp}
50+
/// std::string Path = "/path/to/socket"
51+
/// Expected<ListeningSocket> S = ListeningSocket::createListeningSocket(Path);
52+
///
53+
/// if (listeningSocket) {
54+
/// auto connection = S->accept();
55+
/// if (connection) {
56+
/// // Use the accepted raw_socket_stream for communication.
57+
/// }
58+
/// }
59+
/// @endcode
60+
///
3361
class ListeningSocket {
34-
int FD;
62+
std::atomic<int> FD;
3563
std::string SocketPath;
3664
ListeningSocket(int SocketFD, StringRef SocketPath);
65+
3766
#ifdef _WIN32
3867
WSABalancer _;
3968
#endif // _WIN32
4069

4170
public:
42-
static Expected<ListeningSocket> createUnix(
71+
~ListeningSocket();
72+
ListeningSocket(ListeningSocket &&LS);
73+
ListeningSocket(const ListeningSocket &LS) = delete;
74+
ListeningSocket &operator=(const ListeningSocket &) = delete;
75+
76+
/// Closes the socket's FD and unlinks the socket file from the file system.
77+
/// The method is idempotent
78+
void shutdown();
79+
80+
/// Accepts an incoming connection on the listening socket. This method can
81+
/// optionally either block until a connection is available or timeout after a
82+
/// specified amount of time has passed. By default the method will block
83+
/// until the socket has recieved a connection
84+
///
85+
/// @param Timeout An optional timeout duration in microseconds
86+
///
87+
Expected<std::unique_ptr<raw_socket_stream>>
88+
accept(std::optional<std::chrono::microseconds> Timeout = std::nullopt);
89+
90+
/// Creates a listening socket bound to the specified file system path.
91+
/// Handles the socket creation, binding, and immediately starts listening for
92+
/// incoming connections.
93+
///
94+
/// @param SocketPath The file system path where the socket will be created
95+
/// @param MaxBacklog The max number of connections in a socket's backlog
96+
///
97+
static Expected<ListeningSocket> createListeningSocket(
4398
StringRef SocketPath,
4499
int MaxBacklog = llvm::hardware_concurrency().compute_thread_count());
45-
Expected<std::unique_ptr<raw_socket_stream>> accept();
46-
ListeningSocket(ListeningSocket &&LS);
47-
~ListeningSocket();
48100
};
101+
102+
//===----------------------------------------------------------------------===//
103+
// raw_socket_stream
104+
//===----------------------------------------------------------------------===//
105+
49106
class raw_socket_stream : public raw_fd_stream {
50107
uint64_t current_pos() const override { return 0; }
51108
#ifdef _WIN32
52109
WSABalancer _;
53110
#endif // _WIN32
54111

55112
public:
113+
// TODO: Should probably be private
56114
raw_socket_stream(int SocketFD);
57115
/// Create a \p raw_socket_stream connected to the Unix domain socket at \p
58116
/// SocketPath.
59117
static Expected<std::unique_ptr<raw_socket_stream>>
60-
createConnectedUnix(StringRef SocketPath);
118+
createConnectedSocket(StringRef SocketPath);
61119
~raw_socket_stream();
62120
};
63121

llvm/lib/Support/raw_socket_stream.cpp

Lines changed: 148 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#include "llvm/Support/raw_socket_stream.h"
1515
#include "llvm/Config/config.h"
1616
#include "llvm/Support/Error.h"
17+
#include "llvm/Support/FileSystem.h"
18+
19+
#include <atomic>
1720

1821
#ifndef _WIN32
1922
#include <sys/socket.h>
@@ -45,7 +48,6 @@ WSABalancer::WSABalancer() {
4548
}
4649

4750
WSABalancer::~WSABalancer() { WSACleanup(); }
48-
4951
#endif // _WIN32
5052

5153
static std::error_code getLastSocketErrorCode() {
@@ -56,117 +58,201 @@ static std::error_code getLastSocketErrorCode() {
5658
#endif
5759
}
5860

61+
static void closeFD(int FD) {
62+
#ifdef _WIN32
63+
// on windows ::close is a deprecated alias for _close
64+
_close(FD);
65+
#else
66+
::close(FD);
67+
#endif
68+
}
69+
70+
static void unlinkFile(StringRef Path) {
71+
#ifdef _WIN32
72+
// on windows ::unlink is a deprecated alias for _unlink
73+
_unlink(Path.str().c_str());
74+
#else
75+
::unlink(Path.str().c_str());
76+
#endif
77+
}
78+
79+
static sockaddr_un setSocketAddr(StringRef SocketPath) {
80+
struct sockaddr_un Addr;
81+
memset(&Addr, 0, sizeof(Addr));
82+
Addr.sun_family = AF_UNIX;
83+
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
84+
return Addr;
85+
}
86+
87+
static Expected<int> getSocketFD(StringRef SocketPath) {
88+
#ifdef _WIN32
89+
SOCKET MaybeSocket = socket(AF_UNIX, SOCK_STREAM, 0);
90+
if (MaybeSocket == INVALID_SOCKET) {
91+
#else
92+
int MaybeSocket = socket(AF_UNIX, SOCK_STREAM, 0);
93+
if (MaybeSocket == -1) {
94+
#endif // _WIN32
95+
return llvm::make_error<StringError>(getLastSocketErrorCode(),
96+
"Create socket failed");
97+
}
98+
99+
struct sockaddr_un Addr = setSocketAddr(SocketPath);
100+
if (::connect(MaybeSocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1)
101+
return llvm::make_error<StringError>(getLastSocketErrorCode(),
102+
"Connect socket failed");
103+
104+
#ifdef _WIN32
105+
return _open_osfhandle(MaybeWinsocket, 0);
106+
#else
107+
return MaybeSocket;
108+
#endif // _WIN32
109+
}
110+
59111
ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath)
60112
: FD(SocketFD), SocketPath(SocketPath) {}
61113

62114
ListeningSocket::ListeningSocket(ListeningSocket &&LS)
63-
: FD(LS.FD), SocketPath(LS.SocketPath) {
115+
: FD(LS.FD.load()), SocketPath(LS.SocketPath) {
116+
117+
LS.SocketPath.clear();
64118
LS.FD = -1;
65119
}
66120

67-
Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
68-
int MaxBacklog) {
121+
Expected<ListeningSocket>
122+
ListeningSocket::createListeningSocket(StringRef SocketPath, int MaxBacklog) {
123+
124+
// Handle instances where the target socket address already exists and
125+
// differentiate between a preexisting file with and without a bound socket
126+
//
127+
// ::bind will return std::errc:address_in_use if a file at the socket address
128+
// already exists (e.g., the file was not properly unlinked due to a crash)
129+
// even if another socket has not yet binded to that address
130+
if (llvm::sys::fs::exists(SocketPath)) {
131+
Expected<int> MaybeFD = getSocketFD(SocketPath);
132+
if (!MaybeFD) {
133+
134+
// Regardless of the error, notify the caller that a file already exists
135+
// at the desired socket address. The file must be removed before ::bind
136+
// can use the socket address
137+
consumeError(MaybeFD.takeError());
138+
return llvm::make_error<StringError>(
139+
std::make_error_code(std::errc::file_exists),
140+
"Socket address unavailable");
141+
}
142+
closeFD(std::move(*MaybeFD));
143+
144+
// Notify caller that the provided socket address already has a bound socket
145+
return llvm::make_error<StringError>(
146+
std::make_error_code(std::errc::address_in_use),
147+
"Socket address unavailable");
148+
}
69149

70150
#ifdef _WIN32
71151
WSABalancer _;
72-
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
73-
if (MaybeWinsocket == INVALID_SOCKET) {
152+
SOCKET MaybeSocket = socket(AF_UNIX, SOCK_STREAM, 0);
153+
if (MaybeSocket == INVALID_SOCKET) {
74154
#else
75-
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
76-
if (MaybeWinsocket == -1) {
155+
int MaybeSocket = socket(AF_UNIX, SOCK_STREAM, 0);
156+
if (MaybeSocket == -1) {
77157
#endif
78158
return llvm::make_error<StringError>(getLastSocketErrorCode(),
79159
"socket create failed");
80160
}
81161

82-
struct sockaddr_un Addr;
83-
memset(&Addr, 0, sizeof(Addr));
84-
Addr.sun_family = AF_UNIX;
85-
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
86-
87-
if (bind(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
88-
std::error_code Err = getLastSocketErrorCode();
89-
if (Err == std::errc::address_in_use)
90-
::close(MaybeWinsocket);
91-
return llvm::make_error<StringError>(Err, "Bind error");
162+
struct sockaddr_un Addr = setSocketAddr(SocketPath);
163+
if (::bind(MaybeSocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
164+
// Grab error code from call to ::bind before calling ::close
165+
std::error_code EC = getLastSocketErrorCode();
166+
::close(MaybeSocket);
167+
return llvm::make_error<StringError>(EC, "Bind error");
92168
}
93-
if (listen(MaybeWinsocket, MaxBacklog) == -1) {
169+
170+
// Mark socket as passive so incoming connections can be accepted
171+
if (::listen(MaybeSocket, MaxBacklog) == -1)
94172
return llvm::make_error<StringError>(getLastSocketErrorCode(),
95173
"Listen error");
96-
}
97-
int UnixSocket;
174+
175+
int Socket;
98176
#ifdef _WIN32
99-
UnixSocket = _open_osfhandle(MaybeWinsocket, 0);
177+
Socket = _open_osfhandle(MaybeWinsocket, 0);
100178
#else
101-
UnixSocket = MaybeWinsocket;
179+
Socket = MaybeSocket;
102180
#endif // _WIN32
103-
return ListeningSocket{UnixSocket, SocketPath};
181+
return ListeningSocket{Socket, SocketPath};
104182
}
105183

106-
Expected<std::unique_ptr<raw_socket_stream>> ListeningSocket::accept() {
184+
Expected<std::unique_ptr<raw_socket_stream>>
185+
ListeningSocket::accept(std::optional<std::chrono::microseconds> Timeout) {
186+
187+
int SelectStatus;
107188
int AcceptFD;
189+
108190
#ifdef _WIN32
109191
SOCKET WinServerSock = _get_osfhandle(FD);
110-
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
111-
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
192+
#endif
193+
194+
fd_set Readfds;
195+
if (Timeout.has_value()) {
196+
timeval TV = {0, Timeout.value().count()};
197+
FD_ZERO(&Readfds);
198+
#ifdef _WIN32
199+
FD_SET(WinServerSock, &Readfds);
200+
#else
201+
FD_SET(FD, &Readfds);
202+
#endif
203+
SelectStatus = ::select(FD + 1, &Readfds, NULL, NULL, &TV);
204+
} else {
205+
SelectStatus = 1;
206+
}
207+
208+
if (SelectStatus == -1)
209+
return llvm::make_error<StringError>(getLastSocketErrorCode(),
210+
"Select failed");
211+
212+
if (SelectStatus) {
213+
#ifdef _WIN32
214+
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
215+
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
112216
#else
113-
AcceptFD = ::accept(FD, NULL, NULL);
114-
#endif //_WIN32
217+
AcceptFD = ::accept(FD, NULL, NULL);
218+
#endif
219+
} else
220+
return llvm::make_error<StringError>(
221+
std::make_error_code(std::errc::timed_out), "Accept timed out");
222+
115223
if (AcceptFD == -1)
116224
return llvm::make_error<StringError>(getLastSocketErrorCode(),
117225
"Accept failed");
226+
118227
return std::make_unique<raw_socket_stream>(AcceptFD);
119228
}
120229

121-
ListeningSocket::~ListeningSocket() {
230+
void ListeningSocket::shutdown() {
122231
if (FD == -1)
123232
return;
124-
::close(FD);
125-
unlink(SocketPath.c_str());
233+
closeFD(FD);
234+
unlinkFile(SocketPath);
235+
FD = -1;
126236
}
127237

128-
static Expected<int> GetSocketFD(StringRef SocketPath) {
129-
#ifdef _WIN32
130-
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
131-
if (MaybeWinsocket == INVALID_SOCKET) {
132-
#else
133-
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
134-
if (MaybeWinsocket == -1) {
135-
#endif // _WIN32
136-
return llvm::make_error<StringError>(getLastSocketErrorCode(),
137-
"Create socket failed");
138-
}
238+
ListeningSocket::~ListeningSocket() { shutdown(); }
139239

140-
struct sockaddr_un Addr;
141-
memset(&Addr, 0, sizeof(Addr));
142-
Addr.sun_family = AF_UNIX;
143-
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
144-
145-
int status = connect(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr));
146-
if (status == -1) {
147-
return llvm::make_error<StringError>(getLastSocketErrorCode(),
148-
"Connect socket failed");
149-
}
150-
#ifdef _WIN32
151-
return _open_osfhandle(MaybeWinsocket, 0);
152-
#else
153-
return MaybeWinsocket;
154-
#endif // _WIN32
155-
}
240+
//===----------------------------------------------------------------------===//
241+
// raw_socket_stream
242+
//===----------------------------------------------------------------------===//
156243

157244
raw_socket_stream::raw_socket_stream(int SocketFD)
158245
: raw_fd_stream(SocketFD, true) {}
159246

160247
Expected<std::unique_ptr<raw_socket_stream>>
161-
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
248+
raw_socket_stream::createConnectedSocket(StringRef SocketPath) {
162249
#ifdef _WIN32
163250
WSABalancer _;
164251
#endif // _WIN32
165-
Expected<int> FD = GetSocketFD(SocketPath);
252+
Expected<int> FD = getSocketFD(SocketPath);
166253
if (!FD)
167254
return FD.takeError();
168255
return std::make_unique<raw_socket_stream>(*FD);
169256
}
170257

171258
raw_socket_stream::~raw_socket_stream() {}
172-

0 commit comments

Comments
 (0)