-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Improve support for vector.extract(broadcast) #116234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Improve support for vector.extract(broadcast) #116234
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThis patch improves support for vector.extract(broadcast) dynamic dimension folders. This is mostly a matter of moving a conservative condition for dynamic dimensions. This patch also improves test coverage for vector.extract + broadcast folders/canonicalizers. The folders/canonicalizers now enumerate every supported / unsupported case. Full diff: https://github.com/llvm/llvm-project/pull/116234.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..12f0ae25f4dc7d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1648,10 +1648,6 @@ static bool hasZeroDimVectors(Operation *op) {
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (extractOp.hasDynamicPosition())
- return Value();
-
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
@@ -1680,6 +1676,16 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
broadcastVecType.getShape().take_back(extractResultRank))
return Value();
+ // The dim-1 broadcast -> ExtractOp folder requires in place operation
+ // modifications. For dynamic position, this means we have to change the
+ // number of operands. This cannot be done in place since it changes the
+ // operation storage. For dynamic dimensions, the dim-1 broadcasting should
+ // be implemented as a canonicalization pattern.
+ // TODO: Implement canonicalization pattern for dim-1 broadcasting +
+ // extractop.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..766f0e09b6d753 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -652,24 +652,44 @@ func.func @fold_extract_transpose(
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_same_type
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_broadcast(%a : f32) -> f32 {
+func.func @fold_extract_broadcast_same_type(%a : f32,
+ %idx0 : index,
+ %idx1 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast_0dvec
+// CHECK-LABEL: fold_extract_broadcast_same_type_vec
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>
+// CHECK: return %[[A]] : vector<4xf32>
+func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>,
+ %idx0 : index)
+ -> vector<4xf32> {
+ %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
// CHECK-SAME: %[[A:.*]]: vector<f32>
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
// CHECK: return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>,
+ %idx0 : index,
+ %idx1 : index) -> f32 {
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
@@ -689,57 +709,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
// CHECK-LABEL: fold_extract_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast_vector
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK: return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+// CHECK: return %[[R]] : f32
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
- return %r : vector<4xf32>
+ %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+// CHECK-SAME: %[[IDX:.*]]: index
+// CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+// This folder is not yet implemented. Check that this does not fold.
+func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
+ %a : vector<4xf32>,
+ %idx : index) -> f32 {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
+func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32,
+ %idx0 : index)
+ -> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
+ // The indices don't batter for this canonicalizer, so we use mixed indices.
+ %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
// CHECK: return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
+ %idx0 : index)
+ -> vector<8xf32> {
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
- %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
+ // The indices don't batter for this canonicalizer, so we use mixed indices.
+ %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
return %r : vector<8xf32>
}
+
// -----
// CHECK-LABEL: @fold_extract_shuffle
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this to match the folder for extractelements?
yes, i updated the description to mention that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but please wait for an OK from Andrzej or Diego.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic makes sense, thanks! I've made a few suggestions how to improve consistency in tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about the delay, this one require a bit of context-switching.
Thanks for all the detective work re folders, that's much appreciated! I think that we should also formalise a bit better what "dim1 broadcasting" means. I can help with that, but will be OOO for a week.
Some comments inline, lets continue the discussion :)
// CHECK-LABEL: fold_extract_broadcast_same_type_vec | ||
// CHECK-SAME: %[[A:.*]]: vector<4xf32> | ||
// CHECK: return %[[A]] : vector<4xf32> | ||
func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>, | |
func.func @fold_extract_broadcast_vector(%a : vector<4xf32>, |
154c550
to
67dc952
Compare
@banach-space @dcaballe After a rethink, I found out I was wrong, and we can also extend the folder to dim-1 broadcasting cases. This should be a much simpler change now. Thanks for the reviews so far! |
Just to clarify - you have re-worked this? Before reviewing, can we make sure we are on the same page re dim-1 broadcasting? #116234 (comment) |
Yes, I reworked it a bit. Before, this PR did not work for dynamic indices when doing dim-1 broadcasting (because I thought folders could not change the number of operands, but they can). Now, this PR is adding dynamic indices support for all cases the existing folder covered for static cases. re: dim-1 broadcasting. Happy to iterate on that PR first before we come back to this, having a look now. |
f5c130c
to
e2a4e6d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've noticed a few typos :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nit, thanks for seeing this through!
Please wait ~1 day before landing, I see that @dcaballe also commented on this in the past and he may want to take another look.
This patch improves support for vector.extract(broadcast) dynamic dimension folders. This is mostly a matter of moving a conservative condition for dynamic dimensions. The broadcast folder for vector.extract now covers the cases that the vector.extractelement + broadcast folder does.
This patch also improves test coverage for vector.extract + broadcast folders/canonicalizers. The folders/canonicalizers now enumerate every supported / unsupported case.