Skip to content

Commit e777407

Browse files
matthias-springerAlexisPerry
authored andcommitted
[mlir][NVVM] Disallow results on kernel functions (llvm#96399)
Functions that have the `nvvm.kernel` attribute should have 0 results.
1 parent 2216386 commit e777407

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ void MmaOp::print(OpAsmPrinter &p) {
214214
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
215215

216216
// Print the types of the operands and result.
217-
p << " : " << "(";
217+
p << " : "
218+
<< "(";
218219
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
219220
frags[1].regs[0].getType(),
220221
frags[2].regs[0].getType()},
@@ -955,7 +956,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
955956
ss << "},";
956957
// Need to map read/write registers correctly.
957958
regCnt = (regCnt * 2);
958-
ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
959+
ss << " $" << (regCnt) << ","
960+
<< " $" << (regCnt + 1) << ","
961+
<< " p";
959962
if (getTypeD() != WGMMATypes::s32) {
960963
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
961964
}
@@ -1053,10 +1056,14 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
10531056
StringAttr attrName = attr.getName();
10541057
// Kernel function attribute should be attached to functions.
10551058
if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1056-
if (!isa<LLVM::LLVMFuncOp>(op)) {
1059+
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
1060+
if (!funcOp) {
10571061
return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
10581062
<< "' attribute attached to unexpected op";
10591063
}
1064+
if (!funcOp.getResultTypes().empty()) {
1065+
return op->emitError() << "kernel function cannot have results";
1066+
}
10601067
}
10611068
// If maxntid and reqntid exist, it must be an array with max 3 dim
10621069
if (attrName == NVVMDialect::getMaxntidAttrName() ||

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,10 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant})
574574
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
575575
llvm.return
576576
}
577+
578+
// -----
579+
580+
// expected-error @below{{kernel function cannot have results}}
581+
llvm.func @kernel_with_result(%i: i32) -> i32 attributes {nvvm.kernel} {
582+
llvm.return %i : i32
583+
}

0 commit comments

Comments
 (0)