Skip to content

Commit 684b928

Browse files
author
Jeff Niu
authored
Revert "[mlir][NVVM] Disallow results on kernel functions (#96399)"
This reverts commit 346c4a8.
1 parent b468804 commit 684b928

File tree

2 files changed

+3
-17
lines changed

2 files changed

+3
-17
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,7 @@ void MmaOp::print(OpAsmPrinter &p) {
214214
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
215215

216216
// Print the types of the operands and result.
217-
p << " : "
218-
<< "(";
217+
p << " : " << "(";
219218
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
220219
frags[1].regs[0].getType(),
221220
frags[2].regs[0].getType()},
@@ -956,9 +955,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
956955
ss << "},";
957956
// Need to map read/write registers correctly.
958957
regCnt = (regCnt * 2);
959-
ss << " $" << (regCnt) << ","
960-
<< " $" << (regCnt + 1) << ","
961-
<< " p";
958+
ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
962959
if (getTypeD() != WGMMATypes::s32) {
963960
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
964961
}
@@ -1056,14 +1053,10 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
10561053
StringAttr attrName = attr.getName();
10571054
// Kernel function attribute should be attached to functions.
10581055
if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1059-
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
1060-
if (!funcOp) {
1056+
if (!isa<LLVM::LLVMFuncOp>(op)) {
10611057
return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
10621058
<< "' attribute attached to unexpected op";
10631059
}
1064-
if (!funcOp.getResultTypes().empty()) {
1065-
return op->emitError() << "kernel function cannot have results";
1066-
}
10671060
}
10681061
// If maxntid and reqntid exist, it must be an array with max 3 dim
10691062
if (attrName == NVVMDialect::getMaxntidAttrName() ||

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,3 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant})
574574
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
575575
llvm.return
576576
}
577-
578-
// -----
579-
580-
// expected-error @below{{kernel function cannot have results}}
581-
llvm.func @kernel_with_result(%i: i32) -> i32 attributes {nvvm.kernel} {
582-
llvm.return %i : i32
583-
}

0 commit comments

Comments
 (0)