Skip to content

Commit cf4a2c5

Browse files
committed
fixup! [mlir][linalg] Refactor vectorization hooks to improve code reuse
* Restore the original behaviour in `vectorizeAsInsertSliceOp`, whereby the `in_bounds` attribute was used to identify potentially out-of-bounds accesses. Masks are only used when input vector sizes are specified. * Revert the changes in insert-slice-with-patterns.mlir and pad-with-patterns.mlir, i.e. the tests in which we don't specify vector sizes. * Other minor updates.
1 parent 82cc2fe commit cf4a2c5

File tree

3 files changed

+23
-53
lines changed

3 files changed

+23
-53
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,9 +1659,10 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16591659
static_cast<size_t>(vecToStoreType.getRank()) &&
16601660
"Insufficient number of input vector sizes!");
16611661
// Update the inBounds attribute.
1662-
for (unsigned i = 0; i < destRank; i++)
1663-
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1664-
!ShapedType::isDynamic(destShape[i]);
1662+
for (unsigned i = 0; i < vecToStoreRank; i++)
1663+
inBoundsVal[i] =
1664+
(destShape[i] == inputVecSizesForLeadingDims[i]) &&
1665+
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
16651666
}
16661667

16671668
// If missing, initialize the write indices to 0.
@@ -1670,7 +1671,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16701671
"Invalid number of write indices!");
16711672
if (writeIndices.empty()) {
16721673
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1673-
writeIndices = SmallVector<Value>(destRank, zero);
1674+
writeIndices.assign(destRank, zero);
16741675
}
16751676

16761677
// Generate the xfer_write Op
@@ -1826,8 +1827,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18261827
transposeOp.getResult().getType().getElementType());
18271828
Operation *write = createWriteOrMaskedWrite(
18281829
rewriter, loc, transposeOp.getResult(), dest,
1829-
/*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
1830-
/*useInBoundsInsteadOfMasking=*/false);
1830+
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
18311831
newResults.push_back(write->getResult(0));
18321832
return success();
18331833
}
@@ -2000,8 +2000,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20002000
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
20012001
Operation *write = createWriteOrMaskedWrite(
20022002
rewriter, loc, maskedRead, dest,
2003-
/*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
2004-
/*useInBoundsInsteadOfMasking=*/false);
2003+
/*inputVecSizesForLeadingDims=*/inputVectorSizes);
20052004
newResults.push_back(write->getResult(0));
20062005
return success();
20072006
}
@@ -3007,39 +3006,24 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30073006
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
30083007
}
30093008

3010-
// 2. Get the vector shape and in-bounds attributes
3009+
// 2. Get the vector shape
30113010
SmallVector<int64_t> vecShape;
3012-
SmallVector<bool> readInBounds;
3013-
SmallVector<bool> writeInBounds;
30143011
size_t rankDiff = resultType.getRank() - sourceType.getRank();
30153012
for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
30163013
if (!inputVectorSizes.empty()) {
30173014
vecShape.push_back(inputVectorSizes[i]);
3018-
readInBounds.push_back(false);
3019-
writeInBounds.push_back(false);
30203015
} else if (!sourceType.isDynamicDim(i)) {
30213016
vecShape.push_back(sourceType.getDimSize(i));
3022-
// Source shape is statically known: Neither read nor write are
3023-
// out-of-bounds.
3024-
readInBounds.push_back(true);
3025-
writeInBounds.push_back(true);
30263017
} else if (!resultType.isDynamicDim(i)) {
30273018
// Source shape is not statically known, but result shape is.
30283019
// Vectorize with size of result shape. This may be larger than the
30293020
// source size.
30303021
// FIXME: Using rankDiff implies that the source tensor is inserted at
30313022
// the end of the destination tensor. However, that's not required.
30323023
vecShape.push_back(resultType.getDimSize(rankDiff + i));
3033-
// Read may be out-of-bounds because the result size could be larger
3034-
// than the source size.
3035-
readInBounds.push_back(false);
3036-
// Write will be in-bounds provided that the corresponding write idx is 0.
3037-
// To keep this logic simple, conservatively mark as out-of-bounds.
3038-
writeInBounds.push_back(false);
30393024
} else {
30403025
// Neither source nor result dim of padOp is static. Cannot vectorize
30413026
// the copy.
3042-
// TODO: Add support for masking
30433027
return failure();
30443028
}
30453029
}
@@ -3052,13 +3036,15 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30523036
SmallVector<Value> readIndices(
30533037
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
30543038
Value read = mlir::vector::createReadOrMaskedRead(
3055-
rewriter, loc, source, vecType.getShape(), padValue);
3039+
rewriter, loc, source, vecType.getShape(), padValue,
3040+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
30563041

30573042
// Create write
30583043
auto writeIndices =
30593044
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
30603045
Operation *write = createWriteOrMaskedWrite(
3061-
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
3046+
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices,
3047+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
30623048

30633049
// 4. Finalize
30643050
newResults.push_back(write->getResult(0));

mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,10 @@ module attributes {transform.with_named_sequence} {
6767
// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x?x3xf32>,
6868
// CHECK-SAME: %[[PAD:.*]]: f32,
6969
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
70-
// CHECK: %[[C3:.*]] = arith.constant 3 : index
71-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
72-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
7370
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
7471
// CHECK: %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
7572
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7x1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
76-
77-
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG_0]], %[[C1]] : tensor<1x?x3xf32>
78-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[D1]], %[[C3]] : vector<1x2x3xi1>
79-
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] {
80-
// CHECK-SAME: vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x?x3xf32>, vector<1x2x3xf32>
81-
// CHECK-SAME: } : vector<1x2x3xi1> -> vector<1x2x3xf32>
82-
73+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{.*}}, %[[PAD]] {in_bounds = [true, false, true]} : tensor<1x?x3xf32>, vector<1x2x3xf32>
8374
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[WRITE]]{{.*}} {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
8475
// CHECK: return %[[RES]] : tensor<9x8x7x1x2x3xf32>
8576
func.func @insert_dynamic_slice_non_zero_pad(%arg0: tensor<1x?x3xf32>, %pad : f32, %size: index) -> tensor<9x8x7x1x2x3xf32> {

mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,16 @@
55
///----------------------------------------------------------------------------------------
66

77
// CHECK-LABEL: func @pad_static(
8-
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x?x2xf32>,
9-
// CHECK-SAME: %[[ARG1:.*]]: f32) -> tensor<2x3x4xf32> {
10-
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
11-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
12-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
13-
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf32>
14-
// CHECK: %[[INIT:.*]] = vector.broadcast %[[ARG1]] : f32 to vector<2x3x4xf32>
15-
// CHECK: %[[OUT_TENSOR:.*]] = vector.transfer_write %[[INIT]], %[[EMPTY]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<2x3x4xf32>, tensor<2x3x4xf32>
16-
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<2x?x2xf32>
17-
// CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[C2]], %[[DIM_1]], %[[C2]] : vector<2x3x2xi1>
18-
// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {
19-
// CHECK-SAME: vector.transfer_read %[[ARG0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[ARG1]]
20-
// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<2x?x2xf32>, vector<2x3x2xf32>
21-
// CHECK-SAME: } : vector<2x3x2xi1> -> vector<2x3x2xf32>
22-
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[READ]], %[[OUT_TENSOR]]{{\[}}%[[C0]], %[[C0]], %[[C2]]]
23-
// CHECK-SAME: {in_bounds = [true, true, true]} : vector<2x3x2xf32>, tensor<2x3x4xf32>
24-
// CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
8+
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x?x2xf32>, %[[PAD:.*]]: f32
9+
// CHECK-NOT: tensor.pad
10+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
11+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
12+
// CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<2x3x4xf32>
13+
// CHECK-DAG: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x3x4xf32>
14+
// CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]]{{.*}} : vector<2x3x4xf32>, tensor<2x3x4xf32>
15+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, false, true]} : tensor<2x?x2xf32>, vector<2x3x2xf32>
16+
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x3x2xf32>, tensor<2x3x4xf32>
17+
// CHECK: return %[[RESULT]]
2518
func.func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
2619
%0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] {
2720
^bb0(%arg1: index, %arg2: index, %arg3: index):

0 commit comments

Comments
 (0)