Skip to content

[NVPTX] Implement variadic functions using IR lowering #96015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions clang/lib/Basic/Targets/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
}

BuiltinVaListKind getBuiltinVaListKind() const override {
// FIXME: implement
return TargetInfo::CharPtrBuiltinVaList;
return TargetInfo::VoidPtrBuiltinVaList;
}

bool isValidCPUName(StringRef Name) const override {
Expand Down
12 changes: 9 additions & 3 deletions clang/lib/CodeGen/Targets/NVPTX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,11 @@ ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
if (!getCXXABI().classifyReturnType(FI))
FI.getReturnInfo() = classifyReturnType(FI.getReturnType());
for (auto &I : FI.arguments())
I.info = classifyArgumentType(I.type);

for (auto &&[ArgumentsCount, I] : llvm::enumerate(FI.arguments()))
I.info = ArgumentsCount < FI.getNumRequiredArgs()
? classifyArgumentType(I.type)
: ABIArgInfo::getDirect();

// Always honor user-specified calling convention.
if (FI.getCallingConvention() != llvm::CallingConv::C)
Expand All @@ -215,7 +218,10 @@ void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {

RValue NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
QualType Ty, AggValueSlot Slot) const {
llvm_unreachable("NVPTX does not support varargs");
return emitVoidPtrVAArg(CGF, VAListAddr, Ty, /*IsIndirect=*/false,
getContext().getTypeInfoInChars(Ty),
CharUnits::fromQuantity(1),
/*AllowHigherAlign=*/true, Slot);
}

void NVPTXTargetCodeGenInfo::setTargetAttributes(
Expand Down
94 changes: 94 additions & 0 deletions clang/test/CodeGen/variadic-nvptx.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -emit-llvm -o - %s | FileCheck %s

extern void varargs_simple(int, ...);

// CHECK-LABEL: define dso_local void @foo(
// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[C:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[S:%.*]] = alloca i16, align 2
// CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
// CHECK-NEXT: [[L:%.*]] = alloca i64, align 8
// CHECK-NEXT: [[F:%.*]] = alloca float, align 4
// CHECK-NEXT: [[D:%.*]] = alloca double, align 8
// CHECK-NEXT: [[A:%.*]] = alloca [[STRUCT_ANON:%.*]], align 4
// CHECK-NEXT: [[V:%.*]] = alloca <4 x i32>, align 16
// CHECK-NEXT: [[T:%.*]] = alloca [[STRUCT_ANON_0:%.*]], align 1
// CHECK-NEXT: store i8 1, ptr [[C]], align 1
// CHECK-NEXT: store i16 1, ptr [[S]], align 2
// CHECK-NEXT: store i32 1, ptr [[I]], align 4
// CHECK-NEXT: store i64 1, ptr [[L]], align 8
// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 4
// CHECK-NEXT: store double 1.000000e+00, ptr [[D]], align 8
// CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[C]], align 1
// CHECK-NEXT: [[CONV:%.*]] = sext i8 [[TMP0]] to i32
// CHECK-NEXT: [[TMP1:%.*]] = load i16, ptr [[S]], align 2
// CHECK-NEXT: [[CONV1:%.*]] = sext i16 [[TMP1]] to i32
// CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[I]], align 4
// CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr [[L]], align 8
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[F]], align 4
// CHECK-NEXT: [[CONV2:%.*]] = fpext float [[TMP4]] to double
// CHECK-NEXT: [[TMP5:%.*]] = load double, ptr [[D]], align 8
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 noundef [[CONV]], i32 noundef [[CONV1]], i32 noundef [[TMP2]], i64 noundef [[TMP3]], double noundef [[CONV2]], double noundef [[TMP5]])
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[A]], ptr align 4 @__const.foo.a, i64 12, i1 false)
// CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 0
// CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr [[TMP6]], align 4
// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 1
// CHECK-NEXT: [[TMP9:%.*]] = load i8, ptr [[TMP8]], align 4
// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 2
// CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr [[TMP10]], align 4
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 [[TMP7]], i8 [[TMP9]], i32 [[TMP11]])
// CHECK-NEXT: store <4 x i32> <i32 1, i32 1, i32 1, i32 1>, ptr [[V]], align 16
// CHECK-NEXT: [[TMP12:%.*]] = load <4 x i32>, ptr [[V]], align 16
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, <4 x i32> noundef [[TMP12]])
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
// CHECK-NEXT: [[TMP14:%.*]] = load i8, ptr [[TMP13]], align 1
// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
// CHECK-NEXT: [[TMP16:%.*]] = load i8, ptr [[TMP15]], align 1
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
// CHECK-NEXT: [[TMP18:%.*]] = load i8, ptr [[TMP17]], align 1
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
// CHECK-NEXT: [[TMP20:%.*]] = load i8, ptr [[TMP19]], align 1
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
// CHECK-NEXT: [[TMP22:%.*]] = load i8, ptr [[TMP21]], align 1
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
// CHECK-NEXT: [[TMP24:%.*]] = load i8, ptr [[TMP23]], align 1
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i8 [[TMP14]], i8 [[TMP16]], i8 [[TMP18]], i8 [[TMP20]], i32 noundef 0, i8 [[TMP22]], i8 [[TMP24]])
// CHECK-NEXT: ret void
//
void foo() {
char c = '\x1';
short s = 1;
int i = 1;
long l = 1;
float f = 1.f;
double d = 1.;
varargs_simple(0, c, s, i, l, f, d);

struct {int x; char c; int y;} a = {1, '\x1', 1};
varargs_simple(0, a);

typedef int __attribute__((ext_vector_type(4))) int4;
int4 v = {1, 1, 1, 1};
varargs_simple(0, v);

struct {char c, d;} t;
varargs_simple(0, t, t, 0, t);
}

typedef struct {long x; long y;} S;
extern void varargs_complex(S, S, ...);

// CHECK-LABEL: define dso_local void @bar(
// CHECK-SAME: ) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 8
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[S]], ptr align 8 @__const.bar.s, i64 16, i1 false)
// CHECK-NEXT: call void (ptr, ptr, ...) @varargs_complex(ptr noundef byval([[STRUCT_S]]) align 8 [[S]], ptr noundef byval([[STRUCT_S]]) align 8 [[S]], i32 noundef 1, i64 noundef 1, double noundef 1.000000e+00)
// CHECK-NEXT: ret void
//
void bar() {
S s = {1l, 1l};
varargs_complex(s, s, 1, 1l, 1.0);
}
15 changes: 4 additions & 11 deletions libc/config/gpu/entrypoints.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
if(LIBC_TARGET_ARCHITECTURE_IS_AMDGPU)
set(extra_entrypoints
# stdio.h entrypoints
libc.src.stdio.snprintf
libc.src.stdio.sprintf
libc.src.stdio.vsnprintf
libc.src.stdio.vsprintf
)
endif()

set(TARGET_LIBC_ENTRYPOINTS
# assert.h entrypoints
libc.src.assert.__assert_fail
Expand Down Expand Up @@ -185,9 +175,12 @@ set(TARGET_LIBC_ENTRYPOINTS
libc.src.errno.errno

# stdio.h entrypoints
${extra_entrypoints}
libc.src.stdio.clearerr
libc.src.stdio.fclose
libc.src.stdio.sprintf
libc.src.stdio.snprintf
libc.src.stdio.vsprintf
libc.src.stdio.vsnprintf
libc.src.stdio.feof
libc.src.stdio.ferror
libc.src.stdio.fflush
Expand Down
21 changes: 9 additions & 12 deletions libc/test/src/__support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,15 @@ add_libc_test(
libc.src.__support.uint128
)

# NVPTX does not support varargs currently.
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
add_libc_test(
arg_list_test
SUITE
libc-support-tests
SRCS
arg_list_test.cpp
DEPENDS
libc.src.__support.arg_list
)
endif()
add_libc_test(
arg_list_test
SUITE
libc-support-tests
SRCS
arg_list_test.cpp
DEPENDS
libc.src.__support.arg_list
)

if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
add_libc_test(
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO/ExpandVariadics.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
Expand Down Expand Up @@ -342,6 +343,7 @@ void NVPTXPassConfig::addIRPasses() {
}

addPass(createAtomicExpandLegacyPass());
addPass(createExpandVariadicsPass(ExpandVariadicsMode::Lowering));
addPass(createNVPTXCtorDtorLoweringLegacyPass());

// === LSR and other generic IR passes ===
Expand Down
40 changes: 36 additions & 4 deletions llvm/lib/Transforms/IPO/ExpandVariadics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ bool ExpandVariadics::runOnFunction(Module &M, IRBuilder<> &Builder,
// Replace known calls to the variadic with calls to the va_list equivalent
for (User *U : make_early_inc_range(VariadicWrapper->users())) {
if (CallBase *CB = dyn_cast<CallBase>(U)) {
Value *calledOperand = CB->getCalledOperand();
if (VariadicWrapper == calledOperand)
Value *CalledOperand = CB->getCalledOperand();
if (VariadicWrapper == CalledOperand)
Changed |=
expandCall(M, Builder, CB, VariadicWrapper->getFunctionType(),
FixedArityReplacement);
Expand Down Expand Up @@ -938,6 +938,33 @@ struct Amdgpu final : public VariadicABIInfo {
}
};

struct NVPTX final : public VariadicABIInfo {

bool enableForTarget() override { return true; }

bool vaListPassedInSSARegister() override { return true; }

Type *vaListType(LLVMContext &Ctx) override {
return PointerType::getUnqual(Ctx);
}

Type *vaListParameterType(Module &M) override {
return PointerType::getUnqual(M.getContext());
}

Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder,
AllocaInst *, Value *Buffer) override {
return Builder.CreateAddrSpaceCast(Buffer, vaListParameterType(M));
}

VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override {
// NVPTX expects natural alignment in all cases. The variadic call ABI will
// handle promoting types to their appropriate size and alignment.
Align A = DL.getABITypeAlign(Parameter);
return {A, false};
}
};

struct Wasm final : public VariadicABIInfo {

bool enableForTarget() override {
Expand Down Expand Up @@ -967,8 +994,8 @@ struct Wasm final : public VariadicABIInfo {
if (A < MinAlign)
A = Align(MinAlign);

if (auto s = dyn_cast<StructType>(Parameter)) {
if (s->getNumElements() > 1) {
if (auto *S = dyn_cast<StructType>(Parameter)) {
if (S->getNumElements() > 1) {
return {DL.getABITypeAlign(PointerType::getUnqual(Ctx)), true};
}
}
Expand All @@ -988,6 +1015,11 @@ std::unique_ptr<VariadicABIInfo> VariadicABIInfo::create(const Triple &T) {
return std::make_unique<Wasm>();
}

case Triple::nvptx:
case Triple::nvptx64: {
return std::make_unique<NVPTX>();
}

default:
return {};
}
Expand Down
Loading
Loading