Skip to content

Commit abe3119

Browse files
author
Christian Riis
committed
Add raw_socket_stream
1 parent 67268da commit abe3119

File tree

4 files changed

+233
-2
lines changed

4 files changed

+233
-2
lines changed

llvm/include/llvm/Support/raw_ostream.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,32 @@ class raw_fd_stream : public raw_fd_ostream {
630630
static bool classof(const raw_ostream *OS);
631631
};
632632

633+
//===----------------------------------------------------------------------===//
634+
// Socket Streams
635+
//===----------------------------------------------------------------------===//
636+
637+
/// A raw stream for sockets reading/writing
638+
639+
class raw_socket_stream : public raw_fd_ostream {
640+
StringRef SocketPath;
641+
bool ShouldUnlink;
642+
643+
uint64_t current_pos() const override { return 0; }
644+
645+
public:
646+
int get_socket() {
647+
return get_fd();
648+
}
649+
650+
static int MakeServerSocket(StringRef SocketPath, unsigned int MaxBacklog, std::error_code &EC);
651+
652+
raw_socket_stream(int SocketFD, StringRef SockPath, std::error_code &EC);
653+
raw_socket_stream(StringRef SockPath, std::error_code &EC);
654+
~raw_socket_stream();
655+
656+
Expected<std::string> read_impl();
657+
};
658+
633659
//===----------------------------------------------------------------------===//
634660
// Output Stream Adaptors
635661
//===----------------------------------------------------------------------===//

llvm/lib/Support/raw_ostream.cpp

Lines changed: 151 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===--- raw_ostream.cpp - Implement the raw_ostream classes --------------===//
1+
//===--- raw_ostream.cpp - Implement the raw_ostream classes --------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -23,11 +23,17 @@
2323
#include "llvm/Support/NativeFormatting.h"
2424
#include "llvm/Support/Process.h"
2525
#include "llvm/Support/Program.h"
26+
#include "llvm/Support/Threading.h"
27+
#include "llvm/Support/Error.h"
2628
#include <algorithm>
2729
#include <cerrno>
2830
#include <cstdio>
2931
#include <sys/stat.h>
3032

33+
#include <sys/socket.h>
34+
#include <sys/un.h>
35+
#include <iostream>
36+
3137
// <fcntl.h> may provide O_BINARY.
3238
#if defined(HAVE_FCNTL_H)
3339
# include <fcntl.h>
@@ -58,6 +64,9 @@
5864
#include "llvm/Support/ConvertUTF.h"
5965
#include "llvm/Support/Signals.h"
6066
#include "llvm/Support/Windows/WindowsSupport.h"
67+
#include "raw_ostream.h"
68+
#include <afunix.h>
69+
#include <io.h>
6170
#endif
6271

6372
using namespace llvm;
@@ -644,7 +653,7 @@ raw_fd_ostream::raw_fd_ostream(int fd, bool shouldClose, bool unbuffered,
644653
// Check if this is a console device. This is not equivalent to isatty.
645654
IsWindowsConsole =
646655
::GetFileType((HANDLE)::_get_osfhandle(fd)) == FILE_TYPE_CHAR;
647-
#endif
656+
#endif // _WIN32
648657

649658
// Get the starting position.
650659
off_t loc = ::lseek(FD, 0, SEEK_CUR);
@@ -942,6 +951,146 @@ bool raw_fd_stream::classof(const raw_ostream *OS) {
942951
return OS->get_kind() == OStreamKind::OK_FDStream;
943952
}
944953

954+
//===----------------------------------------------------------------------===//
955+
// raw_socket_stream
956+
//===----------------------------------------------------------------------===//
957+
958+
int raw_socket_stream::MakeServerSocket(StringRef SocketPath, unsigned int MaxBacklog, std::error_code &EC) {
959+
960+
#ifdef _WIN32
961+
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
962+
#else
963+
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
964+
#endif // defined(_WIN32)
965+
966+
#ifdef _WIN32
967+
if (MaybeWinsocket == INVALID_SOCKET) {
968+
#else
969+
if (MaybeWinsocket == -1) {
970+
#endif // _WIN32
971+
std::string Msg = "socket create error" + std::string(strerror(errno));
972+
std::perror(Msg.c_str());
973+
std::cout << Msg << std::endl;
974+
EC = std::make_error_code(std::errc::connection_aborted);
975+
return -1;
976+
}
977+
978+
struct sockaddr_un Addr;
979+
memset(&Addr, 0, sizeof(Addr));
980+
Addr.sun_family = AF_UNIX;
981+
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
982+
983+
if (bind(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
984+
if (errno == EADDRINUSE) {
985+
::close(MaybeWinsocket);
986+
EC = std::make_error_code(std::errc::address_in_use);
987+
} else {
988+
EC = std::make_error_code(std::errc::inappropriate_io_control_operation);
989+
}
990+
return -1;
991+
}
992+
993+
if (listen(MaybeWinsocket, MaxBacklog) == -1) {
994+
EC = std::make_error_code(std::errc::address_not_available);
995+
return -1;
996+
}
997+
#ifdef _WIN32
998+
return _open_osfhandle(MaybeWinsocket, 0); // flags? https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/open-osfhandle?view=msvc-170
999+
#else
1000+
return MaybeWinsocket;
1001+
#endif // _WIN32
1002+
}
1003+
1004+
int GetSocketFD(StringRef SocketPath, std::error_code &EC) {
1005+
#ifdef _WIN32
1006+
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
1007+
if (MaybeWinsocket == INVALID_SOCKET) {
1008+
#else
1009+
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
1010+
if (MaybeWinsocket == -1) {
1011+
#endif // _WIN32
1012+
std::string Msg = "socket create error" + std::string(strerror(errno));
1013+
std::perror(Msg.c_str());
1014+
EC = std::make_error_code(std::errc::connection_aborted);
1015+
return -1;
1016+
}
1017+
1018+
struct sockaddr_un Addr;
1019+
memset(&Addr, 0, sizeof(Addr));
1020+
Addr.sun_family = AF_UNIX;
1021+
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
1022+
1023+
int status = connect(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr));
1024+
if (status == -1) {
1025+
std::string Msg = "socket connect error" + std::string(strerror(errno));
1026+
std::perror(Msg.c_str());
1027+
EC = std::make_error_code(std::errc::connection_aborted);
1028+
return -1;
1029+
}
1030+
#ifdef _WIN32
1031+
return _open_osfhandle(MaybeWinsocket, 0);
1032+
#else
1033+
return MaybeWinsocket;
1034+
#endif // _WIN32
1035+
}
1036+
1037+
static int ServerAccept(int FD) {
1038+
int AcceptFD;
1039+
#ifdef _WIN32
1040+
SOCKET WinServerSock = _get_osfhandle(FD);
1041+
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
1042+
AcceptFD = _open_osfhandle(WinAcceptSock, 0); // flags?
1043+
#else
1044+
AcceptFD = ::accept(FD, NULL, NULL);
1045+
#endif //_WIN32
1046+
return AcceptFD;
1047+
}
1048+
1049+
// Server
1050+
// Call raw_fd_ostream with ShouldClose=false
1051+
raw_socket_stream::raw_socket_stream(int SocketFD, StringRef SockPath, std::error_code &EC) : raw_fd_ostream(ServerAccept(SocketFD), true) {
1052+
SocketPath = SockPath;
1053+
ShouldUnlink = true;
1054+
}
1055+
1056+
// Client
1057+
raw_socket_stream::raw_socket_stream(StringRef SockPath, std::error_code &EC) : raw_fd_ostream(GetSocketFD(SockPath, EC), true, true, OStreamKind::OK_OStream ) {
1058+
SocketPath = SockPath;
1059+
ShouldUnlink = false;
1060+
}
1061+
1062+
raw_socket_stream::~raw_socket_stream() {
1063+
if (ShouldUnlink) {
1064+
unlink(SocketPath.str().c_str());
1065+
}
1066+
}
1067+
1068+
Expected<std::string> raw_socket_stream::read_impl() {
1069+
const size_t BUFFER_SIZE = 4096;
1070+
std::vector<char> Buffer(BUFFER_SIZE);
1071+
1072+
int Socket = get_socket();
1073+
assert(Socket >= 0 && "Socket not found.");
1074+
1075+
ssize_t n;
1076+
#ifdef _WIN32
1077+
SOCKET MaybeWinsocket = _get_osfhandle(Socket);
1078+
#else
1079+
int MaybeWinsocket = Socket;
1080+
#endif // _WIN32
1081+
n = ::read(MaybeWinsocket, Buffer.data(), Buffer.size());
1082+
1083+
if (n < 0) {
1084+
std::string Msg = "Buffer read error: " + std::string(strerror(errno));
1085+
return llvm::make_error<StringError>(Msg, inconvertibleErrorCode());
1086+
}
1087+
1088+
if (n == 0) {
1089+
return llvm::make_error<StringError>("EOF", inconvertibleErrorCode());
1090+
}
1091+
return std::string(Buffer.data());
1092+
}
1093+
9451094
//===----------------------------------------------------------------------===//
9461095
// raw_string_ostream
9471096
//===----------------------------------------------------------------------===//

llvm/unittests/Support/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ add_llvm_unittest(SupportTests
103103
raw_ostream_test.cpp
104104
raw_pwrite_stream_test.cpp
105105
raw_sha1_ostream_test.cpp
106+
raw_socket_stream_test.cpp
106107
xxhashTest.cpp
107108

108109
DEPENDS
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include <stdlib.h>
2+
#include <iostream>
3+
#include <future>
4+
#include "llvm/ADT/SmallString.h"
5+
#include "llvm/Config/llvm-config.h"
6+
#include "llvm/Support/Casting.h"
7+
#include "llvm/Support/FileSystem.h"
8+
#include "llvm/Support/FileUtilities.h"
9+
#include "llvm/Support/raw_ostream.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace llvm;
13+
14+
namespace {
15+
16+
TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
17+
18+
SmallString<100> SocketPath("/tmp/test_raw_socket_stream.sock");
19+
std::error_code ECServer, ECClient;
20+
21+
int ServerFD = raw_socket_stream::MakeServerSocket(SocketPath, 3, ECServer);
22+
23+
raw_socket_stream Client(SocketPath, ECClient);
24+
EXPECT_TRUE(!ECClient);
25+
26+
raw_socket_stream Client2(SocketPath, ECClient);
27+
28+
raw_socket_stream Server(ServerFD, SocketPath, ECServer);
29+
EXPECT_TRUE(!ECServer);
30+
31+
Client << "01234567";
32+
Client.flush();
33+
34+
Client2 << "abcdefgh";
35+
Client2.flush();
36+
37+
Expected<std::string> from_client = Server.read_impl();
38+
39+
if (auto E = from_client.takeError()) {
40+
return; // FIXME: Do something.
41+
}
42+
EXPECT_EQ("01234567", (*from_client));
43+
44+
Server << "76543210";
45+
Server.flush();
46+
47+
Expected<std::string> from_server = Client.read_impl();
48+
if (auto E = from_server.takeError()) {
49+
return;
50+
// YIKES! 😩
51+
}
52+
EXPECT_EQ("76543210", (*from_server));
53+
54+
}
55+
} // namespace

0 commit comments

Comments
 (0)