Skip to content

[mlir][vector] Fix vector.broadcast lowering for scalable vectors #66344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 15, 2023

Conversation

banach-space
Copy link
Contributor

This patch makes sure that the following case is lowered correctly
("duplication"):

  func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
    %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
    return %res : vector<1x[32]xf32>
  }

This patch makes sure that the following case is lowered correctly
("duplication"):
```
  func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
    %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
    return %res : vector<1x[32]xf32>
  }
```
@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-core

Changes This patch makes sure that the following case is lowered correctly ("duplication"): ``` func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> { %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32> return %res : vector<1x[32]xf32> } ```

--
Full diff: https://github.com/llvm/llvm-project/pull/66344.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+2-1)
  • (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+11)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 7c606e0c35f0899..2937b2d08b06979 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -85,7 +85,8 @@ class BroadcastOpLowering : public OpRewritePattern&lt;vector::BroadcastOp&gt; {
     if (srcRank &lt; dstRank) {
       // Duplication.
       VectorType resType =
-          VectorType::get(dstType.getShape().drop_front(), eltType);
+          VectorType::get(dstType.getShape().drop_front(), eltType,
+                          dstType.getScalableDims().drop_front());
       Value bcst =
           rewriter.create&lt;vector::BroadcastOp&gt;(loc, resType, op.getSource());
       Value result = rewriter.create&lt;arith::ConstantOp&gt;(
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 2d3c88d751192aa..386102cf5b4d225 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector&lt;4x1x2xf32&gt;) -&gt; vector&lt;4x3x2
   return %0 : vector&lt;4x3x2xf32&gt;
 }
 
+// CHECK-LABEL:   func.func @broadcast_scalable_duplication
+// CHECK-SAME:      %[[ARG0:.*]]: vector&lt;[32]xf32&gt;)
+// CHECK:           %[[CST:.*]] = arith.constant dense&lt;0.000000e+00&gt; : vector&lt;1x[32]xf32&gt;
+// CHECK:           %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector&lt;[32]xf32&gt; into vector&lt;1x[32]xf32&gt;
+// CHECK:           return %[[RES]] : vector&lt;1x[32]xf32&gt;
+
+func.func @broadcast_scalable_duplication(%arg0: vector&lt;[32]xf32&gt;) -&gt; vector&lt;1x[32]xf32&gt; {
+  %res = vector.broadcast %arg0 : vector&lt;[32]xf32&gt; to vector&lt;1x[32]xf32&gt;
+  return %res : vector&lt;1x[32]xf32&gt;
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):
   %f = transform.structured.match ops{[&quot;func.func&quot;]} in %module_op 

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small suggestion but otherwise LGTM, cheers

@banach-space banach-space merged commit 57cf689 into llvm:main Sep 15, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…vm#66344)

This patch makes sure that the following case is lowered correctly
("duplication"):
```
  func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
    %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
    return %res : vector<1x[32]xf32>
  }
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants