|
16 | 16 | #include "flang/Optimizer/Dialect/FIRAttr.h"
|
17 | 17 | #include "flang/Optimizer/Dialect/FIRType.h"
|
18 | 18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
| 19 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | 20 | #include "mlir/IR/Attributes.h"
|
20 | 21 | #include "mlir/IR/BuiltinAttributes.h"
|
21 | 22 | #include "mlir/IR/BuiltinOps.h"
|
@@ -276,18 +277,26 @@ mlir::LogicalResult cuf::RegisterKernelOp::verify() {
|
276 | 277 |
|
277 | 278 | mlir::SymbolTable symTab(mod);
|
278 | 279 | auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName());
|
279 |
| - if (!gpuMod) |
| 280 | + if (!gpuMod) { |
| 281 | + // If already a gpu.binary then stop the check here. |
| 282 | + if (symTab.lookup<mlir::gpu::BinaryOp>(getKernelModuleName())) |
| 283 | + return mlir::success(); |
280 | 284 | return emitOpError("gpu module not found");
|
| 285 | + } |
281 | 286 |
|
282 | 287 | mlir::SymbolTable gpuSymTab(gpuMod);
|
283 |
| - auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName()); |
284 |
| - if (!func) |
285 |
| - return emitOpError("device function not found"); |
286 |
| - |
287 |
| - if (!func.isKernel()) |
288 |
| - return emitOpError("only kernel gpu.func can be registered"); |
289 |
| - |
290 |
| - return mlir::success(); |
| 288 | + if (auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName())) { |
| 289 | + if (!func.isKernel()) |
| 290 | + return emitOpError("only kernel gpu.func can be registered"); |
| 291 | + return mlir::success(); |
| 292 | + } else if (auto func = |
| 293 | + gpuSymTab.lookup<mlir::LLVM::LLVMFuncOp>(getKernelName())) { |
| 294 | + if (!func->getAttrOfType<mlir::UnitAttr>( |
| 295 | + mlir::gpu::GPUDialect::getKernelFuncAttrName())) |
| 296 | + return emitOpError("only gpu.kernel llvm.func can be registered"); |
| 297 | + return mlir::success(); |
| 298 | + } |
| 299 | + return emitOpError("device function not found"); |
291 | 300 | }
|
292 | 301 |
|
293 | 302 | // Tablegen operators
|
|
0 commit comments