@@ -214,7 +214,8 @@ void MmaOp::print(OpAsmPrinter &p) {
214
214
p.printOptionalAttrDict (this ->getOperation ()->getAttrs (), ignoreAttrNames);
215
215
216
216
// Print the types of the operands and result.
217
- p << " : " << " (" ;
217
+ p << " : "
218
+ << " (" ;
218
219
llvm::interleaveComma (SmallVector<Type, 3 >{frags[0 ].regs [0 ].getType (),
219
220
frags[1 ].regs [0 ].getType (),
220
221
frags[2 ].regs [0 ].getType ()},
@@ -955,7 +956,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
955
956
ss << " }," ;
956
957
// Need to map read/write registers correctly.
957
958
regCnt = (regCnt * 2 );
958
- ss << " $" << (regCnt) << " ," << " $" << (regCnt + 1 ) << " ," << " p" ;
959
+ ss << " $" << (regCnt) << " ,"
960
+ << " $" << (regCnt + 1 ) << " ,"
961
+ << " p" ;
959
962
if (getTypeD () != WGMMATypes::s32) {
960
963
ss << " , $" << (regCnt + 3 ) << " , $" << (regCnt + 4 );
961
964
}
@@ -1053,10 +1056,14 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1053
1056
StringAttr attrName = attr.getName ();
1054
1057
// Kernel function attribute should be attached to functions.
1055
1058
if (attrName == NVVMDialect::getKernelFuncAttrName ()) {
1056
- if (!isa<LLVM::LLVMFuncOp>(op)) {
1059
+ auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
1060
+ if (!funcOp) {
1057
1061
return op->emitError () << " '" << NVVMDialect::getKernelFuncAttrName ()
1058
1062
<< " ' attribute attached to unexpected op" ;
1059
1063
}
1064
+ if (!funcOp.getResultTypes ().empty ()) {
1065
+ return op->emitError () << " kernel function cannot have results" ;
1066
+ }
1060
1067
}
1061
1068
// If maxntid and reqntid exist, it must be an array with max 3 dim
1062
1069
if (attrName == NVVMDialect::getMaxntidAttrName () ||
0 commit comments