Skip to content

[mlir][vector] Split TransposeOpLowering into 2 patterns #91935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2024

Conversation

banach-space
Copy link
Contributor

Splits TransposeOpLowering into two patterns:

  1. Transpose2DWithUnitDimToShapeCast - rewrites 2D vector.transpose
    as vector.shape_cast (there has to be at least one unit dim),
  2. TransposeOpLowering - the original pattern without the part
    extracted into Transpose2DWithUnitDimToShapeCast.

The rationale behind the split:

  • the output generated by Transpose2DWithUnitDimToShapeCast doesn't
    really match the intended output from TransposeOpLowering as
    documented in the source file - it doesn't make much sense to keep
    it embedded inside TransposeOpLowering,
  • Transpose2DWithUnitDimToShapeCast does work for scalable vectors,
    TransposeOpLowering does not.

Splits `TransposeOpLowering` into two patterns:
  1. `Transpose2DWithUnitDimToShapeCast` - rewrites 2D `vector.transpose`
    as `vector.shape_cast` (there has to be at least one unit dim),
  2. `TransposeOpLowering` - the original pattern without the part
    extracted into `Transpose2DWithUnitDimToShapeCast`.

The rationale behind the split:
  * the output generated by `Transpose2DWithUnitDimToShapeCast` doesn't
    really match the intended output from `TransposeOpLowering` as
    documented in the source file - it doesn't make much sense to keep
    it embedded inside `TransposeOpLowering`,
  * `Transpose2DWithUnitDimToShapeCast` _does_ work for scalable vectors,
    `TransposeOpLowering` _does_ not.
@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Splits TransposeOpLowering into two patterns:

  1. Transpose2DWithUnitDimToShapeCast - rewrites 2D vector.transpose
    as vector.shape_cast (there has to be at least one unit dim),
  2. TransposeOpLowering - the original pattern without the part
    extracted into Transpose2DWithUnitDimToShapeCast.

The rationale behind the split:

  • the output generated by Transpose2DWithUnitDimToShapeCast doesn't
    really match the intended output from TransposeOpLowering as
    documented in the source file - it doesn't make much sense to keep
    it embedded inside TransposeOpLowering,
  • Transpose2DWithUnitDimToShapeCast does work for scalable vectors,
    TransposeOpLowering does not.

Full diff: https://github.com/llvm/llvm-project/pull/91935.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+48-22)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 7011c478fefba..0706f22cb8b12 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     VectorType inputType = op.getSourceVectorType();
     VectorType resType = op.getResultVectorType();
 
+    if (inputType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "This lowering does not support scalable vectors");
+
     // Set up convenience transposition table.
     ArrayRef<int64_t> transp = op.getPermutation();
 
@@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
-    // Replace:
-    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
-    //                                 vector<1xnxelty>
-    // with:
-    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
-    //
-    // Source with leading unit dim (inverse) is also replaced. Unit dim must
-    // be fixed. Non-unit can be scalable.
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    // TODO: Add support for scalable vectors
-    if (inputType.isScalable())
-      return failure();
-
     // Handle a true 2-D matrix transpose differently when requested.
     if (vectorTransformOptions.vectorTransposeLowering ==
             vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,48 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransformsOptions vectorTransformOptions;
 };
 
+/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
+/// to 2D vectors with at least one unit dim. For example:
+///
+/// Replace:
+///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
+///                                 vector<1x4xi32>
+/// with:
+///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
+///
+/// Source with leading unit dim (inverse) is also replaced. Unit dim must
+/// be fixed. Non-unit can be scalable.
+class Transpose2DWithUnitDimToShapeCast
+    : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
+                                    PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.getVector();
+    VectorType resType = op.getResultVectorType();
+
+    // Set up convenience transposition table.
+    ArrayRef<int64_t> transp = op.getPermutation();
+
+    if (resType.getRank() == 2 &&
+        ((resType.getShape().front() == 1 &&
+          !resType.getScalableDims().front()) ||
+         (resType.getShape().back() == 1 &&
+          !resType.getScalableDims().back())) &&
+        transp == ArrayRef<int64_t>({1, 0})) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -483,6 +507,8 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns, VectorTransformsOptions options,
     PatternBenefit benefit) {
+  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
+                                                  benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       options, patterns.getContext(), benefit);
 }

@MacDue
Copy link
Member

MacDue commented May 13, 2024

Transpose2DWithUnitDimToShapeCast does work for scalable vectors, TransposeOpLowering does not.

Note that it only works if (and only if) the vector.shape_cast cancels out.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving this PR, but I do wonder (with the benefit of hindsight), if the Transpose2DWithUnitDimToShapeCast lowering would be better as a (more specific) canonicalization of a shape_cast + transpose. Since what we really want here is for the illegal vector types (e.g. vector<[4]x1xf32>) to be eliminated, which won't always happen if the transpose is replaced with a shape cast.

E.g. the rewrite could be:

%0 = some_op 
%1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
%2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>

->

%0 = some_op 
%1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think such a change causes failure on SPIR-V backend. I think we've had this discussion previously.

Waiting for more people to chime in. cc @kuhar and @antiagainst

@banach-space
Copy link
Contributor Author

I think such a change causes failure on SPIR-V backend.

It shouldn't, this is merely moving code around.

@MaheshRavishankar
Copy link
Contributor

I think such a change causes failure on SPIR-V backend.

It shouldn't, this is merely moving code around.

IIRC this is now going to be generating a vector.shape_cast that isnt handled on SPIR-V backends, and this pattern is being added to a "generic vector lowering" pattern set. I'd suggest adding a "populateVectorTransposeLoweringForLLVMPatterns" and adding this to that path. I know the SPIR-V path isnt tested as well in MLIR, but the shape_cast isnt supported on SPIR-V path.

@MacDue
Copy link
Member

MacDue commented May 13, 2024

I think such a change causes failure on SPIR-V backend.

It shouldn't, this is merely moving code around.

IIRC this is now going to be generating a vector.shape_cast that isnt handled on SPIR-V backends, and this pattern is being added to a "generic vector lowering" pattern set. I'd suggest adding a "populateVectorTransposeLoweringForLLVMPatterns" and adding this to that path. I know the SPIR-V path isnt tested as well in MLIR, but the shape_cast isnt supported on SPIR-V path.

IIRC this lowering has not been a problem on SPIR-V for a while now (since populateVectorShapeCastLoweringPatterns() is now called in the IREE SPIR-V backend). We removed any option to disable this lowering as it was not needed for SPIR-V (see #75062).

@banach-space
Copy link
Contributor Author

IIRC this is now going to be generating a vector.shape_cast that isnt handled on SPIR-V backends

No, TransposeOpLowering is already generating vector.shape_cast as demonstrated by these tests:

/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
// CHECK-LABEL: func @transpose10_4x1xf32
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
return %0 : vector<1x4xf32>
}
// CHECK-LABEL: func @transpose10_nx4x1xf32
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
// CHECK-LABEL: func @transpose10_1x4xf32
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
return %0 : vector<4x1xf32>
}
// CHECK-LABEL: func @transpose10_1xnx4xf32
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
return %0 : vector<[4]x1xf32>
}
/// Scalable unit dim should not be lowered to shape_cast.
// CHECK-LABEL: func @transpose10_4xnx1xf32
func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
return %0 : vector<[1]x4xf32>
}
// CHECK-LABEL: func @transpose10_nx4xnx1xf32
func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
return %0 : vector<[1]x4xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose
} : !transform.op<"func.func">
transform.yield
}
}

In our previous discussion about these patterns, @antiagainst confirmed that those are not a problem for SPIR-V:

We explicitly control when to perform transpose lowering and thus we can slot the lowering for type_cast in there too.

In any case, let me re-iterate, this patch does not add any new patterns, it is merely moving the existing patterns.

and this pattern is being added to a "generic vector lowering" pattern set

No, this pattern:

This PR merely extracts that logic into a separate class, Transpose2DWithUnitDimToShapeCast. The rationale for that is summarised ... in the summary.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I understand now. Thanks for the explanation and sorry for the noise!

@banach-space
Copy link
Contributor Author

banach-space commented May 14, 2024

I do wonder (with the benefit of hindsight), if the Transpose2DWithUnitDimToShapeCast lowering would be better as a (more specific) canonicalization of a shape_cast + transpose.

That's a good point, but I'd like to get input from other reviewers first (specifically, folks that are using these patterns for something other than scalable vectors). Let me add a TODO with your suggestion and we can re-visit at some later time. Hopefully somebody will notice and bring it up.

Ah, I understand now. Thanks for the explanation and sorry for the noise!

No worries, it's good to have somebody with a SPIR-V hat on scanning these PRs, thanks for taking a look!

@banach-space banach-space merged commit cbd72cb into llvm:main May 14, 2024
3 of 4 checks passed
@banach-space banach-space deleted the andrzej/split_lower_transpose branch May 17, 2024 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants