Skip to content

[mlir][vector] Update castAwayContractionLeadingOneDim to omit transposes solely on leading unit dims. #85694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2024

Conversation

KoolJBlack
Copy link
Contributor

@KoolJBlack KoolJBlack commented Mar 18, 2024

Updates castAwayContractionLeadingOneDim to check for leading unit dimensions before inserting vector.transpose ops.

Currently castAwayContractionLeadingOneDim removes all leading unit dims based on the accumulator and transpose any subsequent operands to match the accumulator indexing. This does not take into account if the transpose is strictly necessary, for instance when given this vector-matrix contract:

  %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>

Passing this through castAwayContractionLeadingOneDim pattern produces the following:

    %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x1x8xi32> to vector<1x1x8xi32>
    %1 = vector.extract %0[0] : vector<1x8xi32> from vector<1x1x8xi32>
    %2 = vector.extract %arg2[0] : vector<8xi32> from vector<1x8xi32>
    %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %arg1, %2 : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
    %4 = vector.broadcast %3 : vector<8xi32> to vector<1x8xi32>

The vector.transpose introduced does not affect the underlying data layout (effectively a no op), but it cannot be folded automatically. This change avoids inserting transposes when only leading unit dimensions are involved.

Fixes #85691

@KoolJBlack KoolJBlack changed the title Update castAwayContractionLeadingOneDim to omit transposes solely on leading unit dims. [mlir][vector] Update castAwayContractionLeadingOneDim to omit transposes solely on leading unit dims. Mar 18, 2024
@KoolJBlack KoolJBlack requested a review from dcaballe March 19, 2024 16:46
@KoolJBlack KoolJBlack marked this pull request as ready for review March 19, 2024 16:46
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Kojo Acquah (KoolJBlack)

Changes

Fixes #85691


Full diff: https://github.com/llvm/llvm-project/pull/85694.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+17-2)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+22)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..6b69f5f1932ad7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -399,13 +399,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
           transposeResults.push_back(targetExpr);
         }
       }
+
+      // Check if the transpose effects outer unit dims only. Such transposes do
+      // not materially effect the underlying vector and can be omitted.
+      bool tranposeNonOuterUnitDims = false;
+      for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
+        if (perm[i] != i && i != (int64_t)perm.size() - 1) {
+          if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
+              1) {
+            tranposeNonOuterUnitDims = true;
+          }
+        }
+      }
+
       // Do the tranpose now if needed so that we can drop the
       // correct dim using extract later.
       if (tranposeNeeded) {
         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
                              contractOp.getContext());
-        operands[it.index()] = rewriter.create<vector::TransposeOp>(
-            contractOp.getLoc(), operands[it.index()], perm);
+        if (tranposeNonOuterUnitDims) {
+          operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+              contractOp.getLoc(), operands[it.index()], perm);
+        }
       }
     }
     // We have taken care to have the dim to be dropped be
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..31b0867c851f58 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
 
 // -----
 
+// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dims_vec_mat(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME:                                 %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK:           %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK:           %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK:           %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
+// CHECK:           return %[[VAL_6]] : vector<1x8xi32>
+// CHECK:         }
+func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+                          %rhs: vector<1x8x8xi32>,
+                          %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+  %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+  return %result : vector<1x8xi32>
+}
+
+// -----
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>

@KoolJBlack KoolJBlack force-pushed the vector_drop_unit_transpose branch from b1ff4b0 to a561286 Compare March 19, 2024 17:04
@dcaballe dcaballe requested a review from banach-space March 19, 2024 17:12
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to submit this earlier :(

Similar comments to what @hanhanW has posted. I would also ask for a some justification in the commit summay (see https://mlir.llvm.org/getting_started/Contributing/#commit-messages). A reference to IREE issue is very helpful, but the commit summary should be self-contained - could you add a brief overview?

@KoolJBlack KoolJBlack force-pushed the vector_drop_unit_transpose branch from a561286 to 163ea73 Compare March 21, 2024 22:04
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates - few more comments inline

Comment on lines 403 to 407
// Checks if only the outer, unit dimensions (of size 1) are permuted.
// Such transposes do not materially effect the underlying vector and can
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
bool tranposeNonOuterUnitDims = false;
for (auto [index, dim] :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example, I now understand why we need this.. Instead of adding the ad-hoc logic here, would it make sense to add the canonicalization pattern to vector.transpose op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this. I wasn't certain if it was %100 safe to remove transposes like this in every scenario (for instance, if there is some pass that creates transposes like this in one pattern and consumes them in a subsequent pattern).

On the flipside, since this was an acute issue created in this pass it is pretty straightforward to handle these transposes here understanding the entire pass.

I would be in favor of using this modification for now and can look intro transpose canonicalizer in following, less you feel differently?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The follow up SGTM as well. Having that canonicalization pattern makes sense but given that we already have logic to decide if a transpose should be generated or not, it would make sense to extend that logic to support this extra case (and hopefully avoid a canonicalization pass just for this).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

going to commit this for now and start work on a broader transpose fold

@KoolJBlack KoolJBlack force-pushed the vector_drop_unit_transpose branch 2 times, most recently from f40efe1 to bfd2fc4 Compare April 3, 2024 21:28
@KoolJBlack KoolJBlack force-pushed the vector_drop_unit_transpose branch from bfd2fc4 to 55af0b0 Compare April 3, 2024 23:21
@KoolJBlack KoolJBlack force-pushed the vector_drop_unit_transpose branch from 55af0b0 to d36279f Compare April 3, 2024 23:26
@KoolJBlack KoolJBlack merged commit 66fed33 into llvm:main Apr 3, 2024
@KoolJBlack KoolJBlack deleted the vector_drop_unit_transpose branch April 3, 2024 23:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

castAwayContractionLeadingOneDim introduces unnecessary transposes on outer unit dims
5 participants