Skip to content

Commit b25b586

Browse files
committed
Revert "[mlir][vector] Move transpose with unit-dim to shape_cast pattern (llvm#72493)"
This reverts commit 95acb33.
1 parent 88f0e4c commit b25b586

File tree

4 files changed

+70
-89
lines changed

4 files changed

+70
-89
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5552,49 +5552,12 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
55525552
}
55535553
};
55545554

5555-
/// Folds transpose with non-scalable unit dims into a shape_cast.
5556-
///
5557-
/// Replace:
5558-
/// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
5559-
/// vector<1xnxelty>
5560-
/// with:
5561-
/// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
5562-
///
5563-
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
5564-
/// be fixed. Non-unit dims can be scalable.
5565-
class FoldTransposeWithNonScalableUnitDimsToShapeCast final
5566-
: public OpRewritePattern<TransposeOp> {
5567-
public:
5568-
using OpRewritePattern::OpRewritePattern;
5569-
5570-
LogicalResult matchAndRewrite(TransposeOp transpOp,
5571-
PatternRewriter &rewriter) const override {
5572-
Value input = transpOp.getVector();
5573-
VectorType resType = transpOp.getResultVectorType();
5574-
ArrayRef<int64_t> permutation = transpOp.getPermutation();
5575-
5576-
if (resType.getRank() == 2 &&
5577-
((resType.getShape().front() == 1 &&
5578-
!resType.getScalableDims().front()) ||
5579-
(resType.getShape().back() == 1 &&
5580-
!resType.getScalableDims().back())) &&
5581-
permutation == ArrayRef<int64_t>({1, 0})) {
5582-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resType,
5583-
input);
5584-
return success();
5585-
}
5586-
5587-
return failure();
5588-
}
5589-
};
5590-
55915555
} // namespace
55925556

55935557
void vector::TransposeOp::getCanonicalizationPatterns(
55945558
RewritePatternSet &results, MLIRContext *context) {
55955559
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5596-
TransposeFolder, FoldTransposeSplat,
5597-
FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
5560+
TransposeFolder, FoldTransposeSplat>(context);
55985561
}
55995562

56005563
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,24 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
334334
return rewriter.notifyMatchFailure(
335335
op, "Options specifies lowering to shuffle");
336336

337+
// Replace:
338+
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
339+
// vector<1xnxelty>
340+
// with:
341+
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
342+
//
343+
// Source with leading unit dim (inverse) is also replaced. Unit dim must
344+
// be fixed. Non-unit can be scalable.
345+
if (resType.getRank() == 2 &&
346+
((resType.getShape().front() == 1 &&
347+
!resType.getScalableDims().front()) ||
348+
(resType.getShape().back() == 1 &&
349+
!resType.getScalableDims().back())) &&
350+
transp == ArrayRef<int64_t>({1, 0})) {
351+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
352+
return success();
353+
}
354+
337355
if (inputType.isScalable())
338356
return failure();
339357

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,54 +2524,3 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
25242524
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
25252525
return %r : vector<1x100x4x5xf32>
25262526
}
2527-
2528-
// -----
2529-
2530-
/// Transpose of rank-2 vector with leading or trailing non-scalable unit dim to shape_cast.
2531-
2532-
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32
2533-
func.func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
2534-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
2535-
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
2536-
return %0 : vector<1x4xf32>
2537-
}
2538-
2539-
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32
2540-
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
2541-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
2542-
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
2543-
return %0 : vector<1x[4]xf32>
2544-
}
2545-
2546-
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32
2547-
func.func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
2548-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
2549-
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
2550-
return %0 : vector<4x1xf32>
2551-
}
2552-
2553-
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32
2554-
func.func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
2555-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
2556-
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
2557-
return %0 : vector<[4]x1xf32>
2558-
}
2559-
2560-
/// Scalable unit dim should not be lowered to shape_cast.
2561-
2562-
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32
2563-
func.func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
2564-
// CHECK-NOT: vector.shape_cast
2565-
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
2566-
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
2567-
return %0 : vector<[1]x4xf32>
2568-
}
2569-
2570-
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32
2571-
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
2572-
// CHECK-NOT: vector.shape_cast
2573-
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
2574-
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
2575-
2576-
return %0 : vector<[1]x4xf32>
2577-
}

mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,57 @@ module attributes {transform.with_named_sequence} {
790790
}
791791
}
792792

793+
// -----
794+
795+
/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
796+
797+
// CHECK-LABEL: func @transpose10_4x1xf32
798+
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
799+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
800+
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
801+
return %0 : vector<1x4xf32>
802+
}
803+
804+
// CHECK-LABEL: func @transpose10_nx4x1xf32
805+
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
806+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
807+
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
808+
return %0 : vector<1x[4]xf32>
809+
}
810+
811+
// CHECK-LABEL: func @transpose10_1x4xf32
812+
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
813+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
814+
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
815+
return %0 : vector<4x1xf32>
816+
}
817+
818+
// CHECK-LABEL: func @transpose10_1xnx4xf32
819+
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
820+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
821+
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
822+
return %0 : vector<[4]x1xf32>
823+
}
824+
825+
/// Scalable unit dim should not be lowered to shape_cast.
826+
827+
// CHECK-LABEL: func @transpose10_4xnx1xf32
828+
func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
829+
// CHECK-NOT: vector.shape_cast
830+
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
831+
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
832+
return %0 : vector<[1]x4xf32>
833+
}
834+
835+
// CHECK-LABEL: func @transpose10_nx4xnx1xf32
836+
func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
837+
// CHECK-NOT: vector.shape_cast
838+
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
839+
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
840+
841+
return %0 : vector<[1]x4xf32>
842+
}
843+
793844
module attributes {transform.with_named_sequence} {
794845
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
795846
transform.apply_patterns to %func_op {

0 commit comments

Comments
 (0)