Skip to content

Commit 791b279

Browse files
authored
[libc] Change the puts implementation on the GPU (#67189)
Summary: Normally, the implementation of `puts` simply writes a second newline charcter after printing the first string. However, because the GPU does everything in batches of the SIMT group size, this will end up with very poor output where you get the strings printed and then 1-64 newline characters all in a row. Optimizations like to turn `printf` calls into `puts` so it's a good idea to make this produce the expected output. The least invasive way I could do this was to add a new opcode. It's a little bloated, but it avoids an unneccessary and slow send operation to configure this.
1 parent e3d6a3a commit 791b279

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

libc/include/llvm-libc-types/rpc_opcodes_t.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@ typedef enum : unsigned short {
1515
RPC_WRITE_TO_STDOUT = 2,
1616
RPC_WRITE_TO_STDERR = 3,
1717
RPC_WRITE_TO_STREAM = 4,
18-
RPC_READ_FROM_STREAM = 5,
19-
RPC_OPEN_FILE = 6,
20-
RPC_CLOSE_FILE = 7,
21-
RPC_MALLOC = 8,
22-
RPC_FREE = 9,
23-
RPC_HOST_CALL = 10,
24-
RPC_ABORT = 11,
25-
RPC_FEOF = 12,
26-
RPC_FERROR = 13,
27-
RPC_CLEARERR = 14,
18+
RPC_WRITE_TO_STDOUT_NEWLINE = 5,
19+
RPC_READ_FROM_STREAM = 6,
20+
RPC_OPEN_FILE = 7,
21+
RPC_CLOSE_FILE = 8,
22+
RPC_MALLOC = 9,
23+
RPC_FREE = 10,
24+
RPC_HOST_CALL = 11,
25+
RPC_ABORT = 12,
26+
RPC_FEOF = 13,
27+
RPC_FERROR = 14,
28+
RPC_CLEARERR = 15,
2829
} rpc_opcode_t;
2930

3031
#endif // __LLVM_LIBC_TYPES_RPC_OPCODE_H__

libc/src/stdio/gpu/puts.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ namespace __llvm_libc {
1717

1818
LLVM_LIBC_FUNCTION(int, puts, (const char *__restrict str)) {
1919
cpp::string_view str_view(str);
20-
auto written = file::write(stdout, str, str_view.size());
21-
if (written != str_view.size())
22-
return EOF;
23-
written = file::write(stdout, "\n", 1);
24-
if (written != 1)
20+
auto written = file::write_impl<RPC_WRITE_TO_STDOUT_NEWLINE>(stdout, str,
21+
str_view.size());
22+
if (written != str_view.size() + 1)
2523
return EOF;
2624
return 0;
2725
}

libc/utils/gpu/server/rpc_server.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,27 @@ struct Server {
5959
switch (port->get_opcode()) {
6060
case RPC_WRITE_TO_STREAM:
6161
case RPC_WRITE_TO_STDERR:
62-
case RPC_WRITE_TO_STDOUT: {
62+
case RPC_WRITE_TO_STDOUT:
63+
case RPC_WRITE_TO_STDOUT_NEWLINE: {
6364
uint64_t sizes[lane_size] = {0};
6465
void *strs[lane_size] = {nullptr};
6566
FILE *files[lane_size] = {nullptr};
66-
if (port->get_opcode() == RPC_WRITE_TO_STREAM)
67+
if (port->get_opcode() == RPC_WRITE_TO_STREAM) {
6768
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
6869
files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
6970
});
71+
} else if (port->get_opcode() == RPC_WRITE_TO_STDERR) {
72+
std::fill(files, files + lane_size, stderr);
73+
} else {
74+
std::fill(files, files + lane_size, stdout);
75+
}
76+
7077
port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
7178
port->send([&](rpc::Buffer *buffer, uint32_t id) {
72-
FILE *file =
73-
port->get_opcode() == RPC_WRITE_TO_STDOUT
74-
? stdout
75-
: (port->get_opcode() == RPC_WRITE_TO_STDERR ? stderr
76-
: files[id]);
77-
uint64_t ret = fwrite(strs[id], 1, sizes[id], file);
78-
std::memcpy(buffer->data, &ret, sizeof(uint64_t));
79+
buffer->data[0] = fwrite(strs[id], 1, sizes[id], files[id]);
80+
if (port->get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE &&
81+
buffer->data[0] == sizes[id])
82+
buffer->data[0] += fwrite("\n", 1, 1, files[id]);
7983
delete[] reinterpret_cast<uint8_t *>(strs[id]);
8084
});
8185
break;

openmp/libomptarget/test/libc/puts.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ int main() {
3131
// CHECK: PASS
3232
#pragma omp target teams num_teams(4)
3333
#pragma omp parallel num_threads(2)
34-
{ fputs("PASS\n", stdout); }
34+
{ puts("PASS\n"); }
3535
}

0 commit comments

Comments
 (0)