Skip to content

[mlir][Vector] Add fold transpose(shape_cast) -> shape_cast #73951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 1, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Nov 30, 2023

This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast.

Example:

%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>

Folds to:

%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32>

This is an (alternate) fix for lowering matmuls to ArmSME.

@llvmbot
Copy link
Member

llvmbot commented Nov 30, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast.

Example:

%0 = vector.shape_cast %vec : vector&lt;[4]xf32&gt; to vector&lt;[4]x1xf32&gt;
%1 = vector.transpose %0, [1, 0] : vector&lt;[4]x1xf32&gt; to vector&lt;1x[4]xf32&gt;

Folds to:

vector.shape_cast %vec : vector&lt;[4]xf32&gt; to vector&lt;1x[4]xf32&gt;

This is an (alternate) fix for lowering matmuls to ArmSME.


Full diff: https://github.com/llvm/llvm-project/pull/73951.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+44-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133fc9..cf006adaee72a25 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,12 +5548,55 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
+/// permutes a unit dim from the result of the shape_cast.
+class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransposeOp transpOp,
+                                PatternRewriter &rewriter) const override {
+    Value transposeSrc = transpOp.getVector();
+    auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
+    if (!shapeCastOp)
+      return failure();
+
+    auto sourceType = transpOp.getSourceVectorType();
+    auto resultType = transpOp.getResultVectorType();
+
+    auto filterUnitDims = [](VectorType type) {
+      return llvm::make_filter_range(
+          llvm::zip_equal(type.getShape(), type.getScalableDims()),
+          [&](auto dim) {
+            auto [size, isScalble] = dim;
+            return size != 1 || isScalble;
+          });
+    };
+
+    auto sourceWithoutUnitDims = filterUnitDims(sourceType);
+    auto resultWithoutUnitDims = filterUnitDims(sourceType);
+
+    // If this transpose just permutes a unit dim, then we can fold it into the
+    // shape_cast.
+    for (auto [srcDim, resDim] :
+         llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
+      if (srcDim != resDim)
+        return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
+                                                     shapeCastOp.getSource());
+
+    return success();
+  };
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat>(context);
+              TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..6bfb477ecf97285 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,6 +67,18 @@ func.func @create_mask_transpose_to_transposed_create_mask(
 
 // -----
 
+// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
+//  CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
+  //     CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
+  // CHECK-NOT: vector.transpose
+  %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %1 : vector<1x[4]xf32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_from_create_mask
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {

This folds transpose(shape_cast) into a new shape_cast, when the
transpose just permutes a unit dim from the result of the shape_cast.

Example:

```
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
```

Folds to:
```
vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32>
```

This is an (alternate) fix for lowering matmuls to ArmSME.
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I'd give time for others to take a look though given the ongoing discussion.

if (!shapeCastOp)
return failure();

auto sourceType = transpOp.getSourceVectorType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we had

%0 = shape_cast ... : vector<4x1x4> to vector<2x2x1x4>
transpose %0 [0, 1, 3, 2] : vector<2x2x1x4> to vector<2x2x4x1>

This devolves to the same discussion in the other PR. Since there's already a shape_cast in the source I won't block here, but would it still work to use the source vector type of the shape_cast?

Copy link
Member Author

@MacDue MacDue Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still a legal shape_cast and will lower to (pretty much) the same thing. But yeah, the point here is we're not adding a shape_cast where there was not already one before, so this should not cause problems for SPIR-V :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point I wanted to make was that the pattern in the test is definitely a reasonable canonicalization (i.e. the transpose is just shuffling unit dims introduced by the shape_cast), but unit dims present before the shape_cast touch on the same discussion. My main concern with making the earlier pattern a canonicalization wasn't SPIR-V specific, it was more a matter of whether it was a pattern we wanted globally. Semantic quirks of shape_cast aside (being discussed on discourse), vector.transpose or vector.contract (i.e. higher level vector operations) play nicer with transpose than shape_cast. That's why I see the vector lowering pattern as reasonable; SPIR-V should use the "LLVM" lowering for shape_cast in that case. Making it a canonicalization though means that this needs to be the canonical representation everywhere. Because there's already a shape_cast in the IR here, that's why I'm not blocking, but similarly there's no reason we couldn't have a shape_cast alongside "higher level" vector IR, hence my question.

I feel like we've been blocking your work with how much this conversation got blown up though, and I am sorry about that :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point for us is vector<[4]x1xf32> is an impossible type, there's no legal lowering for that in LLVM (only trailing scalable dimensions are supported). So we need a mechanism (such as this), which allows it to be eliminated.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also LGTM and seems like a sensible path forward that should unlock us given this only applies when there's a shape cast to begin with.

@MaheshRavishankar
Copy link
Contributor

This does look good to me. Do you mind if I check if it fixes the issue and get back to you?

@MacDue
Copy link
Member Author

MacDue commented Nov 30, 2023

This does look good to me. Do you mind if I check if it fixes the issue and get back to you?

Fine with me 👍

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Nov 30, 2023
@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Nov 30, 2023

Kicked of a PR on IREE (iree-org/iree#15748) that cherry-picks this change and undoes the revert we were carrying locally if it fixes the original issue.

@MacDue
Copy link
Member Author

MacDue commented Nov 30, 2023

Kicked of a PR on IREE (iree-org/iree#15748) that cherry-picks this change and undoes the revert we were carrying locally if it fixes the original issue.

This does not depend on the previous transpose lowering (which will likely still lead to problems for you)

This patch alone is all we need for lowering SME matmuls :)

@MaheshRavishankar
Copy link
Contributor

Kicked of a PR on IREE (openxla/iree#15748) that cherry-picks this change and undoes the revert we were carrying locally if it fixes the original issue.

This does not depend on the previous transpose lowering (which will likely still lead to problems for you)

This patch alone is all we need for lowering SME matmuls :)

Brain fade on my part.... So really we will also need to upstream revert the original patch.... Thats a different discussion. Fixed up the PR for now that just tests this patch.

@MacDue MacDue merged commit f42b761 into llvm:main Dec 1, 2023
@MacDue MacDue deleted the fold_transpose_unit_shape_cast branch December 1, 2023 14:24
@MacDue
Copy link
Member Author

MacDue commented Dec 1, 2023

Kicked of a PR on IREE (openxla/iree#15748) that cherry-picks this change and undoes the revert we were carrying locally if it fixes the original issue.

This does not depend on the previous transpose lowering (which will likely still lead to problems for you)
This patch alone is all we need for lowering SME matmuls :)

Brain fade on my part.... So really we will also need to upstream revert the original patch.... Thats a different discussion. Fixed up the PR for now that just tests this patch.

One thing to note is that with the original patch reverted, the transpose lowering generates invalid code for scalable vectors. It must at least return failure() for scalable vectors.

@apaszke
Copy link
Member

apaszke commented Dec 6, 2023

This folding rules seems wrong to me? It does not even look at the transpose pattern of the non-unit dims? I have code that looks like this:

%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
%23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32>

and it gets simplified to

%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>

which obviously produces different values!

apaszke added a commit to apaszke/llvm-project that referenced this pull request Dec 6, 2023
…lvm#73951)"

This reverts commit f42b761.

The fold pattern is incorrect, because it does not even look at the
permutation of non-unit dims and is happy to replace a pattern
such as
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
%23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32>
```
with
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
```
which is obviously incorrect.
metaflow pushed a commit that referenced this pull request Dec 6, 2023
…73951)" (#74579)

This reverts commit f42b761.

The fold pattern is incorrect, because it does not even look at the
permutation of non-unit dims and is happy to replace a pattern such as
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
%23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32>
```
with
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
```
which is obviously incorrect.
MacDue added a commit to MacDue/llvm-project that referenced this pull request Dec 6, 2023
This folds transpose(shape_cast) into a new shape_cast, when the
transpose just permutes a unit dim from the result of the shape_cast.

Example:

```
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
```

Folds to:
```
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32>
```

This is an (alternate) fix for lowering matmuls to ArmSME.

---

Corrected version of llvm#73951.
MacDue added a commit to MacDue/llvm-project that referenced this pull request Dec 6, 2023
This folds transpose(shape_cast) into a new shape_cast, when the
transpose just permutes a unit dim from the result of the shape_cast.

Example:

```
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
```

Folds to:
```
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32>
```

This is an (alternate) fix for lowering matmuls to ArmSME.

---

Corrected version of llvm#73951.
@MacDue
Copy link
Member Author

MacDue commented Dec 6, 2023

I've prepared a patch here: https://github.com/llvm/llvm-project/compare/main...MacDue:llvm-project:transpose_of_shape_cast_v2?expand=1, that I believe fixes the correctness issues. Sorry for the inconvenience! :)

I won't create a new PR because we have alternate solutions to what this aimed to solve (and the general n-D case for this fold is a little tricky).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants