Skip to content

Commit aa2dc79

Browse files
[mlir][vector] Fix rewrite pattern API violation in VectorToSCF (#77909)
A rewrite pattern is not allowed to change the IR if it returns "failure". This commit fixes `test/Conversion/VectorToSCF/vector-to-scf.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. ``` Processing operation : 'vector.transfer_read'(0x55823a409a60) { %5 = "vector.transfer_read"(%arg0, %0, %0, %2, %4) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 1>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<?x4xf32>, index, index, f32, vector<[4]x4xi1>) -> vector<[4]x4xf32> * Pattern (anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion : 'vector.transfer_read -> ()' { Trying to match "(anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion" ** Insert : 'vector.splat'(0x55823a445640) "(anonymous namespace)::lowering_n_d_unrolled::UnrollTransferReadConversion" result 0 } -> failure : pattern failed to match LLVM ERROR: pattern returned failure but IR did change ```
1 parent 35708b0 commit aa2dc79

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,10 +1060,10 @@ struct UnrollTransferReadConversion
10601060
setHasBoundedRewriteRecursion();
10611061
}
10621062

1063-
/// Return the vector into which the newly created TransferReadOp results
1064-
/// are inserted.
1065-
Value getResultVector(TransferReadOp xferOp,
1066-
PatternRewriter &rewriter) const {
1063+
/// Get or build the vector into which the newly created TransferReadOp
1064+
/// results are inserted.
1065+
Value buildResultVector(PatternRewriter &rewriter,
1066+
TransferReadOp xferOp) const {
10671067
if (auto insertOp = getInsertOp(xferOp))
10681068
return insertOp.getDest();
10691069
Location loc = xferOp.getLoc();
@@ -1098,24 +1098,27 @@ struct UnrollTransferReadConversion
10981098
LogicalResult matchAndRewrite(TransferReadOp xferOp,
10991099
PatternRewriter &rewriter) const override {
11001100
if (xferOp.getVectorType().getRank() <= options.targetRank)
1101-
return failure();
1101+
return rewriter.notifyMatchFailure(
1102+
xferOp, "vector rank is less or equal to target rank");
11021103
if (isTensorOp(xferOp) && !options.lowerTensors)
1103-
return failure();
1104+
return rewriter.notifyMatchFailure(
1105+
xferOp, "transfers operating on tensors are excluded");
11041106
// Transfer ops that modify the element type are not supported atm.
11051107
if (xferOp.getVectorType().getElementType() !=
11061108
xferOp.getShapedType().getElementType())
1107-
return failure();
1108-
1109-
auto insertOp = getInsertOp(xferOp);
1110-
auto vec = getResultVector(xferOp, rewriter);
1111-
auto vecType = dyn_cast<VectorType>(vec.getType());
1109+
return rewriter.notifyMatchFailure(
1110+
xferOp, "not yet supported: element type mismatch");
11121111
auto xferVecType = xferOp.getVectorType();
1113-
11141112
if (xferVecType.getScalableDims()[0]) {
11151113
// Cannot unroll a scalable dimension at compile time.
1116-
return failure();
1114+
return rewriter.notifyMatchFailure(
1115+
xferOp, "scalable dimensions cannot be unrolled");
11171116
}
11181117

1118+
auto insertOp = getInsertOp(xferOp);
1119+
auto vec = buildResultVector(rewriter, xferOp);
1120+
auto vecType = dyn_cast<VectorType>(vec.getType());
1121+
11191122
VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
11201123

11211124
int64_t dimSize = xferVecType.getShape()[0];

0 commit comments

Comments
 (0)