Skip to content

[mlir][vector] Move transpose with unit-dim to shape_cast pattern #72493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5564,12 +5564,51 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose with non-scalable unit dims into a shape_cast.
///
/// Replace:
/// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
/// vector<1xnxelty>
/// with:
/// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
///
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
/// be fixed. Non-unit dims can be scalable.
class FoldTransposeWithNonScalableUnitDimsToShapeCast final
: public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transpOp,
PatternRewriter &rewriter) const override {
Value input = transpOp.getVector();
VectorType resType = transpOp.getResultVectorType();

SmallVector<int64_t> permutation;
transpOp.getTransp(permutation);

if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
permutation == ArrayRef<int64_t>({1, 0})) {
Comment on lines +5590 to +5595
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up patch, I wonder if we could generalize this to n-D dimensions where 0 or 1 of them is != 1? If I'm not missing something, the permutation pattern itself shouldn't even matter for those cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good I'll look into it soon 👍

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resType,
input);
return success();
}

return failure();
}
};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
TransposeFolder, FoldTransposeSplat>(context);
TransposeFolder, FoldTransposeSplat,
FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
}

void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
Expand Down
18 changes: 0 additions & 18 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,24 +336,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");

// Replace:
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
// vector<1xnxelty>
// with:
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
//
// Source with leading unit dim (inverse) is also replaced. Unit dim must
// be fixed. Non-unit can be scalable.
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}

if (inputType.isScalable())
return failure();

Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2524,3 +2524,54 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}

// -----

/// Transpose of rank-2 vector with leading or trailing non-scalable unit dim to shape_cast.

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
return %0 : vector<1x4xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
return %0 : vector<4x1xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
return %0 : vector<[4]x1xf32>
}

/// Scalable unit dim should not be lowered to shape_cast.

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
return %0 : vector<[1]x4xf32>
}

// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32
func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>

return %0 : vector<[1]x4xf32>
}
51 changes: 0 additions & 51 deletions mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -790,57 +790,6 @@ module attributes {transform.with_named_sequence} {
}
}

// -----

/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.

// CHECK-LABEL: func @transpose10_4x1xf32
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
return %0 : vector<1x4xf32>
}

// CHECK-LABEL: func @transpose10_nx4x1xf32
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}

// CHECK-LABEL: func @transpose10_1x4xf32
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
return %0 : vector<4x1xf32>
}

// CHECK-LABEL: func @transpose10_1xnx4xf32
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
return %0 : vector<[4]x1xf32>
}

/// Scalable unit dim should not be lowered to shape_cast.

// CHECK-LABEL: func @transpose10_4xnx1xf32
func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
return %0 : vector<[1]x4xf32>
}

// CHECK-LABEL: func @transpose10_nx4xnx1xf32
func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>

return %0 : vector<[1]x4xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
transform.apply_patterns to %func_op {
Expand Down