Skip to content

Commit 6f7f45b

Browse files
committed
better verification
1 parent 89d5fe9 commit 6f7f45b

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -531,33 +531,35 @@ LogicalResult WarpgroupMmaOp::verify() {
531531
}
532532

533533
LogicalResult WarpgroupMmaStoreOp::verify() {
534-
Type stype = getMatrixD()
535-
.front()
536-
.getType()
537-
.cast<WarpgroupAccumulatorType>()
538-
.getFragmented();
539-
534+
MemRefType dstMemrefType = getDstMemref().getType();
535+
VectorType firstVtype = getMatrixD()
536+
.front()
537+
.getType()
538+
.cast<WarpgroupAccumulatorType>()
539+
.getFragmented();
540+
541+
int64_t totalFirstDimension = 0;
540542
for (auto result : getMatrixD()) {
541-
auto resultStype = result.getType()
542-
.cast<WarpgroupAccumulatorType>()
543-
.getFragmented()
544-
.dyn_cast<LLVM::LLVMStructType>();
545-
if (!resultStype)
546-
return emitOpError() << "result is " << result.getType()
547-
<< " but must keep type of llvm struct";
548-
if (stype != resultStype)
549-
return emitOpError() << "all results must be the same type";
550-
551-
// todo improve this limitation
552-
if (!resultStype.getBody().front().isF32()) {
553-
return emitOpError() << "supporst only f32 results for the time being";
543+
VectorType vtype =
544+
result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
545+
if (vtype != firstVtype)
546+
return emitOpError() << "all fragmented types must be the same";
547+
// Limitation
548+
if (!vtype.getElementType().isF32()) {
549+
return emitOpError()
550+
<< "hit a limitation: only f32 results for the time being";
554551
}
552+
totalFirstDimension += vtype.getDimSize(0);
555553
}
556-
557-
if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
558-
return emitOpError() << "all element types must be equal ";
554+
if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
555+
firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
556+
return emitOpError() << "results [" << totalFirstDimension << "]["
557+
<< firstVtype.getDimSize(1)
558+
<< "] values. However, destination memref["
559+
<< dstMemrefType.getDimSize(0) << "]["
560+
<< dstMemrefType.getDimSize(1)
561+
<< "] does not have same size as results";
559562
}
560-
561563
return success();
562564
}
563565

0 commit comments

Comments
 (0)