Skip to content

Commit bbe79a8

Browse files
authored
[libc] Use RAII alloc in gpu rpc printf impl (#110352)
1 parent 641b4d5 commit bbe79a8

File tree

1 file changed

+22
-27
lines changed

1 file changed

+22
-27
lines changed

libc/utils/gpu/server/rpc_server.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,19 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
4242
static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
4343
"Incorrect maximum port count");
4444

45+
namespace {
46+
struct TempStorage {
47+
char *alloc(size_t size) {
48+
storage.emplace_back(std::make_unique<char[]>(size));
49+
return storage.back().get();
50+
}
51+
52+
std::vector<std::unique_ptr<char[]>> storage;
53+
};
54+
} // namespace
55+
4556
template <bool packed, uint32_t lane_size>
46-
void handle_printf(rpc::Server::Port &port) {
57+
static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
4758
FILE *files[lane_size] = {nullptr};
4859
// Get the appropriate output stream to use.
4960
if (port.get_opcode() == RPC_PRINTF_TO_STREAM ||
@@ -65,7 +76,7 @@ void handle_printf(rpc::Server::Port &port) {
6576

6677
// Recieve the format string and arguments from the client.
6778
port.recv_n(format, format_sizes,
68-
[&](uint64_t size) { return new char[size]; });
79+
[&](uint64_t size) { return temp_storage.alloc(size); });
6980

7081
// Parse the format string to get the expected size of the buffer.
7182
for (uint32_t lane = 0; lane < lane_size; ++lane) {
@@ -88,7 +99,8 @@ void handle_printf(rpc::Server::Port &port) {
8899
port.send([&](rpc::Buffer *buffer, uint32_t id) {
89100
buffer->data[0] = args_sizes[id];
90101
});
91-
port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });
102+
port.recv_n(args, args_sizes,
103+
[&](uint64_t size) { return temp_storage.alloc(size); });
92104

93105
// Identify any arguments that are actually pointers to strings on the client.
94106
// Additionally we want to determine how much buffer space we need to print.
@@ -137,7 +149,8 @@ void handle_printf(rpc::Server::Port &port) {
137149
});
138150
uint64_t str_sizes[lane_size] = {0};
139151
void *strs[lane_size] = {nullptr};
140-
port.recv_n(strs, str_sizes, [](uint64_t size) { return new char[size]; });
152+
port.recv_n(strs, str_sizes,
153+
[&](uint64_t size) { return temp_storage.alloc(size); });
141154
for (uint32_t lane = 0; lane < lane_size; ++lane) {
142155
if (!strs[lane])
143156
continue;
@@ -149,13 +162,12 @@ void handle_printf(rpc::Server::Port &port) {
149162

150163
// Perform the final formatting and printing using the LLVM C library printf.
151164
int results[lane_size] = {0};
152-
std::vector<void *> to_be_deleted;
153165
for (uint32_t lane = 0; lane < lane_size; ++lane) {
154166
if (!format[lane])
155167
continue;
156168

157-
std::unique_ptr<char[]> buffer(new char[buffer_size[lane]]);
158-
WriteBuffer wb(buffer.get(), buffer_size[lane]);
169+
char *buffer = temp_storage.alloc(buffer_size[lane]);
170+
WriteBuffer wb(buffer, buffer_size[lane]);
159171
Writer writer(&wb);
160172

161173
internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
@@ -173,7 +185,6 @@ void handle_printf(rpc::Server::Port &port) {
173185
if (cur_section.has_conv && cur_section.conv_name == 's') {
174186
if (!copied_strs[lane].empty()) {
175187
cur_section.conv_val_ptr = copied_strs[lane].back();
176-
to_be_deleted.push_back(copied_strs[lane].back());
177188
copied_strs[lane].pop_back();
178189
} else {
179190
cur_section.conv_val_ptr = nullptr;
@@ -188,8 +199,7 @@ void handle_printf(rpc::Server::Port &port) {
188199
}
189200
}
190201

191-
results[lane] =
192-
fwrite(buffer.get(), 1, writer.get_chars_written(), files[lane]);
202+
results[lane] = fwrite(buffer, 1, writer.get_chars_written(), files[lane]);
193203
if (results[lane] != writer.get_chars_written() || ret == -1)
194204
results[lane] = -1;
195205
}
@@ -199,24 +209,9 @@ void handle_printf(rpc::Server::Port &port) {
199209
port.send([&](rpc::Buffer *buffer, uint32_t id) {
200210
buffer->data[0] = static_cast<uint64_t>(results[id]);
201211
buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr);
202-
delete[] reinterpret_cast<char *>(format[id]);
203-
delete[] reinterpret_cast<char *>(args[id]);
204212
});
205-
for (void *ptr : to_be_deleted)
206-
delete[] reinterpret_cast<char *>(ptr);
207213
}
208214

209-
namespace {
210-
struct TempStorage {
211-
char *alloc(size_t size) {
212-
storage.emplace_back(std::make_unique<char[]>(size));
213-
return storage.back().get();
214-
}
215-
216-
std::vector<std::unique_ptr<char[]>> storage;
217-
};
218-
} // namespace
219-
220215
template <uint32_t lane_size>
221216
rpc_status_t handle_server_impl(
222217
rpc::Server &server,
@@ -381,13 +376,13 @@ rpc_status_t handle_server_impl(
381376
case RPC_PRINTF_TO_STREAM_PACKED:
382377
case RPC_PRINTF_TO_STDOUT_PACKED:
383378
case RPC_PRINTF_TO_STDERR_PACKED: {
384-
handle_printf<true, lane_size>(*port);
379+
handle_printf<true, lane_size>(*port, temp_storage);
385380
break;
386381
}
387382
case RPC_PRINTF_TO_STREAM:
388383
case RPC_PRINTF_TO_STDOUT:
389384
case RPC_PRINTF_TO_STDERR: {
390-
handle_printf<false, lane_size>(*port);
385+
handle_printf<false, lane_size>(*port, temp_storage);
391386
break;
392387
}
393388
case RPC_REMOVE: {

0 commit comments

Comments
 (0)