Skip to content

Commit 9802803

Browse files
committed
[mlir][linalg] Add support for masked vectorization of tensor.insert_slice (2/N)
For context, recall that `tensor.insert_slice` is vectorised using the `vector.transfer_read` + `vector.transfer_write` pair. An unmasked example is shown below: ```mlir // BEFORE VECTORIZATION %res = tensor.insert_slice %slice into %dest[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<5x3xi32> // AFTER VECTORIZATION %read = vector.transfer_read %source[%c0, %c0], %pad : tensor<5x1xi32>, vector<8x1xi32> %res = vector.transfer_write %read, %dest[%c0, %c2] : vector<8x1xi32>, tensor<5x3xi32> ``` This PR extends `vectorizeAsInsertSliceOp` to add masking support for the `vector.transfer_write` operation. This complements the changes in #122927, which introduced masking for the `vector.transfer_read`.
1 parent 709dd82 commit 9802803

File tree

3 files changed

+108
-52
lines changed

3 files changed

+108
-52
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2716,56 +2716,60 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
27162716
}
27172717
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
27182718

2719-
// 3. Generate TransferReadOp.
2720-
SmallVector<Value> readIndices(
2721-
vecType.getRank(),
2722-
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2723-
Operation *read = rewriter.create<vector::TransferReadOp>(
2724-
sliceOp.getLoc(), vecType, source, readIndices, padValue,
2725-
ArrayRef<bool>{readInBounds});
2719+
// 3. Generate TransferReadOp + TransferWriteOp
2720+
ReifiedRankedShapedTypeDims reifiedSrcSizes;
2721+
Value maskOp;
27262722

2727-
// If vector sizes are user provided, make sure to mask xfer_read.
2723+
// If vector sizes are user provided, make sure to mask. First, generate the
2724+
// mask.
27282725
if (!inputVectorSizes.empty()) {
27292726
auto *srcDefOp = source.getDefiningOp();
27302727
if (!srcDefOp) {
27312728
LDBG("Unable to get the defining Op of " << sliceOp);
27322729
return failure();
27332730
}
27342731

2735-
ReifiedRankedShapedTypeDims reifiedSrcSizes;
27362732
LogicalResult status =
27372733
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
27382734
rewriter, reifiedSrcSizes);
27392735
if (status.failed()) {
2740-
LDBG("Unable to reify result shapes of " << sliceOp);
2736+
LDBG("Unable to reify result shapes of " << srcDefOp);
27412737
return failure();
27422738
}
27432739

27442740
// Create the mask
2745-
SmallVector<int64_t> readMaskShape(
2746-
sliceOp.getSource().getType().getShape());
27472741
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
2748-
Value maskOp = rewriter.create<vector::CreateMaskOp>(
2742+
maskOp = rewriter.create<vector::CreateMaskOp>(
27492743
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
2750-
2751-
// Mask the xfer_read Op
2752-
read = mlir::vector::maskOperation(rewriter, read, maskOp);
27532744
}
27542745

2755-
// 4. Generate TransferWriteOp.
2756-
if (!inputVectorSizes.empty() &&
2757-
ShapedType::isDynamicShape(resultType.getShape())) {
2758-
LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp);
2759-
return failure();
2746+
// 3.a. TransferReadOp
2747+
SmallVector<Value> readIndices(
2748+
vecType.getRank(),
2749+
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2750+
Operation *read = rewriter.create<vector::TransferReadOp>(
2751+
sliceOp.getLoc(), vecType, source, readIndices, padValue,
2752+
ArrayRef<bool>{readInBounds});
2753+
2754+
// Mask the xfer_read Op
2755+
if (!inputVectorSizes.empty()) {
2756+
read = mlir::vector::maskOperation(rewriter, read, maskOp);
27602757
}
27612758

2759+
// 3.b. TransferWriteOp
27622760
auto writeIndices = getValueOrCreateConstantIndexOp(
27632761
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
27642762

2765-
// 5. Finalize
27662763
Operation *write = rewriter.create<vector::TransferWriteOp>(
27672764
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
27682765
ArrayRef<bool>{writeInBounds});
2766+
2767+
// Mask the xfer_write Op
2768+
if (!inputVectorSizes.empty()) {
2769+
write = mlir::vector::maskOperation(rewriter, write, maskOp);
2770+
}
2771+
2772+
// 4. Finalize
27692773
newResults.push_back(write->getResult(0));
27702774

27712775
return success();

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -280,26 +280,3 @@ module attributes {transform.with_named_sequence} {
280280
transform.yield
281281
}
282282
}
283-
284-
// -----
285-
286-
// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static).
287-
288-
func.func private @insert_slice_dynamic_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> {
289-
%c2 = arith.constant 2 : index
290-
%init = tensor.empty(%size) : tensor<?x3xi32>
291-
292-
%source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
293-
// expected-error @+1 {{Attempted to vectorize, but failed}}
294-
%res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<?x3xi32>
295-
296-
return %res : tensor<?x3xi32>
297-
}
298-
299-
module attributes {transform.with_named_sequence} {
300-
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
301-
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
302-
transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
303-
transform.yield
304-
}
305-
}

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,14 +1130,14 @@ func.func private @insert_slice_static_sizes(%source: tensor<?x3x?x1xi32>) -> te
11301130
// CHECK: %[[C_2:.*]] = arith.constant 2 : index
11311131
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<5x3xi32>
11321132
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SEC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
1133-
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
1134-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1135-
// CHECK: %[[C_5:.*]] = arith.constant 5 : index
1136-
// CHECK: %[[C_1:.*]] = arith.constant 1 : index
1133+
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
1134+
// CHECK-DAG: %[[C_5:.*]] = arith.constant 5 : index
1135+
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
11371136
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1>
1137+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
11381138
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
11391139
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
1140-
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32>
1140+
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32>
11411141
// CHECK: return %[[RES]] : tensor<5x3xi32>
11421142

11431143
module attributes {transform.with_named_sequence} {
@@ -1170,11 +1170,11 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor<?x3x?x1xi32>, %s
11701170
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
11711171
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
11721172
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
1173-
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
11741173
// CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C_1]] : vector<8x1xi1>
1174+
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
11751175
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<?x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
11761176
// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index
1177-
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32>
1177+
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32>
11781178
// CHECK: return %[[RES]] : tensor<5x3xi32>
11791179

11801180
module attributes {transform.with_named_sequence} {
@@ -1184,3 +1184,78 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor<?x3x?x1xi32>, %s
11841184
transform.yield
11851185
}
11861186
}
1187+
1188+
// -----
1189+
1190+
// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static).
1191+
1192+
func.func private @insert_slice_dynamic_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> {
1193+
%c2 = arith.constant 2 : index
1194+
%init = tensor.empty(%size) : tensor<?x3xi32>
1195+
1196+
%source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
1197+
%res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<?x3xi32>
1198+
1199+
return %res : tensor<?x3xi32>
1200+
}
1201+
1202+
// CHECK-LABEL: func.func private @insert_slice_dynamic_dest_dim(
1203+
// CHECK-SAME: %[[SRC:.*]]: tensor<?x3x?x1xi32>,
1204+
// CHECK-SAME: %[[size:.*]]: index) -> tensor<?x3xi32> {
1205+
// CHECK: %[[C_2:.*]] = arith.constant 2 : index
1206+
// CHECK: %[[INIT:.*]] = tensor.empty(%[[size]]) : tensor<?x3xi32>
1207+
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
1208+
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
1209+
// CHECK: %[[C_5:.*]] = arith.constant 5 : index
1210+
// CHECK: %[[C_1:.*]] = arith.constant 1 : index
1211+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1>
1212+
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
1213+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
1214+
// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index
1215+
// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<?x3xi32> } : vector<8x1xi1> -> tensor<?x3xi32>
1216+
// CHECK: return %[[WRITE]] : tensor<?x3xi32>
1217+
1218+
module attributes {transform.with_named_sequence} {
1219+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1220+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1221+
transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
1222+
transform.yield
1223+
}
1224+
}
1225+
1226+
// -----
1227+
1228+
// At least one _source_ and one _destination_ dimensions are dynamic.
1229+
1230+
func.func private @insert_slice_dynamic_source_and_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> {
1231+
%c2 = arith.constant 2 : index
1232+
%init = tensor.empty(%size) : tensor<?x3xi32>
1233+
1234+
%source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
1235+
%res = tensor.insert_slice %source_slice into %init[0, %c2] [%size, 1] [1, 1] : tensor<?x1xi32> into tensor<?x3xi32>
1236+
1237+
return %res : tensor<?x3xi32>
1238+
}
1239+
1240+
// CHECK-LABEL: func.func private @insert_slice_dynamic_source_and_dest_dim(
1241+
// CHECK-SAME: %[[SRC:.*]]: tensor<?x3x?x1xi32>,
1242+
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?x3xi32> {
1243+
// CHECK: %[[C_2:.*]] = arith.constant 2 : index
1244+
// CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor<?x3xi32>
1245+
// CHECK: %[[SRC_SIZE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
1246+
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
1247+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
1248+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C1]] : vector<8x1xi1>
1249+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1250+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SIZE]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<?x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
1251+
// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index
1252+
// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]]{{\[}}%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<?x3xi32> } : vector<8x1xi1> -> tensor<?x3xi32>
1253+
// CHECK: return %[[WRITE]] : tensor<?x3xi32>
1254+
1255+
module attributes {transform.with_named_sequence} {
1256+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1257+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1258+
transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
1259+
transform.yield
1260+
}
1261+
}

0 commit comments

Comments
 (0)