Skip to content

Commit c54d8d4

Browse files
committed
[libc] Implement (v|f)printf on the GPU
Summary: This patch implements the `printf` family of functions on the GPU using the new variadic support. This patch adapts the old handling in the `rpc_fprintf` placeholder, but adds an extra RPC call to get the size of the buffer to copy. This prevents the GPU from needing to parse the string. While it's theoretically possible for the pass to know the size of the struct, it's prohibitively difficult to do while maintaining ABI compatibility with NVIDIA's varargs. Depends on #96015.
1 parent 25c2bad commit c54d8d4

File tree

16 files changed

+318
-67
lines changed

16 files changed

+318
-67
lines changed

libc/config/gpu/entrypoints.txt

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
if(LIBC_TARGET_ARCHITECTURE_IS_AMDGPU)
2-
set(extra_entrypoints
3-
# stdio.h entrypoints
4-
libc.src.stdio.sprintf
5-
libc.src.stdio.snprintf
6-
libc.src.stdio.vsprintf
7-
libc.src.stdio.vsnprintf
8-
)
9-
endif()
10-
111
set(TARGET_LIBC_ENTRYPOINTS
122
# assert.h entrypoints
133
libc.src.assert.__assert_fail
@@ -185,7 +175,14 @@ set(TARGET_LIBC_ENTRYPOINTS
185175
libc.src.errno.errno
186176

187177
# stdio.h entrypoints
188-
${extra_entrypoints}
178+
libc.src.stdio.printf
179+
libc.src.stdio.vprintf
180+
libc.src.stdio.fprintf
181+
libc.src.stdio.vfprintf
182+
libc.src.stdio.sprintf
183+
libc.src.stdio.snprintf
184+
libc.src.stdio.vsprintf
185+
libc.src.stdio.vsnprintf
189186
libc.src.stdio.feof
190187
libc.src.stdio.ferror
191188
libc.src.stdio.fseek

libc/src/__support/arg_list.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class MockArgList {
5454
}
5555

5656
template <class T> LIBC_INLINE T next_var() {
57-
++arg_counter;
57+
arg_counter =
58+
((arg_counter + alignof(T) - 1) / alignof(T)) * alignof(T) + sizeof(T);
5859
return T(arg_counter);
5960
}
6061

libc/src/gpu/rpc_fprintf.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ int fprintf_impl(::FILE *__restrict file, const char *__restrict format,
2929
}
3030

3131
port.send_n(format, format_size);
32+
port.recv([&](rpc::Buffer *buffer) {
33+
args_size = static_cast<size_t>(buffer->data[0]);
34+
});
3235
port.send_n(args, args_size);
3336

3437
uint32_t ret = 0;
@@ -50,7 +53,7 @@ int fprintf_impl(::FILE *__restrict file, const char *__restrict format,
5053
return ret;
5154
}
5255

53-
// TODO: This is a stand-in function that uses a struct pointer and size in
56+
// TODO: Delete this and port OpenMP to use `printf`.
5457
// place of varargs. Once varargs support is added we will use that to
5558
// implement the real version.
5659
LLVM_LIBC_FUNCTION(int, rpc_fprintf,

libc/src/stdio/CMakeLists.txt

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,6 @@ add_entrypoint_object(
159159
libc.src.stdio.printf_core.writer
160160
)
161161

162-
add_entrypoint_object(
163-
fprintf
164-
SRCS
165-
fprintf.cpp
166-
HDRS
167-
fprintf.h
168-
DEPENDS
169-
libc.src.__support.arg_list
170-
libc.src.stdio.printf_core.vfprintf_internal
171-
)
172-
173162
add_entrypoint_object(
174163
vsprintf
175164
SRCS
@@ -192,17 +181,6 @@ add_entrypoint_object(
192181
libc.src.stdio.printf_core.writer
193182
)
194183

195-
add_entrypoint_object(
196-
vfprintf
197-
SRCS
198-
vfprintf.cpp
199-
HDRS
200-
vfprintf.h
201-
DEPENDS
202-
libc.src.__support.arg_list
203-
libc.src.stdio.printf_core.vfprintf_internal
204-
)
205-
206184
add_stdio_entrypoint_object(
207185
fileno
208186
SRCS
@@ -261,6 +239,7 @@ add_stdio_entrypoint_object(fputc)
261239
add_stdio_entrypoint_object(putc)
262240
add_stdio_entrypoint_object(putchar)
263241
add_stdio_entrypoint_object(printf)
242+
add_stdio_entrypoint_object(fprintf)
264243
add_stdio_entrypoint_object(fgetc)
265244
add_stdio_entrypoint_object(fgetc_unlocked)
266245
add_stdio_entrypoint_object(getc)
@@ -273,3 +252,4 @@ add_stdio_entrypoint_object(stdin)
273252
add_stdio_entrypoint_object(stdout)
274253
add_stdio_entrypoint_object(stderr)
275254
add_stdio_entrypoint_object(vprintf)
255+
add_stdio_entrypoint_object(vfprintf)

libc/src/stdio/generic/CMakeLists.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,31 @@ add_entrypoint_object(
396396
${printf_deps}
397397
)
398398

399+
add_entrypoint_object(
400+
fprintf
401+
SRCS
402+
fprintf.cpp
403+
HDRS
404+
../fprintf.h
405+
DEPENDS
406+
libc.src.__support.arg_list
407+
libc.src.stdio.printf_core.vfprintf_internal
408+
${printf_deps}
409+
)
410+
411+
add_entrypoint_object(
412+
vfprintf
413+
SRCS
414+
vfprintf.cpp
415+
HDRS
416+
../vfprintf.h
417+
DEPENDS
418+
libc.src.__support.arg_list
419+
libc.src.stdio.printf_core.vfprintf_internal
420+
${printf_deps}
421+
)
422+
423+
399424
add_entrypoint_object(
400425
fgets
401426
SRCS
File renamed without changes.
File renamed without changes.

libc/src/stdio/gpu/CMakeLists.txt

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ add_header_library(
1010
.stderr
1111
)
1212

13+
add_header_library(
14+
vfprintf_utils
15+
HDRS
16+
vfprintf_utils.h
17+
DEPENDS
18+
.gpu_file
19+
)
20+
1321
add_entrypoint_object(
1422
feof
1523
SRCS
@@ -262,6 +270,46 @@ add_entrypoint_object(
262270
.gpu_file
263271
)
264272

273+
add_entrypoint_object(
274+
printf
275+
SRCS
276+
printf.cpp
277+
HDRS
278+
../printf.h
279+
DEPENDS
280+
.vfprintf_utils
281+
)
282+
283+
add_entrypoint_object(
284+
vprintf
285+
SRCS
286+
vprintf.cpp
287+
HDRS
288+
../vprintf.h
289+
DEPENDS
290+
.vfprintf_utils
291+
)
292+
293+
add_entrypoint_object(
294+
fprintf
295+
SRCS
296+
fprintf.cpp
297+
HDRS
298+
../fprintf.h
299+
DEPENDS
300+
.vfprintf_utils
301+
)
302+
303+
add_entrypoint_object(
304+
vfprintf
305+
SRCS
306+
vfprintf.cpp
307+
HDRS
308+
../vfprintf.h
309+
DEPENDS
310+
.vfprintf_utils
311+
)
312+
265313
add_entrypoint_object(
266314
stdin
267315
SRCS

libc/src/stdio/gpu/fprintf.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===-- GPU Implementation of fprintf -------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/stdio/fprintf.h"
10+
11+
#include "src/__support/CPP/string_view.h"
12+
#include "src/__support/arg_list.h"
13+
#include "src/errno/libc_errno.h"
14+
#include "src/stdio/gpu/vfprintf_utils.h"
15+
16+
#include <stdio.h>
17+
18+
namespace LIBC_NAMESPACE {
19+
20+
LLVM_LIBC_FUNCTION(int, fprintf,
21+
(::FILE *__restrict stream, const char *__restrict format,
22+
...)) {
23+
va_list vlist;
24+
va_start(vlist, format);
25+
cpp::string_view str_view(format);
26+
int ret_val =
27+
file::vfprintf_internal(stream, format, str_view.size() + 1, vlist);
28+
va_end(vlist);
29+
return ret_val;
30+
}
31+
32+
} // namespace LIBC_NAMESPACE

libc/src/stdio/gpu/printf.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===-- GPU Implementation of printf --------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/stdio/printf.h"
10+
11+
#include "src/__support/CPP/string_view.h"
12+
#include "src/__support/arg_list.h"
13+
#include "src/errno/libc_errno.h"
14+
#include "src/stdio/gpu/vfprintf_utils.h"
15+
16+
#include <stdio.h>
17+
18+
namespace LIBC_NAMESPACE {
19+
20+
LLVM_LIBC_FUNCTION(int, printf, (const char *__restrict format, ...)) {
21+
va_list vlist;
22+
va_start(vlist, format);
23+
cpp::string_view str_view(format);
24+
int ret_val =
25+
file::vfprintf_internal(stdout, format, str_view.size() + 1, vlist);
26+
va_end(vlist);
27+
return ret_val;
28+
}
29+
30+
} // namespace LIBC_NAMESPACE

libc/src/stdio/gpu/vfprintf.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===-- GPU Implementation of vfprintf ------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/stdio/vfprintf.h"
10+
11+
#include "src/__support/CPP/string_view.h"
12+
#include "src/__support/arg_list.h"
13+
#include "src/errno/libc_errno.h"
14+
#include "src/stdio/gpu/vfprintf_utils.h"
15+
16+
#include <stdio.h>
17+
18+
namespace LIBC_NAMESPACE {
19+
20+
LLVM_LIBC_FUNCTION(int, vfprintf,
21+
(::FILE *__restrict stream, const char *__restrict format,
22+
va_list vlist)) {
23+
cpp::string_view str_view(format);
24+
int ret_val =
25+
file::vfprintf_internal(stream, format, str_view.size() + 1, vlist);
26+
return ret_val;
27+
}
28+
29+
} // namespace LIBC_NAMESPACE

libc/src/stdio/gpu/vfprintf_utils.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//===--- GPU helper functions for printf using RPC ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/__support/RPC/rpc_client.h"
10+
#include "src/__support/arg_list.h"
11+
#include "src/stdio/gpu/file.h"
12+
#include "src/string/string_utils.h"
13+
14+
#include <stdio.h>
15+
16+
namespace LIBC_NAMESPACE {
17+
namespace file {
18+
19+
template <uint16_t opcode>
20+
LIBC_INLINE int vfprintf_impl(::FILE *__restrict file,
21+
const char *__restrict format, size_t format_size,
22+
va_list vlist) {
23+
uint64_t mask = gpu::get_lane_mask();
24+
rpc::Client::Port port = rpc::client.open<opcode>();
25+
26+
if constexpr (opcode == RPC_PRINTF_TO_STREAM) {
27+
port.send([&](rpc::Buffer *buffer) {
28+
buffer->data[0] = reinterpret_cast<uintptr_t>(file);
29+
});
30+
}
31+
32+
size_t args_size = 0;
33+
port.send_n(format, format_size);
34+
port.recv([&](rpc::Buffer *buffer) {
35+
args_size = static_cast<size_t>(buffer->data[0]);
36+
});
37+
port.send_n(vlist, args_size);
38+
39+
uint32_t ret = 0;
40+
for (;;) {
41+
const char *str = nullptr;
42+
port.recv([&](rpc::Buffer *buffer) {
43+
ret = static_cast<uint32_t>(buffer->data[0]);
44+
str = reinterpret_cast<const char *>(buffer->data[1]);
45+
});
46+
// If any lanes have a string argument it needs to be copied back.
47+
if (!gpu::ballot(mask, str))
48+
break;
49+
50+
uint64_t size = str ? internal::string_length(str) + 1 : 0;
51+
port.send_n(str, size);
52+
}
53+
54+
port.close();
55+
return ret;
56+
}
57+
58+
LIBC_INLINE int vfprintf_internal(::FILE *__restrict stream,
59+
const char *__restrict format,
60+
size_t format_size, va_list vlist) {
61+
if (stream == stdout)
62+
return vfprintf_impl<RPC_PRINTF_TO_STDOUT>(stream, format, format_size,
63+
vlist);
64+
else if (stream == stderr)
65+
return vfprintf_impl<RPC_PRINTF_TO_STDERR>(stream, format, format_size,
66+
vlist);
67+
else
68+
return vfprintf_impl<RPC_PRINTF_TO_STREAM>(stream, format, format_size,
69+
vlist);
70+
}
71+
72+
} // namespace file
73+
} // namespace LIBC_NAMESPACE

libc/src/stdio/gpu/vprintf.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===-- GPU Implementation of vprintf -------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/stdio/vprintf.h"
10+
11+
#include "src/__support/CPP/string_view.h"
12+
#include "src/__support/arg_list.h"
13+
#include "src/errno/libc_errno.h"
14+
#include "src/stdio/gpu/vfprintf_utils.h"
15+
16+
#include <stdio.h>
17+
18+
namespace LIBC_NAMESPACE {
19+
20+
LLVM_LIBC_FUNCTION(int, vprintf,
21+
(const char *__restrict format, va_list vlist)) {
22+
cpp::string_view str_view(format);
23+
int ret_val =
24+
file::vfprintf_internal(stdout, format, str_view.size() + 1, vlist);
25+
return ret_val;
26+
}
27+
28+
} // namespace LIBC_NAMESPACE

0 commit comments

Comments
 (0)