Skip to content

Commit f2e5417

Browse files
committed
simplifying tweaks
1 parent cce1483 commit f2e5417

File tree

2 files changed

+9
-24
lines changed

2 files changed

+9
-24
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6481,18 +6481,7 @@ struct TransposeToShapeCast final
64816481
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
64826482
PatternRewriter &rewriter) const override {
64836483

6484-
// This folder does
6485-
// shape_cast(transpose) -> shape_cast
6486-
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
6487-
// shape_cast -> shape_cast(transpose)
6488-
// i.e. the complete opposite. When paired, these 2 patterns can cause
6489-
// infinite cycles in pattern rewriting.
6490-
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
6491-
// vectors, so by disabling this folder for scalable vectors the
6492-
// cycle is avoided.
6493-
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
6494-
// still needed. If it's not, then we can fold here.
6495-
if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) {
6484+
if (!isOrderPreserving(transpose)) {
64966485
return rewriter.notifyMatchFailure(
64976486
transpose, "not order preserving, so not semantically a 'copy'");
64986487
}

mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,25 +100,21 @@ func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector
100100

101101
// -----
102102

103-
// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
104-
// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
105-
// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
106-
// CHECK: vector.transpose
107-
// CHECK: vector.shape_cast
108-
func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
103+
// CHECK-LABEL: @shape_cast_of_transpose_scalable
104+
// CHECK-NEXT: vector.shape_cast
105+
// CHECK-NEXT: return
106+
func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
109107
%0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
110108
%1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
111109
return %1 : vector<[4]xi8>
112110
}
113111

114112
// -----
115113

116-
// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable
117-
// vectors.
118-
// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable
119-
// CHECK: vector.shape_cast
120-
// CHECK: vector.transpose
121-
func.func @negative_transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
114+
// CHECK-LABEL: @transpose_of_shape_cast_scalable
115+
// CHECK-NEXT: vector.shape_cast
116+
// CHECK-NEXT: return
117+
func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> {
122118
%0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8>
123119
%1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8>
124120
return %1 : vector<[4]x1xi8>

0 commit comments

Comments
 (0)