Skip to content

Commit 7e72e5b

Browse files
authored
Reland '[flang][cuda] Add cuf.register_kernel operation' (#112389)
The operation will be used in the CUF constructor to register the kernel functions. This allow to delay this until codegen when the gpu.binary will be available. Reland of #112268 with correct shared library build support.
1 parent 583fa4f commit 7e72e5b

File tree

6 files changed

+128
-0
lines changed

6 files changed

+128
-0
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,23 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
288288
let hasVerifier = 1;
289289
}
290290

291+
def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
292+
let summary = "Register a CUDA kernel";
293+
294+
let arguments = (ins
295+
SymbolRefAttr:$name
296+
);
297+
298+
let assemblyFormat = [{
299+
$name attr-dict
300+
}];
301+
302+
let hasVerifier = 1;
303+
304+
let extraClassDeclaration = [{
305+
mlir::StringAttr getKernelName();
306+
mlir::StringAttr getKernelModuleName();
307+
}];
308+
}
309+
291310
#endif // FORTRAN_DIALECT_CUF_CUF_OPS

flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_flang_library(CUFDialect
1414
FIRDialect
1515
FIRDialectSupport
1616
MLIRIR
17+
MLIRGPUDialect
1718
MLIRTargetLLVMIRExport
1819

1920
LINK_COMPONENTS

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
1616
#include "flang/Optimizer/Dialect/FIRAttr.h"
1717
#include "flang/Optimizer/Dialect/FIRType.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1819
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/BuiltinAttributes.h"
2021
#include "mlir/IR/BuiltinOps.h"
@@ -253,6 +254,42 @@ llvm::LogicalResult cuf::KernelOp::verify() {
253254
return mlir::success();
254255
}
255256

257+
//===----------------------------------------------------------------------===//
258+
// RegisterKernelOp
259+
//===----------------------------------------------------------------------===//
260+
261+
mlir::StringAttr cuf::RegisterKernelOp::getKernelModuleName() {
262+
return getName().getRootReference();
263+
}
264+
265+
mlir::StringAttr cuf::RegisterKernelOp::getKernelName() {
266+
return getName().getLeafReference();
267+
}
268+
269+
mlir::LogicalResult cuf::RegisterKernelOp::verify() {
270+
if (getKernelName() == getKernelModuleName())
271+
return emitOpError("expect a module and a kernel name");
272+
273+
auto mod = getOperation()->getParentOfType<mlir::ModuleOp>();
274+
if (!mod)
275+
return emitOpError("expect to be in a module");
276+
277+
mlir::SymbolTable symTab(mod);
278+
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName());
279+
if (!gpuMod)
280+
return emitOpError("gpu module not found");
281+
282+
mlir::SymbolTable gpuSymTab(gpuMod);
283+
auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName());
284+
if (!func)
285+
return emitOpError("device function not found");
286+
287+
if (!func.isKernel())
288+
return emitOpError("only kernel gpu.func can be registered");
289+
290+
return mlir::success();
291+
}
292+
256293
// Tablegen operators
257294

258295
#define GET_OP_CLASSES
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: fir-opt %s | FileCheck %s
2+
3+
module attributes {gpu.container_module} {
4+
gpu.module @cuda_device_mod {
5+
gpu.func @_QPsub_device1() kernel {
6+
gpu.return
7+
}
8+
gpu.func @_QPsub_device2(%arg0: !fir.ref<f32>) kernel {
9+
gpu.return
10+
}
11+
}
12+
llvm.func internal @__cudaFortranConstructor() {
13+
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
14+
cuf.register_kernel @cuda_device_mod::@_QPsub_device2
15+
llvm.return
16+
}
17+
}
18+
19+
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
20+
// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2

flang/test/Fir/cuf-invalid.fir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,53 @@ func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda
125125
cuf.data_transfer %20#0 to %11#0, %19 : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
126126
return
127127
}
128+
129+
// -----
130+
131+
module attributes {gpu.container_module} {
132+
gpu.module @cuda_device_mod {
133+
gpu.func @_QPsub_device1() {
134+
gpu.return
135+
}
136+
}
137+
llvm.func internal @__cudaFortranConstructor() {
138+
// expected-error@+1{{'cuf.register_kernel' op only kernel gpu.func can be registered}}
139+
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
140+
llvm.return
141+
}
142+
}
143+
144+
// -----
145+
146+
module attributes {gpu.container_module} {
147+
gpu.module @cuda_device_mod {
148+
gpu.func @_QPsub_device1() {
149+
gpu.return
150+
}
151+
}
152+
llvm.func internal @__cudaFortranConstructor() {
153+
// expected-error@+1{{'cuf.register_kernel' op device function not found}}
154+
cuf.register_kernel @cuda_device_mod::@_QPsub_device2
155+
llvm.return
156+
}
157+
}
158+
159+
// -----
160+
161+
module attributes {gpu.container_module} {
162+
llvm.func internal @__cudaFortranConstructor() {
163+
// expected-error@+1{{'cuf.register_kernel' op gpu module not found}}
164+
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
165+
llvm.return
166+
}
167+
}
168+
169+
// -----
170+
171+
module attributes {gpu.container_module} {
172+
llvm.func internal @__cudaFortranConstructor() {
173+
// expected-error@+1{{'cuf.register_kernel' op expect a module and a kernel name}}
174+
cuf.register_kernel @_QPsub_device1
175+
llvm.return
176+
}
177+
}

flang/tools/fir-opt/fir-opt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ int main(int argc, char **argv) {
4242
#endif
4343
DialectRegistry registry;
4444
fir::support::registerDialects(registry);
45+
registry.insert<mlir::gpu::GPUDialect>();
4546
fir::support::addFIRExtensions(registry);
4647
return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
4748
registry));

0 commit comments

Comments
 (0)