Skip to content

Commit 834d001

Browse files
authored
[flang][cuda] Relax the verifier for cuf.register_kernel op (#112585)
Relax the verifier since the `gpu.func` might be converted to `llvm.func` before `cuf.register_kernel` is converted.
1 parent 98b419c commit 834d001

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Optimizer/Dialect/FIRAttr.h"
1717
#include "flang/Optimizer/Dialect/FIRType.h"
1818
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1920
#include "mlir/IR/Attributes.h"
2021
#include "mlir/IR/BuiltinAttributes.h"
2122
#include "mlir/IR/BuiltinOps.h"
@@ -276,18 +277,26 @@ mlir::LogicalResult cuf::RegisterKernelOp::verify() {
276277

277278
mlir::SymbolTable symTab(mod);
278279
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName());
279-
if (!gpuMod)
280+
if (!gpuMod) {
281+
// If already a gpu.binary then stop the check here.
282+
if (symTab.lookup<mlir::gpu::BinaryOp>(getKernelModuleName()))
283+
return mlir::success();
280284
return emitOpError("gpu module not found");
285+
}
281286

282287
mlir::SymbolTable gpuSymTab(gpuMod);
283-
auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName());
284-
if (!func)
285-
return emitOpError("device function not found");
286-
287-
if (!func.isKernel())
288-
return emitOpError("only kernel gpu.func can be registered");
289-
290-
return mlir::success();
288+
if (auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName())) {
289+
if (!func.isKernel())
290+
return emitOpError("only kernel gpu.func can be registered");
291+
return mlir::success();
292+
} else if (auto func =
293+
gpuSymTab.lookup<mlir::LLVM::LLVMFuncOp>(getKernelName())) {
294+
if (!func->getAttrOfType<mlir::UnitAttr>(
295+
mlir::gpu::GPUDialect::getKernelFuncAttrName()))
296+
return emitOpError("only gpu.kernel llvm.func can be registered");
297+
return mlir::success();
298+
}
299+
return emitOpError("device function not found");
291300
}
292301

293302
// Tablegen operators

flang/test/Fir/cuf-invalid.fir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,18 @@ module attributes {gpu.container_module} {
175175
llvm.return
176176
}
177177
}
178+
179+
// -----
180+
181+
module attributes {gpu.container_module} {
182+
gpu.module @cuda_device_mod {
183+
llvm.func @_QPsub_device1() {
184+
llvm.return
185+
}
186+
}
187+
llvm.func internal @__cudaFortranConstructor() {
188+
// expected-error@+1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}}
189+
cuf.register_kernel @cuda_device_mod::@_QPsub_device1
190+
llvm.return
191+
}
192+
}

0 commit comments

Comments
 (0)