Skip to content

Commit 1f9f3cc

Browse files
committed
Fix LowerVectorTransfer patterns to remove unsupported transpose ops for scalable vectors
Signed-off-by: Crefeda Rodrigues <[email protected]>
1 parent 3b3de48 commit 1f9f3cc

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,19 @@ struct TransferWritePermutationLowering
205205
// Generate new transfer_write operation.
206206
Value newVec = rewriter.create<vector::TransposeOp>(
207207
op.getLoc(), op.getVector(), indices);
208+
209+
auto vectorType = cast<VectorType>(newVec.getType());
210+
211+
if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
212+
rewriter.eraseOp(newVec.getDefiningOp());
213+
return failure();
214+
}
215+
208216
auto newMap = AffineMap::getMinorIdentityMap(
209217
map.getNumDims(), map.getNumResults(), rewriter.getContext());
210218
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
211219
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
212220
op.getMask(), newInBoundsAttr);
213-
214221
return success();
215222
}
216223
};
@@ -273,7 +280,7 @@ struct TransferWriteNonPermutationLowering
273280
missingInnerDim.size());
274281
// Mask: add unit dims at the end of the shape.
275282
Value newMask;
276-
if (op.getMask())
283+
if (op.getMask() && !op.getVectorType().isScalable())
277284
newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
278285
missingInnerDim.size());
279286
exprs.append(map.getResults().begin(), map.getResults().end());

mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,23 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
4141
return %1 : vector<8x[4]x2xf32>
4242
}
4343

44-
// CHECK: func.func @permutation_with_mask_transfer_write_scalable(
45-
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
46-
// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
47-
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
48-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
49-
// CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
50-
// CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
51-
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
52-
// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
53-
// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
54-
// CHECK: return
55-
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
44+
// CHECK-LABEL: func.func @permutation_with_mask_transfer_write_scalable(
45+
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
46+
// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
47+
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
48+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
49+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
50+
// CHECK: vector.transfer_write %[[BCAST]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true, true], permutation_map = #map} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
51+
// CHECK: return
52+
// CHECK: }
53+
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
5654
%c0 = arith.constant 0 : index
5755
vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
5856
} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
5957

6058
return
61-
}
59+
}
60+
6261
module attributes {transform.with_named_sequence} {
6362
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
6463
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)