Skip to content

[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. #92934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 20, 2024

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented May 21, 2024

Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions.
This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts.
discussed here.

Copy link

github-actions bot commented May 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@nujaa nujaa force-pushed the hugo.dropUnitDimsGen branch from 8c33c6c to 9acb3da Compare May 21, 2024 16:58
@nujaa nujaa marked this pull request as ready for review May 21, 2024 17:03
@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-mlir-vector

Author: Hugo Trachino (nujaa)

Changes

Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions.

discussed here.


Full diff: https://github.com/llvm/llvm-project/pull/92934.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+27-23)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+20)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..e772d4bbea311 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1607,7 +1607,23 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-/// For vectors with either leading or trailing unit dim, replaces:
+FailureOr<VectorType> dropNonScalableUnitDimType(VectorType VT) {
+  VectorType newVT = VT;
+  int removed = 0;
+  auto shape = VT.getShape();
+  for (unsigned i = 0; i < shape.size(); i++) {
+    if (shape[i] == 1 && !VT.getScalableDims()[i]) {
+      newVT = VectorType::Builder(newVT).dropDim(i - removed);
+      removed++;
+    }
+  }
+
+  if (removed == 0)
+    return failure();
+  return newVT;
+}
+
+/// For vectors with at least an unit dim, replaces:
 ///   elementwise(a, b)
 /// with:
 ///   sc_a = shape_cast(a)
@@ -1652,42 +1668,30 @@ struct DropUnitDimFromElementwiseOps final
     // guaranteed to have identical shapes (with some exceptions such as
     // `arith.select`) and it suffices to only check one of them.
     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
-    if (!sourceVectorType)
-      return failure();
-    if (sourceVectorType.getRank() < 2)
-      return failure();
-
-    bool hasTrailingDimUnitFixed =
-        ((sourceVectorType.getShape().back() == 1) &&
-         (!sourceVectorType.getScalableDims().back()));
-    bool hasLeadingDimUnitFixed =
-        ((sourceVectorType.getShape().front() == 1) &&
-         (!sourceVectorType.getScalableDims().front()));
-    if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+    if (!sourceVectorType || sourceVectorType.getRank() < 2)
       return failure();
 
-    // Drop leading/trailing unit dim by applying vector.shape_cast to all
-    // operands
-    int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
       auto opVectorType = cast<VectorType>(operand.getType());
-      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
-      auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+      auto newVType = dropNonScalableUnitDimType(opVectorType);
+      if (failed(newVType)) {
+        return failure();
+      }
+      auto opSC =
+          rewriter.create<vector::ShapeCastOp>(loc, newVType.value(), operand);
       newOperands.push_back(opSC);
     }
 
     VectorType newResultVectorType =
-        VectorType::Builder(resultVectorType).dropDim(dim);
-    // Create an updated elementwise Op without leading/trailing unit dim
+        dropNonScalableUnitDimType(resultVectorType).value();
+    // Create an updated elementwise Op without unit dim
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                         newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast
-    // to the result
+    // Restore the unit dim by applying vector.shape_cast to the result
     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..03c19742355bf 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -459,6 +459,26 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK-128B-LABEL: func @fold_unit_dims_entirely(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
+// -----
+
+func.func @fold_unit_center_dim_scalable(%arg0 : vector<8x1x[1]xf128>,
+                              %arg1 : vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]xf128> to vector<8x1x[1]xf128>
+   %add = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]xf128>
+   %res = vector.shape_cast %add : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+   return %res : vector<8x[1]xf128>
+}
+
+// CHECK-LABEL: func.func @fold_unit_center_dim_scalable(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x[1]xf128>,
+// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]xf128> to vector<8x[1]xf128>
+// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]xf128>
+// CHECK:         return %[[VAL_4]] : vector<8x[1]xf128>
+
+
+
 
 // -----
 

@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions.

discussed here.


Full diff: https://github.com/llvm/llvm-project/pull/92934.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+27-23)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+20)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..e772d4bbea311 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1607,7 +1607,23 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-/// For vectors with either leading or trailing unit dim, replaces:
+FailureOr<VectorType> dropNonScalableUnitDimType(VectorType VT) {
+  VectorType newVT = VT;
+  int removed = 0;
+  auto shape = VT.getShape();
+  for (unsigned i = 0; i < shape.size(); i++) {
+    if (shape[i] == 1 && !VT.getScalableDims()[i]) {
+      newVT = VectorType::Builder(newVT).dropDim(i - removed);
+      removed++;
+    }
+  }
+
+  if (removed == 0)
+    return failure();
+  return newVT;
+}
+
+/// For vectors with at least an unit dim, replaces:
 ///   elementwise(a, b)
 /// with:
 ///   sc_a = shape_cast(a)
@@ -1652,42 +1668,30 @@ struct DropUnitDimFromElementwiseOps final
     // guaranteed to have identical shapes (with some exceptions such as
     // `arith.select`) and it suffices to only check one of them.
     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
-    if (!sourceVectorType)
-      return failure();
-    if (sourceVectorType.getRank() < 2)
-      return failure();
-
-    bool hasTrailingDimUnitFixed =
-        ((sourceVectorType.getShape().back() == 1) &&
-         (!sourceVectorType.getScalableDims().back()));
-    bool hasLeadingDimUnitFixed =
-        ((sourceVectorType.getShape().front() == 1) &&
-         (!sourceVectorType.getScalableDims().front()));
-    if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+    if (!sourceVectorType || sourceVectorType.getRank() < 2)
       return failure();
 
-    // Drop leading/trailing unit dim by applying vector.shape_cast to all
-    // operands
-    int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
       auto opVectorType = cast<VectorType>(operand.getType());
-      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
-      auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+      auto newVType = dropNonScalableUnitDimType(opVectorType);
+      if (failed(newVType)) {
+        return failure();
+      }
+      auto opSC =
+          rewriter.create<vector::ShapeCastOp>(loc, newVType.value(), operand);
       newOperands.push_back(opSC);
     }
 
     VectorType newResultVectorType =
-        VectorType::Builder(resultVectorType).dropDim(dim);
-    // Create an updated elementwise Op without leading/trailing unit dim
+        dropNonScalableUnitDimType(resultVectorType).value();
+    // Create an updated elementwise Op without unit dim
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                         newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast
-    // to the result
+    // Restore the unit dim by applying vector.shape_cast to the result
     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..03c19742355bf 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -459,6 +459,26 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK-128B-LABEL: func @fold_unit_dims_entirely(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
+// -----
+
+func.func @fold_unit_center_dim_scalable(%arg0 : vector<8x1x[1]xf128>,
+                              %arg1 : vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]xf128> to vector<8x1x[1]xf128>
+   %add = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]xf128>
+   %res = vector.shape_cast %add : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+   return %res : vector<8x[1]xf128>
+}
+
+// CHECK-LABEL: func.func @fold_unit_center_dim_scalable(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x[1]xf128>,
+// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x[1]xf128>) -> vector<8x[1]xf128> {
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]xf128> to vector<8x[1]xf128>
+// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]xf128> to vector<8x[1]xf128>
+// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]xf128>
+// CHECK:         return %[[VAL_4]] : vector<8x[1]xf128>
+
+
+
 
 // -----
 

@nujaa
Copy link
Contributor Author

nujaa commented May 21, 2024

Hi, @banach-space and @MacDue , let's get started with the drop unit dimension MRs.

@banach-space
Copy link
Contributor

discussed here.

I don’t see unit dims mentioned in that particular thread. Did you mean some other thread?

@nujaa
Copy link
Contributor Author

nujaa commented May 22, 2024

discussed here.

I don’t see unit dims mentioned in that particular thread. Did you mean some other thread?

Indeed, thanks for checking it out. updated to https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa

@banach-space
Copy link
Contributor

discussed here.

I don’t see unit dims mentioned in that particular thread. Did you mean some other thread?

Indeed, thanks for checking it out. updated to https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa

Thanks, now I see what you meant :) Nit - I would add a bit more context in your PR summary to make it more self-contained (I find it super useful when commit msgs contain all the context). Specifically, I’d add that this was suggested in the discussion on improving the lowerings for ArmSME. And then add a link for folks interested in the discussion.

@banach-space
Copy link
Contributor

banach-space commented May 27, 2024

Sorry for the delay, was OOO last week. I've finally managed to catch-up with the context and I have one high-level comment/question.

The pattern that you are updating was designed to help with specific scenarios that are documented here:

/// For vectors with either leading or trailing unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
/// sc_b = shape_cast(b)
/// res = elementwise(sc_a, sc_b)
/// return shape_cast(res)
/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
/// required to be rank > 1.
///
/// Ex:
/// ```
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
/// ```
///
/// gets converted to:
///
/// ```
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
/// ```
///
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
/// `%cast`.

However, those cases look very different to what you are trying to "fix":

discussed here

(copied from Ben's "Canonical form")

%lhsCast = vector.shape_cast %inputLHS : vector<[4]xf32> to vector<[4]x1xf32>
%lhsBcast = vector.broadcast %lhsCast : vector<[4]x1xf32> to vector<[4]x[4]x1xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0, 2] : vector<[4]x[4]x1xf32> to vector<[4]x[4]x1xf32>
%rhsCast = vector.shape_cast %inputRHS : vector<[4]xf32> to vector<1x[4]xf32>
%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>
%mul = arith.mulf %lhsT, %rhs : vector<[4]x[4]x1xf32>
%tileMask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
%dropDim = vector.shape_cast %mul : vector<[4]x[4]x1xf32> to vector<[4]x[4]xf32>
%addAcc = arith.addf %acc, %dropDim : vector<[4]x[4]xf32>
%applyMask = arith.select %tileMask, %acc, %addAcc : vector<[4]x[4]xi1>, vector<[4]x[4]xf32>

In the example above there aren't that many internal unit dims. Here are 2 examples:

%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>

Would DropUnitDimFromElementwiseOps help here at all? If yes, could you write tests for that? From what I can tell, that won't work as neither vector.broadcast nor vector.transpose are elementwise. But perhaps I missed something?

EDIT

Btw, "vector-transfer-flatten.mlir" is failing with this change - not sure why pre-commit CI didn't capture that. In fact, looks like the tests didn't run 🤔

@nujaa
Copy link
Contributor Author

nujaa commented May 28, 2024

Would DropUnitDimFromElementwiseOps help here at all? If yes, could you write tests for that? From what I can tell, that won't work as neither vector.broadcast nor vector.transpose are elementwise. But perhaps I missed something?

I simply found it weird elementwiseOps would only clean off outer dims when other similar patterns would not be limited by it. Simply some normalization. Also originally I thought I could reuse this method for Broadcast care but I realised unit unbroadcasted UnitDim were an issue.

Would DropUnitDimFromElementwiseOps help here at all? If yes, could you write tests for that? From what I can tell, that won't work as neither vector.broadcast nor vector.transpose are elementwise. But perhaps I missed something?
That s why I submitted separate MRs to support BroadcastOps and TransposeOps.

Btw, "vector-transfer-flatten.mlir" is failing with this change - not sure why pre-commit CI didn't capture that. In fact, looks like the tests didn't run 🤔

Hah, committed a fix directly inside Github but apparently I got too confident. 👼

newOperands.push_back(opSC);
}

VectorType newResultVectorType =
VectorType::Builder(resultVectorType).dropDim(dim);
// Create an updated elementwise Op without leading/trailing unit dim
dropNonScalableUnitDimType(resultVectorType).value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good example that we expect it will always return valid values.

@nujaa nujaa force-pushed the hugo.dropUnitDimsGen branch from f5da261 to 5e1c3bd Compare May 31, 2024 11:06
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good to me, just few nits! Thanks for pushing on this!

Comment on lines 1617 to 1620
if (dim != 1 || isScalable) {
newShape.push_back(dim);
newScalableDims.push_back(isScalable);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: prefer using early-exit and continue to simplify code. It also saves the levels of nesting for us.

Suggested change
if (dim != 1 || isScalable) {
newShape.push_back(dim);
newScalableDims.push_back(isScalable);
}
if (dim == 1 && !isScalable)
continue;
newShape.push_back(dim);
newScalableDims.push_back(isScalable);

https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one nit. Thanks!

@nujaa nujaa force-pushed the hugo.dropUnitDimsGen branch from 36ef826 to 6b2204f Compare June 10, 2024 09:51
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@MacDue Do you have any other comments?

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one nit:

@nujaa nujaa changed the title [MLIR][Vector]Generalize DropUnitDimFromElementwiseOps [MLIR][Vector] Generalize DropUnitDimFromElementwiseOps Jun 19, 2024
@nujaa nujaa changed the title [MLIR][Vector] Generalize DropUnitDimFromElementwiseOps [MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. Jun 19, 2024
@nujaa nujaa merged commit 2c06fb8 into llvm:main Jun 20, 2024
7 checks passed
Max191 added a commit to iree-org/llvm-project that referenced this pull request Jun 28, 2024
…n leading / trailing dimensions. (llvm#92934)"

This reverts commit 2c06fb8.
qedawkins pushed a commit to iree-org/llvm-project that referenced this pull request Jul 2, 2024
…n leading / trailing dimensions. (llvm#92934)"

This reverts commit 2c06fb8.
qedawkins pushed a commit to iree-org/llvm-project that referenced this pull request Jul 3, 2024
…n leading / trailing dimensions. (llvm#92934)"

This reverts commit 2c06fb8.
@hanhanW
Copy link
Contributor

hanhanW commented Jul 3, 2024

The commit breaks downstream project (iree-org/iree#17778). Here is the repro: mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir

func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> {
  %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32>
  %0 = arith.mulf %arg0, %cst : vector<1x1xf32>
  return %0 : vector<1x1xf32>
}

Error:

repro.mlir:3:8: error: 'arith.mulf' op operand #0 must be floating-point-like, but got 'vector<f32>'
  %0 = arith.mulf %arg0, %cst : vector<1x1xf32>

The commit seems problematic for corner cases, can we revert it and re-land it with a fix?

https://llvm.org/docs/DeveloperPolicy.html#patch-reversion-policy

@joker-eph
Copy link
Collaborator

Any upstream repro is enough for a revert.

@hanhanW
Copy link
Contributor

hanhanW commented Jul 3, 2024

Any upstream repro is enough for a revert.

got it, thanks!

hanhanW added a commit that referenced this pull request Jul 3, 2024
…n leading / trailing dimensions." (#97652)

Reverts #92934 because it breaks some lowering. To
repro: `mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir`

```mlir
func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> {
  %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32>
  %0 = arith.mulf %arg0, %cst : vector<1x1xf32>
  return %0 : vector<1x1xf32>
}
```
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
…n leading / trailing dimensions." (llvm#97652)

Reverts llvm#92934 because it breaks some lowering. To
repro: `mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir`

```mlir
func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> {
  %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32>
  %0 = arith.mulf %arg0, %cst : vector<1x1xf32>
  return %0 : vector<1x1xf32>
}
```
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…g / trailing dimensions. (llvm#92934)

Generalizes `DropUnitDimFromElementwiseOps` to support inner unit
dimensions.
This change stems from improving lowering of contractionOps for Arm SME.
Where we end up with inner unit dimensions on MulOp, BroadcastOp and
TransposeOp, preventing the generation of outerproducts.
discussed
[here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa).

---------

Co-authored-by: Benjamin Maxwell <[email protected]>
nujaa added a commit to nujaa/llvm-project that referenced this pull request Jul 11, 2024
…g / trailing dimensions. (llvm#92934)

Generalizes `DropUnitDimFromElementwiseOps` to support inner unit
dimensions.
This change stems from improving lowering of contractionOps for Arm SME.
Where we end up with inner unit dimensions on MulOp, BroadcastOp and
TransposeOp, preventing the generation of outerproducts.
discussed
[here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa).

---------

Co-authored-by: Benjamin Maxwell <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants