@@ -531,33 +531,35 @@ LogicalResult WarpgroupMmaOp::verify() {
531
531
}
532
532
533
533
LogicalResult WarpgroupMmaStoreOp::verify () {
534
- Type stype = getMatrixD ()
535
- .front ()
536
- .getType ()
537
- .cast <WarpgroupAccumulatorType>()
538
- .getFragmented ();
539
-
534
+ MemRefType dstMemrefType = getDstMemref ().getType ();
535
+ VectorType firstVtype = getMatrixD ()
536
+ .front ()
537
+ .getType ()
538
+ .cast <WarpgroupAccumulatorType>()
539
+ .getFragmented ();
540
+
541
+ int64_t totalFirstDimension = 0 ;
540
542
for (auto result : getMatrixD ()) {
541
- auto resultStype = result.getType ()
542
- .cast <WarpgroupAccumulatorType>()
543
- .getFragmented ()
544
- .dyn_cast <LLVM::LLVMStructType>();
545
- if (!resultStype)
546
- return emitOpError () << " result is " << result.getType ()
547
- << " but must keep type of llvm struct" ;
548
- if (stype != resultStype)
549
- return emitOpError () << " all results must be the same type" ;
550
-
551
- // todo improve this limitation
552
- if (!resultStype.getBody ().front ().isF32 ()) {
553
- return emitOpError () << " supporst only f32 results for the time being" ;
543
+ VectorType vtype =
544
+ result.getType ().cast <WarpgroupAccumulatorType>().getFragmented ();
545
+ if (vtype != firstVtype)
546
+ return emitOpError () << " all fragmented types must be the same" ;
547
+ // Limitation
548
+ if (!vtype.getElementType ().isF32 ()) {
549
+ return emitOpError ()
550
+ << " hit a limitation: only f32 results for the time being" ;
554
551
}
552
+ totalFirstDimension += vtype.getDimSize (0 );
555
553
}
556
-
557
- if (!llvm::all_equal (stype.cast <LLVM::LLVMStructType>().getBody ())) {
558
- return emitOpError () << " all element types must be equal " ;
554
+ if (totalFirstDimension != dstMemrefType.getDimSize (0 ) ||
555
+ firstVtype.getDimSize (1 ) != dstMemrefType.getDimSize (1 )) {
556
+ return emitOpError () << " results [" << totalFirstDimension << " ]["
557
+ << firstVtype.getDimSize (1 )
558
+ << " ] values. However, destination memref["
559
+ << dstMemrefType.getDimSize (0 ) << " ]["
560
+ << dstMemrefType.getDimSize (1 )
561
+ << " ] does not have same size as results" ;
559
562
}
560
-
561
563
return success ();
562
564
}
563
565
0 commit comments