Skip to content

Commit 6b167f8

Browse files
AdUhTkJmlanza
authored andcommitted
[CIR][CUDA] Generate attribute for kernel name of device stubs (#1317)
Now a `__global__` function on host will be generated to a device stub, with an attribute recording the corresponding kernel name (mangled name on device of the same function). The dynamic registration phase will be implemented in LLVM lowering. For example, CIR generated for `__global__ void global_fn();` looks like ``` #fn_attr1 = #cir<extra({cuda_kernel_name = #cir.cuda_kernel_name<_Z9global_fnv>})> cir.func private @_Z24__device_stub__global_fnv() extra(#fn_attr1) ```
1 parent d38092f commit 6b167f8

File tree

6 files changed

+79
-31
lines changed

6 files changed

+79
-31
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,5 +1327,6 @@ def CIR_TBAAAttr : CIR_Attr<"TBAA", "tbaa", []> {
13271327
}
13281328

13291329
include "clang/CIR/Dialect/IR/CIROpenCLAttrs.td"
1330+
include "clang/CIR/Dialect/IR/CIRCUDAAttrs.td"
13301331

13311332
#endif // MLIR_CIR_DIALECT_CIR_ATTRS
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===---- CIRCUDAAttrs.td - CIR dialect attrs for CUDA -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the CIR dialect attributes for OpenCL.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CIR_DIALECT_CIR_CUDA_ATTRS
14+
#define MLIR_CIR_DIALECT_CIR_CUDA_ATTRS
15+
16+
//===----------------------------------------------------------------------===//
17+
// CUDAKernelNameAttr
18+
//===----------------------------------------------------------------------===//
19+
20+
def CUDAKernelNameAttr : CIR_Attr<"CUDAKernelName",
21+
"cuda_kernel_name"> {
22+
let summary = "Device-side function name for this stub.";
23+
let description =
24+
[{
25+
This attribute is attached to function definitions and records the
26+
mangled name of the kernel function used on the device.
27+
28+
In CUDA, global functions (kernels) are processed differently for host
29+
and device. On host, Clang generates device stubs; on device, they are
30+
treated as normal functions. As they probably have different mangled
31+
names, we must record the corresponding device-side name for a stub.
32+
}];
33+
34+
let parameters = (ins "std::string":$kernel_name);
35+
let assemblyFormat = "`<` $kernel_name `>`";
36+
}
37+
38+
#endif // MLIR_CIR_DIALECT_CIR_CUDA_ATTRS

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,16 @@ void CIRGenModule::constructAttributeList(
460460
getLangOpts().OffloadUniformBlock)
461461
assert(!cir::MissingFeatures::CUDA());
462462

463+
if (langOpts.CUDA && !langOpts.CUDAIsDevice &&
464+
TargetDecl->hasAttr<CUDAGlobalAttr>()) {
465+
GlobalDecl kernel(CalleeInfo.getCalleeDecl());
466+
llvm::StringRef kernelName = getMangledName(
467+
kernel.getWithKernelReferenceKind(KernelReferenceKind::Kernel));
468+
auto attr =
469+
cir::CUDAKernelNameAttr::get(&getMLIRContext(), kernelName.str());
470+
funcAttrs.set(attr.getMnemonic(), attr);
471+
}
472+
463473
if (TargetDecl->hasAttr<ArmLocallyStreamingAttr>())
464474
;
465475
}

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
// This is the internal per-translation-unit state used for CIR translation.
1010
//
1111
//===----------------------------------------------------------------------===//
12-
#include "CIRGenModule.h"
1312
#include "CIRGenCXXABI.h"
1413
#include "CIRGenCstEmitter.h"
1514
#include "CIRGenFunction.h"
@@ -528,10 +527,9 @@ void CIRGenModule::emitGlobal(GlobalDecl GD) {
528527
if (langOpts.HIPStdPar)
529528
llvm_unreachable("NYI");
530529

531-
if (Global->hasAttr<CUDAGlobalAttr>())
532-
llvm_unreachable("NYI");
533-
534-
if (!Global->hasAttr<CUDADeviceAttr>())
530+
// Global functions reside on device, so it shouldn't be skipped.
531+
if (!Global->hasAttr<CUDAGlobalAttr>() &&
532+
!Global->hasAttr<CUDADeviceAttr>())
535533
return;
536534
} else {
537535
// We must skip __device__ functions when compiling for host.
@@ -2352,10 +2350,10 @@ cir::FuncOp CIRGenModule::GetAddrOfFunction(clang::GlobalDecl GD, mlir::Type Ty,
23522350
auto F = GetOrCreateCIRFunction(MangledName, Ty, GD, ForVTable, DontDefer,
23532351
/*IsThunk=*/false, IsForDefinition);
23542352

2355-
// As __global__ functions always reside on device,
2356-
// we need special care when accessing them from host;
2357-
// otherwise, CUDA functions behave as normal functions
2358-
if (langOpts.CUDA && !langOpts.CUDAIsDevice &&
2353+
// As __global__ functions (kernels) always reside on device,
2354+
// when we access them from host, we must refer to the kernel handle.
2355+
// For CUDA, it's just the device stub. For HIP, it's something different.
2356+
if (langOpts.CUDA && !langOpts.CUDAIsDevice && langOpts.HIP &&
23592357
cast<FunctionDecl>(GD.getDecl())->hasAttr<CUDAGlobalAttr>()) {
23602358
llvm_unreachable("NYI");
23612359
}
@@ -2398,7 +2396,7 @@ static std::string getMangledNameImpl(CIRGenModule &CGM, GlobalDecl GD,
23982396
assert(0 && "NYI");
23992397
} else if (FD && FD->hasAttr<CUDAGlobalAttr>() &&
24002398
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
2401-
assert(0 && "NYI");
2399+
Out << "__device_stub__";
24022400
} else {
24032401
Out << II->getName();
24042402
}

clang/test/CIR/CodeGen/CUDA/simple-device.cu

Lines changed: 0 additions & 14 deletions
This file was deleted.

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
11
#include "../Inputs/cuda.h"
22

3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4+
// RUN: -x cuda -emit-cir %s -o %t.cir
5+
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
6+
37
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
4-
// RUN: -emit-cir %s -o %t.cir
5-
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
8+
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
9+
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
610

11+
// Attribute for global_fn
12+
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}
713

8-
// This should emit as a normal C++ function.
914
__host__ void host_fn(int *a, int *b, int *c) {}
15+
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
16+
// CIR-DEVICE-NOT: cir.func @_Z7host_fnPiS_S_
1017

11-
// CIR: cir.func @_Z7host_fnPiS_S_
12-
13-
// This shouldn't emit.
1418
__device__ void device_fn(int* a, double b, float c) {}
19+
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
20+
// CIR-DEVICE: cir.func @_Z9device_fnPidf
21+
22+
#ifdef __CUDA_ARCH__
23+
__global__ void global_fn() {}
24+
#else
25+
__global__ void global_fn();
26+
#endif
27+
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
28+
// CIR-DEVICE: @_Z9global_fnv
1529

16-
// CHECK-NOT: cir.func @_Z9device_fnPidf
30+
// Make sure `global_fn` indeed gets emitted
31+
__host__ void x() { auto v = global_fn; }

0 commit comments

Comments
 (0)