Skip to content

[libc] Use RAII alloc in gpu rpc printf impl #110352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions libc/utils/gpu/server/rpc_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,19 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
"Incorrect maximum port count");

namespace {
struct TempStorage {
char *alloc(size_t size) {
storage.emplace_back(std::make_unique<char[]>(size));
return storage.back().get();
}

std::vector<std::unique_ptr<char[]>> storage;
};
} // namespace

template <bool packed, uint32_t lane_size>
void handle_printf(rpc::Server::Port &port) {
static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
FILE *files[lane_size] = {nullptr};
// Get the appropriate output stream to use.
if (port.get_opcode() == RPC_PRINTF_TO_STREAM ||
Expand All @@ -65,7 +76,7 @@ void handle_printf(rpc::Server::Port &port) {

// Recieve the format string and arguments from the client.
port.recv_n(format, format_sizes,
[&](uint64_t size) { return new char[size]; });
[&](uint64_t size) { return temp_storage.alloc(size); });

// Parse the format string to get the expected size of the buffer.
for (uint32_t lane = 0; lane < lane_size; ++lane) {
Expand All @@ -88,7 +99,8 @@ void handle_printf(rpc::Server::Port &port) {
port.send([&](rpc::Buffer *buffer, uint32_t id) {
buffer->data[0] = args_sizes[id];
});
port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });
port.recv_n(args, args_sizes,
[&](uint64_t size) { return temp_storage.alloc(size); });

// Identify any arguments that are actually pointers to strings on the client.
// Additionally we want to determine how much buffer space we need to print.
Expand Down Expand Up @@ -137,7 +149,8 @@ void handle_printf(rpc::Server::Port &port) {
});
uint64_t str_sizes[lane_size] = {0};
void *strs[lane_size] = {nullptr};
port.recv_n(strs, str_sizes, [](uint64_t size) { return new char[size]; });
port.recv_n(strs, str_sizes,
[&](uint64_t size) { return temp_storage.alloc(size); });
for (uint32_t lane = 0; lane < lane_size; ++lane) {
if (!strs[lane])
continue;
Expand All @@ -149,13 +162,12 @@ void handle_printf(rpc::Server::Port &port) {

// Perform the final formatting and printing using the LLVM C library printf.
int results[lane_size] = {0};
std::vector<void *> to_be_deleted;
for (uint32_t lane = 0; lane < lane_size; ++lane) {
if (!format[lane])
continue;

std::unique_ptr<char[]> buffer(new char[buffer_size[lane]]);
WriteBuffer wb(buffer.get(), buffer_size[lane]);
char *buffer = temp_storage.alloc(buffer_size[lane]);
WriteBuffer wb(buffer, buffer_size[lane]);
Writer writer(&wb);

internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
Expand All @@ -173,7 +185,6 @@ void handle_printf(rpc::Server::Port &port) {
if (cur_section.has_conv && cur_section.conv_name == 's') {
if (!copied_strs[lane].empty()) {
cur_section.conv_val_ptr = copied_strs[lane].back();
to_be_deleted.push_back(copied_strs[lane].back());
copied_strs[lane].pop_back();
} else {
cur_section.conv_val_ptr = nullptr;
Expand All @@ -188,8 +199,7 @@ void handle_printf(rpc::Server::Port &port) {
}
}

results[lane] =
fwrite(buffer.get(), 1, writer.get_chars_written(), files[lane]);
results[lane] = fwrite(buffer, 1, writer.get_chars_written(), files[lane]);
if (results[lane] != writer.get_chars_written() || ret == -1)
results[lane] = -1;
}
Expand All @@ -199,24 +209,9 @@ void handle_printf(rpc::Server::Port &port) {
port.send([&](rpc::Buffer *buffer, uint32_t id) {
buffer->data[0] = static_cast<uint64_t>(results[id]);
buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr);
delete[] reinterpret_cast<char *>(format[id]);
delete[] reinterpret_cast<char *>(args[id]);
});
for (void *ptr : to_be_deleted)
delete[] reinterpret_cast<char *>(ptr);
}

namespace {
struct TempStorage {
char *alloc(size_t size) {
storage.emplace_back(std::make_unique<char[]>(size));
return storage.back().get();
}

std::vector<std::unique_ptr<char[]>> storage;
};
} // namespace

template <uint32_t lane_size>
rpc_status_t handle_server_impl(
rpc::Server &server,
Expand Down Expand Up @@ -381,13 +376,13 @@ rpc_status_t handle_server_impl(
case RPC_PRINTF_TO_STREAM_PACKED:
case RPC_PRINTF_TO_STDOUT_PACKED:
case RPC_PRINTF_TO_STDERR_PACKED: {
handle_printf<true, lane_size>(*port);
handle_printf<true, lane_size>(*port, temp_storage);
break;
}
case RPC_PRINTF_TO_STREAM:
case RPC_PRINTF_TO_STDOUT:
case RPC_PRINTF_TO_STDERR: {
handle_printf<false, lane_size>(*port);
handle_printf<false, lane_size>(*port, temp_storage);
break;
}
case RPC_REMOVE: {
Expand Down
Loading