Skip to content

[mlir][Vector] Improve support for vector.extract(broadcast) #116234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 24, 2025

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Nov 14, 2024

This patch improves support for vector.extract(broadcast) dynamic dimension folders. This is mostly a matter of moving a conservative condition for dynamic dimensions. The broadcast folder for vector.extract now covers the cases that the vector.extractelement + broadcast folder does.

This patch also improves test coverage for vector.extract + broadcast folders/canonicalizers. The folders/canonicalizers now enumerate every supported / unsupported case.

@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

This patch improves support for vector.extract(broadcast) dynamic dimension folders. This is mostly a matter of moving a conservative condition for dynamic dimensions.

This patch also improves test coverage for vector.extract + broadcast folders/canonicalizers. The folders/canonicalizers now enumerate every supported / unsupported case.


Full diff: https://github.com/llvm/llvm-project/pull/116234.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+10-4)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+57-23)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..12f0ae25f4dc7d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1648,10 +1648,6 @@ static bool hasZeroDimVectors(Operation *op) {
 
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  // TODO: Canonicalization for dynamic position not implemented yet.
-  if (extractOp.hasDynamicPosition())
-    return Value();
-
   Operation *defOp = extractOp.getVector().getDefiningOp();
   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
@@ -1680,6 +1676,16 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
           broadcastVecType.getShape().take_back(extractResultRank))
     return Value();
 
+  // The dim-1 broadcast -> ExtractOp folder requires in place operation
+  // modifications. For dynamic position, this means we have to change the
+  // number of operands. This cannot be done in place since it changes the
+  // operation storage. For dynamic dimensions, the dim-1 broadcasting should
+  // be implemented as a canonicalization pattern.
+  // TODO: Implement canonicalization pattern for dim-1 broadcasting +
+  // extractop.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   auto broadcastOp = cast<vector::BroadcastOp>(defOp);
   int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..766f0e09b6d753 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -652,24 +652,44 @@ func.func @fold_extract_transpose(
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_same_type
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast(%a : f32) -> f32 {
+func.func @fold_extract_broadcast_same_type(%a : f32, 
+                                            %idx0 : index, 
+                                            %idx1 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_0dvec
+// CHECK-LABEL: fold_extract_broadcast_same_type_vec
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   return %[[A]] : vector<4xf32>
+func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>, 
+                                                %idx0 : index) 
+                                                -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>, 
+                                                   %idx0 : index, 
+                                                   %idx1 : index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
@@ -689,57 +709,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
 // CHECK-LABEL: fold_extract_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_vector
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+//       CHECK:   return %[[R]] : f32
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
-  return %r : vector<4xf32>
+  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+//  CHECK-SAME:   %[[IDX:.*]]: index
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+// This folder is not yet implemented. Check that this does not fold.
+func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
+                                                            %a : vector<4xf32>, 
+                                                            %idx : index) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
+func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32, 
+                                                         %idx0 : index) 
+                                                         -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
+  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
+                                                         %idx0 : index) 
+                                                         -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
-  %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
+  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }
+
 // -----
 
 // CHECK-LABEL: @fold_extract_shuffle

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to match the folder for extractelements?

@Groverkss
Copy link
Member Author

Is this to match the folder for extractelements?

yes, i updated the description to mention that.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but please wait for an OK from Andrzej or Diego.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic makes sense, thanks! I've made a few suggestions how to improve consistency in tests.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the delay, this one require a bit of context-switching.

Thanks for all the detective work re folders, that's much appreciated! I think that we should also formalise a bit better what "dim1 broadcasting" means. I can help with that, but will be OOO for a week.

Some comments inline, lets continue the discussion :)

// CHECK-LABEL: fold_extract_broadcast_same_type_vec
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: return %[[A]] : vector<4xf32>
func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>,
func.func @fold_extract_broadcast_vector(%a : vector<4xf32>,

@Groverkss Groverkss force-pushed the vector_extract_broadcast_dynamic branch from 154c550 to 67dc952 Compare February 8, 2025 23:46
@Groverkss
Copy link
Member Author

@banach-space @dcaballe After a rethink, I found out I was wrong, and we can also extend the folder to dim-1 broadcasting cases. This should be a much simpler change now. Thanks for the reviews so far!

@banach-space
Copy link
Contributor

@banach-space @dcaballe After a rethink, I found out I was wrong, and we can also extend the folder to dim-1 broadcasting cases. This should be a much simpler change now. Thanks for the reviews so far!

Just to clarify - you have re-worked this?

Before reviewing, can we make sure we are on the same page re dim-1 broadcasting? #116234 (comment)

@Groverkss
Copy link
Member Author

Groverkss commented Feb 10, 2025

@banach-space @dcaballe After a rethink, I found out I was wrong, and we can also extend the folder to dim-1 broadcasting cases. This should be a much simpler change now. Thanks for the reviews so far!

Just to clarify - you have re-worked this?

Before reviewing, can we make sure we are on the same page re dim-1 broadcasting? #116234 (comment)

Yes, I reworked it a bit. Before, this PR did not work for dynamic indices when doing dim-1 broadcasting (because I thought folders could not change the number of operands, but they can). Now, this PR is adding dynamic indices support for all cases the existing folder covered for static cases.

re: dim-1 broadcasting. Happy to iterate on that PR first before we come back to this, having a look now.

@Groverkss Groverkss force-pushed the vector_extract_broadcast_dynamic branch from f5c130c to e2a4e6d Compare February 19, 2025 20:01
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've noticed a few typos :)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nit, thanks for seeing this through!

Please wait ~1 day before landing, I see that @dcaballe also commented on this in the past and he may want to take another look.

@Groverkss Groverkss merged commit 61fb954 into llvm:main Feb 24, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants