Skip to content

Commit 8588014

Browse files
authored
[flang][cuda] Add kernel registration in CUF constructor (#112416)
Update the CUF constructor with the cuf.register_kernel operations.
1 parent 23da169 commit 8588014

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def CufImplicitDeviceGlobal :
439439
def CUFAddConstructor : Pass<"cuf-add-constructor", "mlir::ModuleOp"> {
440440
let summary = "Add constructor to register CUDA Fortran allocators";
441441
let dependentDialects = [
442-
"mlir::func::FuncDialect"
442+
"cuf::CUFDialect", "mlir::func::FuncDialect"
443443
];
444444
}
445445

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ add_flang_library(FIRTransforms
4949
HLFIRDialect
5050
MLIRAffineUtils
5151
MLIRFuncDialect
52+
MLIRGPUDialect
5253
MLIRLLVMDialect
5354
MLIRLLVMCommonConversion
5455
MLIRMathTransforms

flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "flang/Optimizer/Dialect/FIRDialect.h"
1313
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
1414
#include "flang/Runtime/entry-names.h"
15+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1516
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1617
#include "mlir/Pass/Pass.h"
1718
#include "llvm/ADT/SmallVector.h"
@@ -23,6 +24,8 @@ namespace fir {
2324

2425
namespace {
2526

27+
static constexpr llvm::StringRef cudaModName{"cuda_device_mod"};
28+
2629
static constexpr llvm::StringRef cudaFortranCtorName{
2730
"__cudaFortranConstructor"};
2831

@@ -31,6 +34,7 @@ struct CUFAddConstructor
3134

3235
void runOnOperation() override {
3336
mlir::ModuleOp mod = getOperation();
37+
mlir::SymbolTable symTab(mod);
3438
mlir::OpBuilder builder{mod.getBodyRegion()};
3539
builder.setInsertionPointToEnd(mod.getBody());
3640
mlir::Location loc = mod.getLoc();
@@ -48,13 +52,25 @@ struct CUFAddConstructor
4852
mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
4953
builder.setInsertionPointToEnd(mod.getBody());
5054

51-
// Create the constructor function that cal CUFRegisterAllocator.
52-
builder.setInsertionPointToEnd(mod.getBody());
55+
// Create the constructor function that call CUFRegisterAllocator.
5356
auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
5457
funcTy);
5558
func.setLinkage(mlir::LLVM::Linkage::Internal);
5659
builder.setInsertionPointToStart(func.addEntryBlock(builder));
5760
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
61+
62+
// Register kernels
63+
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
64+
if (gpuMod) {
65+
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
66+
if (func.isKernel()) {
67+
auto kernelName = mlir::SymbolRefAttr::get(
68+
builder.getStringAttr(cudaModName),
69+
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
70+
builder.create<cuf::RegisterKernelOp>(loc, kernelName);
71+
}
72+
}
73+
}
5874
builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
5975

6076
// Create the llvm.global_ctor with the function.
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: fir-opt %s | FileCheck %s
1+
// RUN: fir-opt --cuf-add-constructor %s | FileCheck %s
22

33
module attributes {gpu.container_module} {
44
gpu.module @cuda_device_mod {
@@ -9,12 +9,8 @@ module attributes {gpu.container_module} {
99
gpu.return
1010
}
1111
}
12-
llvm.func internal @__cudaFortranConstructor() {
13-
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
14-
cuf.register_kernel @cuda_device_mod::@_QPsub_device2
15-
llvm.return
16-
}
1712
}
1813

14+
// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
1915
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
2016
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2

0 commit comments

Comments
 (0)