Skip to content

Commit d7e1932

Browse files
authored
[HIP] Fix comdat of template kernel handle (llvm#66283)
Currently, clang emits LLVM IR that fails verifier for the following code: ``` template<typename T> __global__ void foo(T x); void bar() { foo<<<1, 1>>>(0); } ``` This is due to clang putting the kernel handle for foo into comdat, which is not allowed, since the kernel handle is a declaration. The siutation is similar to calling a declaration-only template function. The callee will be a declaration in LLVM IR and won't be put into comdat. This is in contrast to calling a template function with body, which will be put into comdat. Fixes: SWDEV-419769
1 parent cafb628 commit d7e1932

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

clang/lib/CodeGen/CGCUDANV.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,10 @@ llvm::GlobalValue *CGNVCUDARuntime::getKernelHandle(llvm::Function *F,
12341234
Var->setAlignment(CGM.getPointerAlign().getAsAlign());
12351235
Var->setDSOLocal(F->isDSOLocal());
12361236
Var->setVisibility(F->getVisibility());
1237-
CGM.maybeSetTrivialComdat(*GD.getDecl(), *Var);
1237+
auto *FD = cast<FunctionDecl>(GD.getDecl());
1238+
auto *FT = FD->getPrimaryTemplate();
1239+
if (!FT || FT->isThisDeclarationADefinition())
1240+
CGM.maybeSetTrivialComdat(*FD, *Var);
12381241
KernelHandles[F->getName()] = Var;
12391242
KernelStubs[Var] = F;
12401243
return Var;

clang/test/CodeGenCUDA/kernel-stub-name.cu

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626
// GNU: @[[HNSKERN:_ZN2ns8nskernelEv]] = constant ptr @[[NSSTUB:_ZN2ns23__device_stub__nskernelEv]], align 8
2727
// GNU: @[[HTKERN:_Z10kernelfuncIiEvv]] = linkonce_odr constant ptr @[[TSTUB:_Z25__device_stub__kernelfuncIiEvv]], comdat, align 8
2828
// GNU: @[[HDKERN:_Z11kernel_declv]] = external constant ptr, align 8
29+
// GNU: @[[HTDKERN:_Z20template_kernel_declIiEvT_]] = external constant ptr, align 8
2930

3031
// MSVC: @[[HCKERN:ckernel]] = dso_local constant ptr @[[CSTUB:__device_stub__ckernel]], align 8
3132
// MSVC: @[[HNSKERN:"\?nskernel@ns@@YAXXZ.*"]] = dso_local constant ptr @[[NSSTUB:"\?__device_stub__nskernel@ns@@YAXXZ"]], align 8
3233
// MSVC: @[[HTKERN:"\?\?\$kernelfunc@H@@YAXXZ.*"]] = linkonce_odr dso_local constant ptr @[[TSTUB:"\?\?\$__device_stub__kernelfunc@H@@YAXXZ.*"]], comdat, align 8
3334
// MSVC: @[[HDKERN:"\?kernel_decl@@YAXXZ.*"]] = external dso_local constant ptr, align 8
34-
35+
// MSVC: @[[HTDKERN:"\?\?\$template_kernel_decl@H@@YAXH.*"]] = external dso_local constant ptr, align 8
3536
extern "C" __global__ void ckernel() {}
3637

3738
namespace ns {
@@ -43,6 +44,9 @@ __global__ void kernelfunc() {}
4344

4445
__global__ void kernel_decl();
4546

47+
template<class T>
48+
__global__ void template_kernel_decl(T x);
49+
4650
extern "C" void (*kernel_ptr)();
4751
extern "C" void *void_ptr;
4852

@@ -69,13 +73,16 @@ extern "C" void launch(void *kern);
6973
// CHECK: call void @[[NSSTUB]]()
7074
// CHECK: call void @[[TSTUB]]()
7175
// GNU: call void @[[DSTUB:_Z26__device_stub__kernel_declv]]()
76+
// GNU: call void @[[TDSTUB:_Z35__device_stub__template_kernel_declIiEvT_]](
7277
// MSVC: call void @[[DSTUB:"\?__device_stub__kernel_decl@@YAXXZ"]]()
78+
// MSVC: call void @[[TDSTUB:"\?\?\$__device_stub__template_kernel_decl@H@@YAXH@Z"]](
7379

7480
extern "C" void fun1(void) {
7581
ckernel<<<1, 1>>>();
7682
ns::nskernel<<<1, 1>>>();
7783
kernelfunc<int><<<1, 1>>>();
7884
kernel_decl<<<1, 1>>>();
85+
template_kernel_decl<<<1, 1>>>(1);
7986
}
8087

8188
// Template kernel stub functions
@@ -86,6 +93,7 @@ extern "C" void fun1(void) {
8693
// Check declaration of stub function for external kernel.
8794

8895
// CHECK: declare{{.*}}@[[DSTUB]]
96+
// CHECK: declare{{.*}}@[[TDSTUB]]
8997

9098
// Check kernel handle is used for passing the kernel as a function pointer.
9199

@@ -94,11 +102,13 @@ extern "C" void fun1(void) {
94102
// CHECK: call void @launch({{.*}}[[HNSKERN]]
95103
// CHECK: call void @launch({{.*}}[[HTKERN]]
96104
// CHECK: call void @launch({{.*}}[[HDKERN]]
105+
// CHECK: call void @launch({{.*}}[[HTDKERN]]
97106
extern "C" void fun2() {
98107
launch((void *)ckernel);
99108
launch((void *)ns::nskernel);
100109
launch((void *)kernelfunc<int>);
101110
launch((void *)kernel_decl);
111+
launch((void *)template_kernel_decl<int>);
102112
}
103113

104114
// Check kernel handle is used for assigning a kernel to a function pointer.
@@ -148,3 +158,4 @@ extern "C" void fun5() {
148158
// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[HTKERN]]{{.*}}@[[TKERN]]
149159
// NEG-NOT: call{{.*}}@__hipRegisterFunction{{.*}}__device_stub
150160
// NEG-NOT: call{{.*}}@__hipRegisterFunction{{.*}}kernel_decl
161+
// NEG-NOT: call{{.*}}@__hipRegisterFunction{{.*}}template_kernel_decl

0 commit comments

Comments
 (0)