Skip to content

Revert "[mlir][vector] Move transpose with unit-dim to shape_cast pattern (#72493)" #72918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2023

Conversation

MaheshRavishankar
Copy link
Contributor

This reverts commit 95acb33.

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

This reverts commit 95acb33.


Full diff: https://github.com/llvm/llvm-project/pull/72918.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-38)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+18)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (-51)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+51)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6793f902a1e59e58..c7b74701fdbc8f20 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5552,49 +5552,12 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose with non-scalable unit dims into a shape_cast.
-///
-/// Replace:
-///   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
-///                                 vector<1xnxelty>
-/// with:
-///   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dims can be scalable.
-class FoldTransposeWithNonScalableUnitDimsToShapeCast final
-    : public OpRewritePattern<TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transpOp,
-                                PatternRewriter &rewriter) const override {
-    Value input = transpOp.getVector();
-    VectorType resType = transpOp.getResultVectorType();
-    ArrayRef<int64_t> permutation = transpOp.getPermutation();
-
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        permutation == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resType,
-                                                       input);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat,
-              FoldTransposeWithNonScalableUnitDimsToShapeCast>(context);
+              TransposeFolder, FoldTransposeSplat>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 9475d273c1162607..97f6caca1b25cccc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -334,6 +334,24 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return rewriter.notifyMatchFailure(
           op, "Options specifies lowering to shuffle");
 
+    // Replace:
+    //   vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
+    //                                 vector<1xnxelty>
+    // with:
+    //   vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
+    //
+    // Source with leading unit dim (inverse) is also replaced. Unit dim must
+    // be fixed. Non-unit can be scalable.
+    if (resType.getRank() == 2 &&
+        ((resType.getShape().front() == 1 &&
+          !resType.getScalableDims().front()) ||
+         (resType.getShape().back() == 1 &&
+          !resType.getScalableDims().back())) &&
+        transp == ArrayRef<int64_t>({1, 0})) {
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
+      return success();
+    }
+
     if (inputType.isScalable())
       return failure();
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b3902d2d9b4dde00..1021c73cc57d341d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2524,54 +2524,3 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
-
-// -----
-
-/// Transpose of rank-2 vector with leading or trailing non-scalable unit dim to shape_cast.
-
-// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32
-func.func @fold_transpose_with_unit_dims_to_shape_cast_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
-  return %0 : vector<1x4xf32>
-}
-
-// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32
-func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32
-func.func @fold_transpose_with_unit_dims_to_shape_cast_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
-  return %0 : vector<4x1xf32>
-}
-
-// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32
-func.func @fold_transpose_with_unit_dims_to_shape_cast_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
-  return %0 : vector<[4]x1xf32>
-}
-
-/// Scalable unit dim should not be lowered to shape_cast.
-
-// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32
-func.func @fold_transpose_with_unit_dims_to_shape_cast_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
-  // CHECK-NOT: vector.shape_cast
-  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
-  return %0 : vector<[1]x4xf32>
-}
-
-// CHECK-LABEL: func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32
-func.func @fold_transpose_with_unit_dims_to_shape_cast_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
-  // CHECK-NOT: vector.shape_cast
-  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
-
-  return %0 : vector<[1]x4xf32>
-}
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 72be5e4dbe3ee163..c0b44428d5bcf305 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -790,6 +790,57 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4x1xf32
+func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4x1xf32
+func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1x4xf32
+func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
+  return %0 : vector<4x1xf32>
+}
+
+// CHECK-LABEL: func @transpose10_1xnx4xf32
+func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
+  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
+  return %0 : vector<[4]x1xf32>
+}
+
+/// Scalable unit dim should not be lowered to shape_cast.
+
+// CHECK-LABEL: func @transpose10_4xnx1xf32
+func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+  return %0 : vector<[1]x4xf32>
+}
+
+// CHECK-LABEL: func @transpose10_nx4xnx1xf32
+func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
+
+  return %0 : vector<[1]x4xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I'm not sure this is well suited for a canonicalization. vector.transpose has an additional guarantee that the source and result vector ranks are the same, as well as tells how to relate dimensions between the source and the result. I see that as useful information even when all of the transposed dimensions are unit dims. shape_cast does not have that same guarantee and might require analysis to recover the semantic meaning of the shape_cast in some cases (not that I have such a case on hand, but shape casts like this aren't easy to handle in SPIR-V).

@banach-space
Copy link
Contributor

@MaheshRavishankar, while we agreed that it's fine to revert this patch to unblock SPIR-V code-gen, it was also expected that there would be a test that would capture this sort of issues in the future.

Please, could a test be added upstream so that patches are reverted based on upstream buildbot failures instead of users reporting issues downstream?

@MaheshRavishankar
Copy link
Contributor Author

@MaheshRavishankar, while we agreed that it's fine to revert this patch to unblock SPIR-V code-gen, it was also expected that there would be a test that would capture this sort of issues in the future.

Please, could a test be added upstream so that patches are reverted based on upstream buildbot failures instead of users reporting issues downstream?

Sorry, maybe I pushed the button too hastily. I'll try to create a repro upstream but it isn't always easy since it could be as complex as running the entire spirv lowering pipeline used in IREE. Even with this particular revert the conversion is still broken. For this particular one though, having patterns in canonicalization should have a high bar since it applies everywhere in the stack, and I thought there was an agreement that this doesn't fit the bill of being a canonicalization. Hence I pushed it since I already had the PR. Let's continue the conversation on the other PR w.r.t having an upstream test and how to assess the broken spirv conversion.

@stellaraccident
Copy link
Contributor

stellaraccident commented Nov 21, 2023

I think there are two separate issues here:

  1. There is a bar for what makes something a canonicalization which I am not hearing consensus on. I think that alone justifies a revert and further discussion. The price of just landing things by PR vs RFC or broader discussion for things like this is that we have to have a low cost path to revert when we mis-judge consensus.
  2. Getting a SPIR-V conversion test which exercises this is a good thing to do. Thank you for identifying it.

This particular patch was reverted because of point 1. The bar for canonicalization specifically is universal accessibility and agreement that the form being canonicalized to is actually "canonical". Both are judgment calls and a matter of design.

(I'll note that the last time something like this came up which broke SPIR-V conversion, in the downstream, we ended up concurring with the approach. We carried the revert locally for two weeks and an engineer spent a week and landed multiple patches to robustify it. It was hard but the right call so, we spent the time to fix it forward. So it does go both ways. This time, it is a question of consensus on the design)

@banach-space
Copy link
Contributor

Sorry, maybe I pushed the button too hastily.

No worries, we did agree to have this reverted. I just wanted to make sure that we use this opportunity to improve test coverage in MLIR - ideally there would be a test that would prevent us from landing this in the first place.

Both are judgment calls and a matter of design.

I think that there's also a friction between "canonical for CPUs" and "canonical for GPUs" - currently there is no mechanism to distinguish between the two.

This has been a very useful discussion - I didn't realise that things would be so different for GPUs. Apologies for missing that.

@joker-eph
Copy link
Collaborator

Side note: please provide clear reasoning in commit message when reverting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants