Skip to content

Commit d50705e

Browse files
authored
[mlir][vector] Support scalable vec in TransferReadAfterWriteToBroadcast (#79162)
Makes `TransferReadAfterWriteToBroadcast` correctly propagate scalability flags.
1 parent bae1ada commit d50705e

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4079,10 +4079,15 @@ struct TransferReadAfterWriteToBroadcast
40794079
// final shape we want.
40804080
ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
40814081
SmallVector<int64_t> broadcastShape(destShape.size());
4082-
for (const auto &pos : llvm::enumerate(permutation))
4082+
SmallVector<bool> broadcastScalableFlags(destShape.size());
4083+
for (const auto &pos : llvm::enumerate(permutation)) {
40834084
broadcastShape[pos.value()] = destShape[pos.index()];
4085+
broadcastScalableFlags[pos.value()] =
4086+
readOp.getVectorType().getScalableDims()[pos.index()];
4087+
}
40844088
VectorType broadcastedType = VectorType::get(
4085-
broadcastShape, defWrite.getVectorType().getElementType());
4089+
broadcastShape, defWrite.getVectorType().getElementType(),
4090+
broadcastScalableFlags);
40864091
vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
40874092
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
40884093
rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,24 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
13021302

13031303
// -----
13041304

1305+
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
1306+
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
1307+
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
1308+
// CHECK: return %[[B]] : vector<6x[4]xf32>
1309+
func.func @store_to_load_tensor_broadcast_scalable(%arg0 : tensor<?xf32>,
1310+
%v0 : vector<[4]xf32>) -> vector<6x[4]xf32> {
1311+
%c0 = arith.constant 0 : index
1312+
%cf0 = arith.constant 0.0 : f32
1313+
%w0 = vector.transfer_write %v0, %arg0[%c0] {in_bounds = [true]} :
1314+
vector<[4]xf32>, tensor<?xf32>
1315+
%0 = vector.transfer_read %w0[%c0], %cf0 {in_bounds = [true, true],
1316+
permutation_map = affine_map<(d0) -> (0, d0)>} :
1317+
tensor<?xf32>, vector<6x[4]xf32>
1318+
return %0 : vector<6x[4]xf32>
1319+
}
1320+
1321+
// -----
1322+
13051323
// CHECK-LABEL: func @store_to_load_tensor_perm_broadcast
13061324
// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4x4xf32>, %[[V0:.*]]: vector<4x1xf32>)
13071325
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x1xf32> to vector<100x5x4x1xf32>

0 commit comments

Comments
 (0)