Skip to content

Commit 57cf689

Browse files
authored
[mlir][vector] Fix vector.broadcast lowering for scalable vectors (#66344)
This patch makes sure that the following case is lowered correctly ("duplication"): ``` func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> { %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32> return %res : vector<1x[32]xf32> } ```
1 parent cadabb5 commit 57cf689

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
8484
// %x = [%b,%b,%b,%b] : n-D
8585
if (srcRank < dstRank) {
8686
// Duplication.
87-
VectorType resType =
88-
VectorType::get(dstType.getShape().drop_front(), eltType);
87+
VectorType resType = VectorType::Builder(dstType).dropDim(0);
8988
Value bcst =
9089
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
9190
Value result = rewriter.create<arith::ConstantOp>(

mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
162162
return %0 : vector<4x3x2xf32>
163163
}
164164

165+
// CHECK-LABEL: func.func @broadcast_scalable_duplication
166+
// CHECK-SAME: %[[ARG0:.*]]: vector<[32]xf32>)
167+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x[32]xf32>
168+
// CHECK: %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector<[32]xf32> into vector<1x[32]xf32>
169+
// CHECK: return %[[RES]] : vector<1x[32]xf32>
170+
171+
func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
172+
%res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
173+
return %res : vector<1x[32]xf32>
174+
}
175+
165176
transform.sequence failures(propagate) {
166177
^bb1(%module_op: !transform.any_op):
167178
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)