Skip to content

Commit 8545557

Browse files
committed
Comments
1 parent 6e0be2e commit 8545557

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,10 @@ struct ScalableTransposeTransferWriteConversion
10631063
writeOp, "lowering tensor transfers is disabled");
10641064
}
10651065

1066-
Value vector = writeOp.getVector();
10671066
VectorType vectorType = writeOp.getVectorType();
1067+
1068+
// Note: By comparing the scalable dims to an ArrayRef of length two this
1069+
// implicitly checks the rank (is also two).
10681070
ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
10691071
if (scalableFlags != ArrayRef<bool>{true, false}) {
10701072
return rewriter.notifyMatchFailure(
@@ -1077,11 +1079,15 @@ struct ScalableTransposeTransferWriteConversion
10771079
writeOp, "non-identity permutations are unsupported (lower first)");
10781080
}
10791081

1082+
// Note: This pattern is only lowering the leading dimension (to a loop),
1083+
// so we only check if the leading dimension is in bounds. The in-bounds
1084+
// attribute for the trailing dimension will be propagated.
10801085
if (!writeOp.isDimInBounds(0)) {
10811086
return rewriter.notifyMatchFailure(
10821087
writeOp, "out-of-bounds dims are unsupported (use masking)");
10831088
}
10841089

1090+
Value vector = writeOp.getVector();
10851091
auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
10861092
if (!transposeOp ||
10871093
transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {

0 commit comments

Comments
 (0)