Skip to content

Lift raw_socket_stream implementation out into own files #75653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 0 additions & 49 deletions llvm/include/llvm/Support/raw_ostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DataTypes.h"
#include "llvm/Support/Threading.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -633,54 +632,6 @@ class raw_fd_stream : public raw_fd_ostream {
static bool classof(const raw_ostream *OS);
};

//===----------------------------------------------------------------------===//
// Socket Streams
//===----------------------------------------------------------------------===//

/// A raw stream for sockets reading/writing

class raw_socket_stream;

// Make sure that calls to WSAStartup and WSACleanup are balanced.
#ifdef _WIN32
class WSABalancer {
public:
WSABalancer();
~WSABalancer();
};
#endif // _WIN32

class ListeningSocket {
int FD;
std::string SocketPath;
ListeningSocket(int SocketFD, StringRef SocketPath);
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32

public:
static Expected<ListeningSocket> createUnix(
StringRef SocketPath,
int MaxBacklog = llvm::hardware_concurrency().compute_thread_count());
Expected<std::unique_ptr<raw_socket_stream>> accept();
ListeningSocket(ListeningSocket &&LS);
~ListeningSocket();
};
class raw_socket_stream : public raw_fd_stream {
uint64_t current_pos() const override { return 0; }
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32

public:
raw_socket_stream(int SocketFD);
/// Create a \p raw_socket_stream connected to the Unix domain socket at \p
/// SocketPath.
static Expected<std::unique_ptr<raw_socket_stream>>
createConnectedUnix(StringRef SocketPath);
~raw_socket_stream();
};

//===----------------------------------------------------------------------===//
// Output Stream Adaptors
//===----------------------------------------------------------------------===//
Expand Down
66 changes: 66 additions & 0 deletions llvm/include/llvm/Support/raw_socket_stream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===-- llvm/Support/raw_socket_stream.h - Socket streams --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains raw_ostream implementations for streams to communicate
// via UNIX sockets
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_SUPPORT_RAW_SOCKET_STREAM_H
#define LLVM_SUPPORT_RAW_SOCKET_STREAM_H

#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"

namespace llvm {

class raw_socket_stream;

// Make sure that calls to WSAStartup and WSACleanup are balanced.
#ifdef _WIN32
class WSABalancer {
public:
WSABalancer();
~WSABalancer();
};
#endif // _WIN32

class ListeningSocket {
int FD;
std::string SocketPath;
ListeningSocket(int SocketFD, StringRef SocketPath);
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32

public:
static Expected<ListeningSocket> createUnix(
StringRef SocketPath,
int MaxBacklog = llvm::hardware_concurrency().compute_thread_count());
Expected<std::unique_ptr<raw_socket_stream>> accept();
ListeningSocket(ListeningSocket &&LS);
~ListeningSocket();
};
class raw_socket_stream : public raw_fd_stream {
uint64_t current_pos() const override { return 0; }
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32

public:
raw_socket_stream(int SocketFD);
/// Create a \p raw_socket_stream connected to the Unix domain socket at \p
/// SocketPath.
static Expected<std::unique_ptr<raw_socket_stream>>
createConnectedUnix(StringRef SocketPath);
~raw_socket_stream();
};

} // end namespace llvm

#endif
1 change: 1 addition & 0 deletions llvm/lib/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ add_llvm_component_library(LLVMSupport
YAMLTraits.cpp
raw_os_ostream.cpp
raw_ostream.cpp
raw_socket_stream.cpp
regcomp.c
regerror.c
regexec.c
Expand Down
163 changes: 1 addition & 162 deletions llvm/lib/Support/raw_ostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "llvm/Support/AutoConvert.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Duration.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Format.h"
Expand All @@ -25,17 +24,11 @@
#include "llvm/Support/NativeFormatting.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/Threading.h"
#include <algorithm>
#include <cerrno>
#include <cstdio>
#include <sys/stat.h>

#ifndef _WIN32
#include <sys/socket.h>
#include <sys/un.h>
#endif // _WIN32

// <fcntl.h> may provide O_BINARY.
#if defined(HAVE_FCNTL_H)
# include <fcntl.h>
Expand Down Expand Up @@ -66,13 +59,6 @@
#include "llvm/Support/ConvertUTF.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/Windows/WindowsSupport.h"
// winsock2.h must be included before afunix.h. Briefly turn off clang-format to
// avoid error.
// clang-format off
#include <winsock2.h>
#include <afunix.h>
// clang-format on
#include <io.h>
#endif

using namespace llvm;
Expand Down Expand Up @@ -659,7 +645,7 @@ raw_fd_ostream::raw_fd_ostream(int fd, bool shouldClose, bool unbuffered,
// Check if this is a console device. This is not equivalent to isatty.
IsWindowsConsole =
::GetFileType((HANDLE)::_get_osfhandle(fd)) == FILE_TYPE_CHAR;
#endif // _WIN32
#endif

// Get the starting position.
off_t loc = ::lseek(FD, 0, SEEK_CUR);
Expand Down Expand Up @@ -968,153 +954,6 @@ bool raw_fd_stream::classof(const raw_ostream *OS) {
return OS->get_kind() == OStreamKind::OK_FDStream;
}

//===----------------------------------------------------------------------===//
// raw_socket_stream
//===----------------------------------------------------------------------===//

#ifdef _WIN32
WSABalancer::WSABalancer() {
WSADATA WsaData;
::memset(&WsaData, 0, sizeof(WsaData));
if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) {
llvm::report_fatal_error("WSAStartup failed");
}
}

WSABalancer::~WSABalancer() { WSACleanup(); }

#endif // _WIN32

static std::error_code getLastSocketErrorCode() {
#ifdef _WIN32
return std::error_code(::WSAGetLastError(), std::system_category());
#else
return std::error_code(errno, std::system_category());
#endif
}

ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath)
: FD(SocketFD), SocketPath(SocketPath) {}

ListeningSocket::ListeningSocket(ListeningSocket &&LS)
: FD(LS.FD), SocketPath(LS.SocketPath) {
LS.FD = -1;
}

Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
int MaxBacklog) {

#ifdef _WIN32
WSABalancer _;
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == INVALID_SOCKET) {
#else
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == -1) {
#endif
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"socket create failed");
}

struct sockaddr_un Addr;
memset(&Addr, 0, sizeof(Addr));
Addr.sun_family = AF_UNIX;
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);

if (bind(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
std::error_code Err = getLastSocketErrorCode();
if (Err == std::errc::address_in_use)
::close(MaybeWinsocket);
return llvm::make_error<StringError>(Err, "Bind error");
}
if (listen(MaybeWinsocket, MaxBacklog) == -1) {
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Listen error");
}
int UnixSocket;
#ifdef _WIN32
UnixSocket = _open_osfhandle(MaybeWinsocket, 0);
#else
UnixSocket = MaybeWinsocket;
#endif // _WIN32
return ListeningSocket{UnixSocket, SocketPath};
}

Expected<std::unique_ptr<raw_socket_stream>> ListeningSocket::accept() {
int AcceptFD;
#ifdef _WIN32
SOCKET WinServerSock = _get_osfhandle(FD);
SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
#else
AcceptFD = ::accept(FD, NULL, NULL);
#endif //_WIN32
if (AcceptFD == -1)
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Accept failed");
return std::make_unique<raw_socket_stream>(AcceptFD);
}

ListeningSocket::~ListeningSocket() {
if (FD == -1)
return;
::close(FD);
unlink(SocketPath.c_str());
}

static Expected<int> GetSocketFD(StringRef SocketPath) {
#ifdef _WIN32
SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == INVALID_SOCKET) {
#else
int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0);
if (MaybeWinsocket == -1) {
#endif // _WIN32
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Create socket failed");
}

struct sockaddr_un Addr;
memset(&Addr, 0, sizeof(Addr));
Addr.sun_family = AF_UNIX;
strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);

int status = connect(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr));
if (status == -1) {
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Connect socket failed");
}
#ifdef _WIN32
return _open_osfhandle(MaybeWinsocket, 0);
#else
return MaybeWinsocket;
#endif // _WIN32
}

raw_socket_stream::raw_socket_stream(int SocketFD)
: raw_fd_stream(SocketFD, true) {}

Expected<std::unique_ptr<raw_socket_stream>>
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
#ifdef _WIN32
WSABalancer _;
#endif // _WIN32
Expected<int> FD = GetSocketFD(SocketPath);
if (!FD)
return FD.takeError();
return std::make_unique<raw_socket_stream>(*FD);
}

raw_socket_stream::~raw_socket_stream() {}

//===----------------------------------------------------------------------===//
// raw_string_ostream
//===----------------------------------------------------------------------===//

void raw_string_ostream::write_impl(const char *Ptr, size_t Size) {
OS.append(Ptr, Size);
}

//===----------------------------------------------------------------------===//
// raw_svector_ostream
//===----------------------------------------------------------------------===//
Expand Down
Loading