Skip to content

Commit db0c59e

Browse files
jhuber6aaryanshukla
authored andcommitted
[NVPTX] Implement variadic functions using IR lowering (llvm#96015)
Summary: This patch implements support for variadic functions for NVPTX targets. The implementation here mainly follows what was done to implement it for AMDGPU in llvm#93362. We change the NVPTX codegen to lower all variadic arguments to functions by-value. This creates a flattened set of arguments that the IR lowering pass converts into a struct with the proper alignment. The behavior of this function was determined by iteratively checking what the NVCC copmiler generates for its output. See examples like https://godbolt.org/z/KavfTGY93. I have noted the main methods that NVIDIA uses to lower variadic functions. 1. All arguments are passed in a pointer to aggregate. 2. The minimum alignment for a plain argument is 4 bytes. 3. Alignment is dictated by the underlying type 4. Structs are flattened and do not have their alignment changed. 5. NVPTX never passes any arguments indirectly, even very large ones. This patch passes the tests in the `libc` project currently, including support for `sprintf`.
1 parent 6b6d778 commit db0c59e

File tree

9 files changed

+930
-32
lines changed

9 files changed

+930
-32
lines changed

clang/lib/Basic/Targets/NVPTX.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
119119
}
120120

121121
BuiltinVaListKind getBuiltinVaListKind() const override {
122-
// FIXME: implement
123-
return TargetInfo::CharPtrBuiltinVaList;
122+
return TargetInfo::VoidPtrBuiltinVaList;
124123
}
125124

126125
bool isValidCPUName(StringRef Name) const override {

clang/lib/CodeGen/Targets/NVPTX.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,11 @@ ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
203203
void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
204204
if (!getCXXABI().classifyReturnType(FI))
205205
FI.getReturnInfo() = classifyReturnType(FI.getReturnType());
206-
for (auto &I : FI.arguments())
207-
I.info = classifyArgumentType(I.type);
206+
207+
for (auto &&[ArgumentsCount, I] : llvm::enumerate(FI.arguments()))
208+
I.info = ArgumentsCount < FI.getNumRequiredArgs()
209+
? classifyArgumentType(I.type)
210+
: ABIArgInfo::getDirect();
208211

209212
// Always honor user-specified calling convention.
210213
if (FI.getCallingConvention() != llvm::CallingConv::C)
@@ -215,7 +218,10 @@ void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
215218

216219
RValue NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
217220
QualType Ty, AggValueSlot Slot) const {
218-
llvm_unreachable("NVPTX does not support varargs");
221+
return emitVoidPtrVAArg(CGF, VAListAddr, Ty, /*IsIndirect=*/false,
222+
getContext().getTypeInfoInChars(Ty),
223+
CharUnits::fromQuantity(1),
224+
/*AllowHigherAlign=*/true, Slot);
219225
}
220226

221227
void NVPTXTargetCodeGenInfo::setTargetAttributes(

clang/test/CodeGen/variadic-nvptx.c

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -emit-llvm -o - %s | FileCheck %s
3+
4+
extern void varargs_simple(int, ...);
5+
6+
// CHECK-LABEL: define dso_local void @foo(
7+
// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
8+
// CHECK-NEXT: [[ENTRY:.*:]]
9+
// CHECK-NEXT: [[C:%.*]] = alloca i8, align 1
10+
// CHECK-NEXT: [[S:%.*]] = alloca i16, align 2
11+
// CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
12+
// CHECK-NEXT: [[L:%.*]] = alloca i64, align 8
13+
// CHECK-NEXT: [[F:%.*]] = alloca float, align 4
14+
// CHECK-NEXT: [[D:%.*]] = alloca double, align 8
15+
// CHECK-NEXT: [[A:%.*]] = alloca [[STRUCT_ANON:%.*]], align 4
16+
// CHECK-NEXT: [[V:%.*]] = alloca <4 x i32>, align 16
17+
// CHECK-NEXT: [[T:%.*]] = alloca [[STRUCT_ANON_0:%.*]], align 1
18+
// CHECK-NEXT: store i8 1, ptr [[C]], align 1
19+
// CHECK-NEXT: store i16 1, ptr [[S]], align 2
20+
// CHECK-NEXT: store i32 1, ptr [[I]], align 4
21+
// CHECK-NEXT: store i64 1, ptr [[L]], align 8
22+
// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 4
23+
// CHECK-NEXT: store double 1.000000e+00, ptr [[D]], align 8
24+
// CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[C]], align 1
25+
// CHECK-NEXT: [[CONV:%.*]] = sext i8 [[TMP0]] to i32
26+
// CHECK-NEXT: [[TMP1:%.*]] = load i16, ptr [[S]], align 2
27+
// CHECK-NEXT: [[CONV1:%.*]] = sext i16 [[TMP1]] to i32
28+
// CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[I]], align 4
29+
// CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr [[L]], align 8
30+
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[F]], align 4
31+
// CHECK-NEXT: [[CONV2:%.*]] = fpext float [[TMP4]] to double
32+
// CHECK-NEXT: [[TMP5:%.*]] = load double, ptr [[D]], align 8
33+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 noundef [[CONV]], i32 noundef [[CONV1]], i32 noundef [[TMP2]], i64 noundef [[TMP3]], double noundef [[CONV2]], double noundef [[TMP5]])
34+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[A]], ptr align 4 @__const.foo.a, i64 12, i1 false)
35+
// CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 0
36+
// CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr [[TMP6]], align 4
37+
// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 1
38+
// CHECK-NEXT: [[TMP9:%.*]] = load i8, ptr [[TMP8]], align 4
39+
// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 2
40+
// CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr [[TMP10]], align 4
41+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 [[TMP7]], i8 [[TMP9]], i32 [[TMP11]])
42+
// CHECK-NEXT: store <4 x i32> <i32 1, i32 1, i32 1, i32 1>, ptr [[V]], align 16
43+
// CHECK-NEXT: [[TMP12:%.*]] = load <4 x i32>, ptr [[V]], align 16
44+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, <4 x i32> noundef [[TMP12]])
45+
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
46+
// CHECK-NEXT: [[TMP14:%.*]] = load i8, ptr [[TMP13]], align 1
47+
// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
48+
// CHECK-NEXT: [[TMP16:%.*]] = load i8, ptr [[TMP15]], align 1
49+
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
50+
// CHECK-NEXT: [[TMP18:%.*]] = load i8, ptr [[TMP17]], align 1
51+
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
52+
// CHECK-NEXT: [[TMP20:%.*]] = load i8, ptr [[TMP19]], align 1
53+
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
54+
// CHECK-NEXT: [[TMP22:%.*]] = load i8, ptr [[TMP21]], align 1
55+
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
56+
// CHECK-NEXT: [[TMP24:%.*]] = load i8, ptr [[TMP23]], align 1
57+
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i8 [[TMP14]], i8 [[TMP16]], i8 [[TMP18]], i8 [[TMP20]], i32 noundef 0, i8 [[TMP22]], i8 [[TMP24]])
58+
// CHECK-NEXT: ret void
59+
//
60+
void foo() {
61+
char c = '\x1';
62+
short s = 1;
63+
int i = 1;
64+
long l = 1;
65+
float f = 1.f;
66+
double d = 1.;
67+
varargs_simple(0, c, s, i, l, f, d);
68+
69+
struct {int x; char c; int y;} a = {1, '\x1', 1};
70+
varargs_simple(0, a);
71+
72+
typedef int __attribute__((ext_vector_type(4))) int4;
73+
int4 v = {1, 1, 1, 1};
74+
varargs_simple(0, v);
75+
76+
struct {char c, d;} t;
77+
varargs_simple(0, t, t, 0, t);
78+
}
79+
80+
typedef struct {long x; long y;} S;
81+
extern void varargs_complex(S, S, ...);
82+
83+
// CHECK-LABEL: define dso_local void @bar(
84+
// CHECK-SAME: ) #[[ATTR0]] {
85+
// CHECK-NEXT: [[ENTRY:.*:]]
86+
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 8
87+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[S]], ptr align 8 @__const.bar.s, i64 16, i1 false)
88+
// CHECK-NEXT: call void (ptr, ptr, ...) @varargs_complex(ptr noundef byval([[STRUCT_S]]) align 8 [[S]], ptr noundef byval([[STRUCT_S]]) align 8 [[S]], i32 noundef 1, i64 noundef 1, double noundef 1.000000e+00)
89+
// CHECK-NEXT: ret void
90+
//
91+
void bar() {
92+
S s = {1l, 1l};
93+
varargs_complex(s, s, 1, 1l, 1.0);
94+
}

libc/config/gpu/entrypoints.txt

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
if(LIBC_TARGET_ARCHITECTURE_IS_AMDGPU)
2-
set(extra_entrypoints
3-
# stdio.h entrypoints
4-
libc.src.stdio.snprintf
5-
libc.src.stdio.sprintf
6-
libc.src.stdio.vsnprintf
7-
libc.src.stdio.vsprintf
8-
)
9-
endif()
10-
111
set(TARGET_LIBC_ENTRYPOINTS
122
# assert.h entrypoints
133
libc.src.assert.__assert_fail
@@ -185,9 +175,12 @@ set(TARGET_LIBC_ENTRYPOINTS
185175
libc.src.errno.errno
186176

187177
# stdio.h entrypoints
188-
${extra_entrypoints}
189178
libc.src.stdio.clearerr
190179
libc.src.stdio.fclose
180+
libc.src.stdio.sprintf
181+
libc.src.stdio.snprintf
182+
libc.src.stdio.vsprintf
183+
libc.src.stdio.vsnprintf
191184
libc.src.stdio.feof
192185
libc.src.stdio.ferror
193186
libc.src.stdio.fflush

libc/test/src/__support/CMakeLists.txt

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,15 @@ add_libc_test(
131131
libc.src.__support.uint128
132132
)
133133

134-
# NVPTX does not support varargs currently.
135-
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
136-
add_libc_test(
137-
arg_list_test
138-
SUITE
139-
libc-support-tests
140-
SRCS
141-
arg_list_test.cpp
142-
DEPENDS
143-
libc.src.__support.arg_list
144-
)
145-
endif()
134+
add_libc_test(
135+
arg_list_test
136+
SUITE
137+
libc-support-tests
138+
SRCS
139+
arg_list_test.cpp
140+
DEPENDS
141+
libc.src.__support.arg_list
142+
)
146143

147144
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
148145
add_libc_test(

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/Target/TargetMachine.h"
3434
#include "llvm/Target/TargetOptions.h"
3535
#include "llvm/TargetParser/Triple.h"
36+
#include "llvm/Transforms/IPO/ExpandVariadics.h"
3637
#include "llvm/Transforms/Scalar.h"
3738
#include "llvm/Transforms/Scalar/GVN.h"
3839
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
@@ -342,6 +343,7 @@ void NVPTXPassConfig::addIRPasses() {
342343
}
343344

344345
addPass(createAtomicExpandLegacyPass());
346+
addPass(createExpandVariadicsPass(ExpandVariadicsMode::Lowering));
345347
addPass(createNVPTXCtorDtorLoweringLegacyPass());
346348

347349
// === LSR and other generic IR passes ===

llvm/lib/Transforms/IPO/ExpandVariadics.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ bool ExpandVariadics::runOnFunction(Module &M, IRBuilder<> &Builder,
456456
// Replace known calls to the variadic with calls to the va_list equivalent
457457
for (User *U : make_early_inc_range(VariadicWrapper->users())) {
458458
if (CallBase *CB = dyn_cast<CallBase>(U)) {
459-
Value *calledOperand = CB->getCalledOperand();
460-
if (VariadicWrapper == calledOperand)
459+
Value *CalledOperand = CB->getCalledOperand();
460+
if (VariadicWrapper == CalledOperand)
461461
Changed |=
462462
expandCall(M, Builder, CB, VariadicWrapper->getFunctionType(),
463463
FixedArityReplacement);
@@ -938,6 +938,33 @@ struct Amdgpu final : public VariadicABIInfo {
938938
}
939939
};
940940

941+
struct NVPTX final : public VariadicABIInfo {
942+
943+
bool enableForTarget() override { return true; }
944+
945+
bool vaListPassedInSSARegister() override { return true; }
946+
947+
Type *vaListType(LLVMContext &Ctx) override {
948+
return PointerType::getUnqual(Ctx);
949+
}
950+
951+
Type *vaListParameterType(Module &M) override {
952+
return PointerType::getUnqual(M.getContext());
953+
}
954+
955+
Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder,
956+
AllocaInst *, Value *Buffer) override {
957+
return Builder.CreateAddrSpaceCast(Buffer, vaListParameterType(M));
958+
}
959+
960+
VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override {
961+
// NVPTX expects natural alignment in all cases. The variadic call ABI will
962+
// handle promoting types to their appropriate size and alignment.
963+
Align A = DL.getABITypeAlign(Parameter);
964+
return {A, false};
965+
}
966+
};
967+
941968
struct Wasm final : public VariadicABIInfo {
942969

943970
bool enableForTarget() override {
@@ -967,8 +994,8 @@ struct Wasm final : public VariadicABIInfo {
967994
if (A < MinAlign)
968995
A = Align(MinAlign);
969996

970-
if (auto s = dyn_cast<StructType>(Parameter)) {
971-
if (s->getNumElements() > 1) {
997+
if (auto *S = dyn_cast<StructType>(Parameter)) {
998+
if (S->getNumElements() > 1) {
972999
return {DL.getABITypeAlign(PointerType::getUnqual(Ctx)), true};
9731000
}
9741001
}
@@ -988,6 +1015,11 @@ std::unique_ptr<VariadicABIInfo> VariadicABIInfo::create(const Triple &T) {
9881015
return std::make_unique<Wasm>();
9891016
}
9901017

1018+
case Triple::nvptx:
1019+
case Triple::nvptx64: {
1020+
return std::make_unique<NVPTX>();
1021+
}
1022+
9911023
default:
9921024
return {};
9931025
}

0 commit comments

Comments
 (0)