Skip to content

[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. #98455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 17, 2024

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented Jul 11, 2024

Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions.
This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts.
discussed here.

Fix after : #97652 showed an unhandled edge case when all dimensions are one. The generated target VectorType would be vector<f32> which is apparently not supported by the mulf.
In case all dimensions are dropped, the target vectorType is vector<1xf32>

…g / trailing dimensions. (llvm#92934)

Generalizes `DropUnitDimFromElementwiseOps` to support inner unit
dimensions.
This change stems from improving lowering of contractionOps for Arm SME.
Where we end up with inner unit dimensions on MulOp, BroadcastOp and
TransposeOp, preventing the generation of outerproducts.
discussed
[here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa).

---------

Co-authored-by: Benjamin Maxwell <[email protected]>
@nujaa nujaa marked this pull request as ready for review July 11, 2024 10:31
@llvmbot
Copy link
Member

llvmbot commented Jul 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Hugo Trachino (nujaa)

Changes

Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions.
This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts.
discussed here.

Fix after : #97652 showed an unhandled edge case when all dimensions are one.


Full diff: https://github.com/llvm/llvm-project/pull/98455.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+36-26)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+51)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index da5954b70a2ec..4edc85af9ee60 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1622,7 +1622,34 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-/// For vectors with either leading or trailing unit dim, replaces:
+// Scalable unit dimensions are not supported. Folding such dimensions would
+// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
+// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
+// future.
+static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
+  auto inVecShape = inVecTy.getShape();
+  auto inVecScalableDims = inVecTy.getScalableDims();
+  SmallVector<int64_t> newShape;
+  SmallVector<bool> newScalableDims;
+  if (llvm::all_of(inVecShape, [](int64_t dim) { return dim == 1; }) &&
+      llvm::none_of(inVecScalableDims,
+                    [](bool isScalable) { return isScalable; })) {
+    newShape.push_back(1);
+    newScalableDims.push_back(false);
+  } else {
+    for (auto [dim, isScalable] :
+         llvm::zip_equal(inVecShape, inVecScalableDims)) {
+      if (dim == 1 && !isScalable)
+        continue;
+
+      newShape.push_back(dim);
+      newScalableDims.push_back(isScalable);
+    }
+  }
+  return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
+}
+
+/// For vectors with at least an unit dim, replaces:
 ///   elementwise(a, b)
 /// with:
 ///   sc_a = shape_cast(a)
@@ -1634,20 +1661,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 /// required to be rank > 1.
 ///
 /// Ex:
-/// ```
 ///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
 ///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
 ///
 /// gets converted to:
 ///
-/// ```
 ///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
 ///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
 ///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
 ///
 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
 /// `%cast`.
@@ -1667,42 +1690,29 @@ struct DropUnitDimFromElementwiseOps final
     // guaranteed to have identical shapes (with some exceptions such as
     // `arith.select`) and it suffices to only check one of them.
     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
-    if (!sourceVectorType)
-      return failure();
-    if (sourceVectorType.getRank() < 2)
+    if (!sourceVectorType || sourceVectorType.getRank() < 2)
       return failure();
 
-    bool hasTrailingDimUnitFixed =
-        ((sourceVectorType.getShape().back() == 1) &&
-         (!sourceVectorType.getScalableDims().back()));
-    bool hasLeadingDimUnitFixed =
-        ((sourceVectorType.getShape().front() == 1) &&
-         (!sourceVectorType.getScalableDims().front()));
-    if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
-      return failure();
-
-    // Drop leading/trailing unit dim by applying vector.shape_cast to all
-    // operands
-    int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
       auto opVectorType = cast<VectorType>(operand.getType());
-      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
+      auto newVType = dropNonScalableUnitDimFromType(opVectorType);
+      if (newVType == opVectorType)
+        return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
+
       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
       newOperands.push_back(opSC);
     }
 
     VectorType newResultVectorType =
-        VectorType::Builder(resultVectorType).dropDim(dim);
-    // Create an updated elementwise Op without leading/trailing unit dim
+        dropNonScalableUnitDimFromType(resultVectorType);
+    // Create an updated elementwise Op without unit dim.
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                         newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast
-    // to the result
+    // Restore the unit dim by applying vector.shape_cast to the result.
     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5fd3cbd54aa58..303f841e8a828 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -604,6 +604,57 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 
 // -----
 
+func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
+                              %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
+   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
+   %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
+   %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
+   return %res : vector<8x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_inner_unit_dim(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x3xf128>,
+// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
+// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
+// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
+// CHECK:         return %[[VAL_4]] : vector<8x3xf128>
+
+// -----
+
+func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
+                              %arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
+   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
+   %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
+   %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
+   return %res : vector<8x[1]x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
+// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
+// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
+// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
+// CHECK:         return %[[VAL_4]] : vector<8x[1]x3xf128>
+
+// -----
+
+func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> {
+  %0 = arith.mulf %arg0, %arg0 : vector<1x1xf32>
+  %res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
+  return %res : vector<1xf32>
+}
+
+// CHECK-LABEL: func.func @fold_all_unit_dims(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
+// CHECK:         %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK:         %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
+// CHECK:         return %[[VAL_3]] : vector<1xf32>
+
+// -----
+
 func.func @negative_out_of_bound_transfer_read(
     %arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
   %c0 = arith.constant 0 : index

@nujaa nujaa requested review from banach-space and MacDue July 11, 2024 10:43
@nujaa nujaa force-pushed the hugo.dropUnitDimsGen branch from 4af5cf1 to 3909145 Compare July 11, 2024 14:47
newShape.push_back(dim);
newScalableDims.push_back(isScalable);
}
// All dims have been dropped, we need to return a legal shape for VectorType.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 0-D vectors are legal, but not well supported (I think that was the issue from IREE)?

Copy link
Contributor Author

@nujaa nujaa Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I saw some instances essentially with data transfer memory shape ops (extract, broadcast, ...). Do you think the solution should rather allow elementwise op to support vector ?

It was indeed an issue coming from IREE yet, it was reproduced with LLVM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to think about IREE here. We should generate valid IRs in upstream transformations. IMO, 0-D vector is weird. I can't really tell the distinction between 0-D vector and vector<1xT>. I'd suggest to return 1-D vector and document it in the function comment (i.e., l.1625-l.1629).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to what @hanhanW is suggesting, thanks!

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (I see fold_all_unit_dims() covers the case from #92934 (comment))

// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
// future.
// Helper function dropping unit non-scalable dimension from a VectorType
// keeping at least 1 dimension. Scalable unit dimensions are not dropped.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
// keeping at least 1 dimension. Scalable unit dimensions are not dropped.
// keeping at least 1 dimension (to avoid generating 0-D vectors). Scalable unit dimensions are not dropped.

Tiny bit of extra context could save ourselves from scratching our heads in 6 months :)

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@nujaa nujaa merged commit de61875 into llvm:main Jul 17, 2024
7 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…g / trailing dimensions. (#98455)

Summary:
Generalizes DropUnitDimFromElementwiseOps to support inner unit
dimensions.
This change stems from improving lowering of contractionOps for Arm SME.
Where we end up with inner unit dimensions on MulOp, BroadcastOp and
TransposeOp, preventing the generation of outerproducts.
discussed
[here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa).

Fix after : #97652 showed an
unhandled edge case when all dimensions are one. The generated target
VectorType would be `vector<f32>` which is apparently not supported by
the mulf.
In case all dimensions are dropped, the target vectorType is
vector<1xf32>

---------

Co-authored-by: Benjamin Maxwell <[email protected]>

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251689
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants