Skip to content

Commit 153d795

Browse files
cfRodMacDue
authored andcommitted
[mlir][vector] Propagate scalability in TransferWriteNonPermutationLowering (llvm#85632)
Updates `extendVectorRank` so that scalability in patterns that use it (in particular, `TransferWriteNonPermutationLowering`), is correctly propagated. Closed related previous PR llvm#85270 --------- Signed-off-by: Crefeda Rodrigues <[email protected]> Co-authored-by: Benjamin Maxwell <[email protected]>
1 parent 12d6039 commit 153d795

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
4141
SmallVector<int64_t> newShape(addedRank, 1);
4242
newShape.append(originalVecType.getShape().begin(),
4343
originalVecType.getShape().end());
44-
VectorType newVecType =
45-
VectorType::get(newShape, originalVecType.getElementType());
44+
45+
SmallVector<bool> newScalableDims(addedRank, false);
46+
newScalableDims.append(originalVecType.getScalableDims().begin(),
47+
originalVecType.getScalableDims().end());
48+
VectorType newVecType = VectorType::get(
49+
newShape, originalVecType.getElementType(), newScalableDims);
4650
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
4751
}
4852

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
4141
return %1 : vector<8x[4]x2xf32>
4242
}
4343

44+
// CHECK: func.func @permutation_with_mask_transfer_write_scalable(
45+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
46+
// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
47+
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
48+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
49+
// CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
50+
// CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
51+
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
52+
// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
53+
// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
54+
// CHECK: return
55+
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
56+
%c0 = arith.constant 0 : index
57+
vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
58+
} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
59+
60+
return
61+
}
4462
module attributes {transform.with_named_sequence} {
4563
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
4664
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)