Skip to content

Commit 34df537

Browse files
authored
Revert "[mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (#73951)" (#74579)
This reverts commit f42b761. The fold pattern is incorrect, because it does not even look at the permutation of non-unit dims and is happy to replace a pattern such as ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> %23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32> ``` with ``` %22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32> ``` which is obviously incorrect.
1 parent 6b1aa31 commit 34df537

File tree

2 files changed

+1
-58
lines changed

2 files changed

+1
-58
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5548,57 +5548,12 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
55485548
}
55495549
};
55505550

5551-
/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
5552-
/// permutes a unit dim from the result of the shape_cast.
5553-
class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
5554-
using OpRewritePattern::OpRewritePattern;
5555-
5556-
LogicalResult matchAndRewrite(TransposeOp transpOp,
5557-
PatternRewriter &rewriter) const override {
5558-
Value transposeSrc = transpOp.getVector();
5559-
auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
5560-
if (!shapeCastOp)
5561-
return rewriter.notifyMatchFailure(
5562-
transpOp, "TransposeOp source is not ShapeCastOp");
5563-
5564-
auto sourceType = transpOp.getSourceVectorType();
5565-
auto resultType = transpOp.getResultVectorType();
5566-
5567-
auto filterUnitDims = [](VectorType type) {
5568-
return llvm::make_filter_range(
5569-
llvm::zip_equal(type.getShape(), type.getScalableDims()),
5570-
[&](auto dim) {
5571-
auto [size, isScalable] = dim;
5572-
return size != 1 || isScalable;
5573-
});
5574-
};
5575-
5576-
auto sourceWithoutUnitDims = filterUnitDims(sourceType);
5577-
auto resultWithoutUnitDims = filterUnitDims(resultType);
5578-
5579-
// If this transpose just permutes a unit dim, then we can fold it into the
5580-
// shape_cast.
5581-
for (auto [srcDim, resDim] :
5582-
llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
5583-
if (srcDim != resDim)
5584-
return rewriter.notifyMatchFailure(transpOp,
5585-
"TransposeOp permutes non-unit dim");
5586-
}
5587-
5588-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
5589-
shapeCastOp.getSource());
5590-
5591-
return success();
5592-
};
5593-
};
5594-
55955551
} // namespace
55965552

55975553
void vector::TransposeOp::getCanonicalizationPatterns(
55985554
RewritePatternSet &results, MLIRContext *context) {
55995555
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5600-
TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
5601-
context);
5556+
TransposeFolder, FoldTransposeSplat>(context);
56025557
}
56035558

56045559
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,6 @@ func.func @create_mask_transpose_to_transposed_create_mask(
6767

6868
// -----
6969

70-
// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
71-
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
72-
func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
73-
// CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
74-
// CHECK-NOT: vector.transpose
75-
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
76-
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
77-
return %1 : vector<1x[4]xf32>
78-
}
79-
80-
// -----
81-
8270
// CHECK-LABEL: extract_from_create_mask
8371
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
8472
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {

0 commit comments

Comments
 (0)