-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Update castAwayContractionLeadingOneDim
to omit transposes solely on leading unit dims.
#85694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
castAwayContractionLeadingOneDim
to omit transposes solely on leading unit dims. castAwayContractionLeadingOneDim
to omit transposes solely on leading unit dims.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Kojo Acquah (KoolJBlack) ChangesFixes #85691 Full diff: https://github.com/llvm/llvm-project/pull/85694.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..6b69f5f1932ad7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -399,13 +399,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
transposeResults.push_back(targetExpr);
}
}
+
+ // Check if the transpose effects outer unit dims only. Such transposes do
+ // not materially effect the underlying vector and can be omitted.
+ bool tranposeNonOuterUnitDims = false;
+ for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
+ if (perm[i] != i && i != (int64_t)perm.size() - 1) {
+ if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
+ 1) {
+ tranposeNonOuterUnitDims = true;
+ }
+ }
+ }
+
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
- operands[it.index()] = rewriter.create<vector::TransposeOp>(
- contractOp.getLoc(), operands[it.index()], perm);
+ if (tranposeNonOuterUnitDims) {
+ operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+ contractOp.getLoc(), operands[it.index()], perm);
+ }
}
}
// We have taken care to have the dim to be dropped be
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..31b0867c851f58 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
// -----
+// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
+
+// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
+// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
+// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
+// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
+// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
+// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
+// CHECK: return %[[VAL_6]] : vector<1x8xi32>
+// CHECK: }
+func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
+ %rhs: vector<1x8x8xi32>,
+ %acc: vector<1x8xi32>) -> vector<1x8xi32> {
+ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
+ return %result : vector<1x8xi32>
+}
+
+// -----
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
|
b1ff4b0
to
a561286
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, forgot to submit this earlier :(
Similar comments to what @hanhanW has posted. I would also ask for a some justification in the commit summay (see https://mlir.llvm.org/getting_started/Contributing/#commit-messages). A reference to IREE issue is very helpful, but the commit summary should be self-contained - could you add a brief overview?
mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Outdated
Show resolved
Hide resolved
a561286
to
163ea73
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates - few more comments inline
mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Outdated
Show resolved
Hide resolved
// Checks if only the outer, unit dimensions (of size 1) are permuted. | ||
// Such transposes do not materially effect the underlying vector and can | ||
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32> | ||
bool tranposeNonOuterUnitDims = false; | ||
for (auto [index, dim] : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the example, I now understand why we need this.. Instead of adding the ad-hoc logic here, would it make sense to add the canonicalization pattern to vector.transpose op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered this. I wasn't certain if it was %100 safe to remove transposes like this in every scenario (for instance, if there is some pass that creates transposes like this in one pattern and consumes them in a subsequent pattern).
On the flipside, since this was an acute issue created in this pass it is pretty straightforward to handle these transposes here understanding the entire pass.
I would be in favor of using this modification for now and can look intro transpose canonicalizer in following, less you feel differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The follow up SGTM as well. Having that canonicalization pattern makes sense but given that we already have logic to decide if a transpose should be generated or not, it would make sense to extend that logic to support this extra case (and hopefully avoid a canonicalization pass just for this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
going to commit this for now and start work on a broader transpose fold
f40efe1
to
bfd2fc4
Compare
bfd2fc4
to
55af0b0
Compare
55af0b0
to
d36279f
Compare
Updates
castAwayContractionLeadingOneDim
to check for leading unit dimensions before insertingvector.transpose
ops.Currently
castAwayContractionLeadingOneDim
removes all leading unit dims based on the accumulator and transpose any subsequent operands to match the accumulator indexing. This does not take into account if the transpose is strictly necessary, for instance when given this vector-matrix contract:Passing this through
castAwayContractionLeadingOneDim
pattern produces the following:The
vector.transpose
introduced does not affect the underlying data layout (effectively a no op), but it cannot be folded automatically. This change avoids inserting transposes when only leading unit dimensions are involved.Fixes #85691