Skip to content

Commit 95acb33

Browse files
authored
[mlir][vector] Move transpose with unit-dim to shape_cast pattern (#72493)
Moved from lowering to canonicalization.
1 parent e77af7e commit 95acb33

File tree

4 files changed

+91
-70
lines changed

4 files changed

+91
-70
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5564,12 +5564,51 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
55645564
}
55655565
};
55665566

5567+
/// Folds transpose with non-scalable unit dims into a shape_cast.
5568+
///
5569+
/// Replace:
5570+
/// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
5571+
/// vector<1xnxelty>
5572+
/// with:
5573+
/// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
5574+
///
5575+
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
5576+
/// be fixed. Non-unit dims can be scalable.
5577+
class FoldTransposeWithNonScalableUnitDimsToShapeCast final
5578+
: public OpRewritePattern<TransposeOp> {
5579+
public:
5580+
using OpRewritePattern::OpRewritePattern;
5581+
5582+
LogicalResult matchAndRewrite(TransposeOp transpOp,
5583+
PatternRewriter &rewriter) const override {
5584+
Value input = transpOp.getVector();
5585+
VectorType resType = transpOp.getResultVectorType();
5586+
5587+
SmallVector<int64_t> permutation;
5588+
transpOp.getTransp(permutation);
5589+
5590+
if (resType.getRank() == 2 &&
5591+
((resType.getShape().front() == 1 &&
5592+
!resType.getScalableDims().front()) ||
5593+
(resType.getShape().back() == 1 &&
5594+
!resType.getScalableDims().back())) &&
5595+
permutation == ArrayRef<int64_t>({1, 0})) {
5596+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resType,
5597+
input);
5598+
return success();
5599+
}
5600+
5601+
return failure();
5602+
}
5603+
};
5604+
55675605
} // namespace
55685606

55695607
void vector::TransposeOp::getCanonicalizationPatterns(
55705608
RewritePatternSet &results, MLIRContext *context) {
55715609
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5572-
TransposeFolder, FoldTransposeSplat>(context);
5610+
TransposeFolder, FoldTransposeSplat,
5611+
FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
55735612
}
55745613

55755614
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -336,24 +336,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
336336
return rewriter.notifyMatchFailure(
337337
op, "Options specifies lowering to shuffle");
338338

339-
// Replace:
340-
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
341-
// vector<1xnxelty>
342-
// with:
343-
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
344-
//
345-
// Source with leading unit dim (inverse) is also replaced. Unit dim must
346-
// be fixed. Non-unit can be scalable.
347-
if (resType.getRank() == 2 &&
348-
((resType.getShape().front() == 1 &&
349-
!resType.getScalableDims().front()) ||
350-
(resType.getShape().back() == 1 &&
351-
!resType.getScalableDims().back())) &&
352-
transp == ArrayRef<int64_t>({1, 0})) {
353-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
354-
return success();
355-
}
356-
357339
if (inputType.isScalable())
358340
return failure();
359341

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,3 +2524,54 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
25242524
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
25252525
return %r : vector<1x100x4x5xf32>
25262526
}
2527+
2528+
// -----
2529+
2530+
/// Transpose of rank-2 vector with leading or trailing non-scalable unit dim to shape_cast.
2531+
2532+
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32
2533+
func.func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
2534+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
2535+
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
2536+
return %0 : vector<1x4xf32>
2537+
}
2538+
2539+
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32
2540+
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
2541+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
2542+
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
2543+
return %0 : vector<1x[4]xf32>
2544+
}
2545+
2546+
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32
2547+
func.func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
2548+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
2549+
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
2550+
return %0 : vector<4x1xf32>
2551+
}
2552+
2553+
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32
2554+
func.func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
2555+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
2556+
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
2557+
return %0 : vector<[4]x1xf32>
2558+
}
2559+
2560+
/// Scalable unit dim should not be lowered to shape_cast.
2561+
2562+
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32
2563+
func.func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
2564+
// CHECK-NOT: vector.shape_cast
2565+
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
2566+
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
2567+
return %0 : vector<[1]x4xf32>
2568+
}
2569+
2570+
// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32
2571+
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
2572+
// CHECK-NOT: vector.shape_cast
2573+
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
2574+
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
2575+
2576+
return %0 : vector<[1]x4xf32>
2577+
}

mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -790,57 +790,6 @@ module attributes {transform.with_named_sequence} {
790790
}
791791
}
792792

793-
// -----
794-
795-
/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
796-
797-
// CHECK-LABEL: func @transpose10_4x1xf32
798-
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
799-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
800-
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
801-
return %0 : vector<1x4xf32>
802-
}
803-
804-
// CHECK-LABEL: func @transpose10_nx4x1xf32
805-
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
806-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
807-
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
808-
return %0 : vector<1x[4]xf32>
809-
}
810-
811-
// CHECK-LABEL: func @transpose10_1x4xf32
812-
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
813-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
814-
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
815-
return %0 : vector<4x1xf32>
816-
}
817-
818-
// CHECK-LABEL: func @transpose10_1xnx4xf32
819-
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
820-
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
821-
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
822-
return %0 : vector<[4]x1xf32>
823-
}
824-
825-
/// Scalable unit dim should not be lowered to shape_cast.
826-
827-
// CHECK-LABEL: func @transpose10_4xnx1xf32
828-
func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
829-
// CHECK-NOT: vector.shape_cast
830-
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
831-
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
832-
return %0 : vector<[1]x4xf32>
833-
}
834-
835-
// CHECK-LABEL: func @transpose10_nx4xnx1xf32
836-
func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
837-
// CHECK-NOT: vector.shape_cast
838-
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
839-
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
840-
841-
return %0 : vector<[1]x4xf32>
842-
}
843-
844793
module attributes {transform.with_named_sequence} {
845794
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
846795
transform.apply_patterns to %func_op {

0 commit comments

Comments
 (0)