Skip to content

Commit d68a4b9

Browse files
authored
[mlir][linalg] Add support for masked vectorization of tensor.insert_slice (1/N) (#122927)
For context, `tensor.insert_slice` is vectorized using a `vector.transfer_read` + `vector.transfer_write` pair. An unmasked example is shown below: ```mlir // BEFORE VECTORIZATION %res = tensor.insert_slice %slice into %dest[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<5x3xi32> // AFTER VECTORIZATION %read = vector.transfer_read %source[%c0, %c0], %pad : tensor<5x1xi32>, vector<8x1xi32> %res = vector.transfer_write %read, %dest[%c0, %c2] : vector<8x1xi32>, tensor<5x3xi32> ``` This PR refactors `InsertSliceVectorizePattern` (which is used to vectorize `tensor.extract_slice`) to enable masked vectorization. ATM, only `vector.transfer_read` is masked. If `vector.transfer_write` also requires masking, the vectorizer will bail out. This will be addressed in a sub-sequent PR. Summary of changes: * Added an argument to specify vector sizes (behavior remains unchanged if vector sizes are not specified). * Renamed `InsertSliceVectorizePattern` to `vectorizeAsInsertSliceOp` and integrated into (alongside other hooks for vectorization) in `linalg::vectorize`. * Removed `populateInsertSliceVectorizationPatterns`, as `InsertSliceVectorizePattern` was its only pattern. * Updated `vectorizeAsInsertSliceOp` to support masking for the "read" operation. * Updated `@pad_and_insert_slice_dest` in "vectorization-pad-patterns.mlir" to reflect the removal of `populateInsertSliceVectorizationPatterns` from `ApplyPadVectorizationPatternsOps`.
1 parent d00579b commit d68a4b9

File tree

6 files changed

+294
-149
lines changed

6 files changed

+294
-149
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,11 +1731,6 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);
17311731
/// \see rewriteInIm2Col for more details.
17321732
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
17331733

1734-
/// Populates `patterns` with vectorisation patterns for tensor.insert_slice.
1735-
/// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either
1736-
/// expand or merge with other `populate{}`.
1737-
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns);
1738-
17391734
/// Populates `patterns` with patterns that vectorize tensor.pad.
17401735
/// These patterns are meant to apply in a complementary fashion. Benefits
17411736
/// are used to encode a certain ordering of pattern application. To avoid

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
265265
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
266266
RewritePatternSet &patterns) {
267267
linalg::populatePadOpVectorizationPatterns(patterns);
268-
linalg::populateInsertSliceVectorizationPatterns(patterns);
269268
}
270269

271270
//===----------------------------------------------------------------------===//
@@ -3504,9 +3503,6 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
35043503

35053504
patterns.add<CopyVectorizationPattern>(ctx);
35063505

3507-
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
3508-
linalg::populateInsertSliceVectorizationPatterns(patterns);
3509-
35103506
if (getVectorizePadding()) {
35113507
linalg::populatePadOpVectorizationPatterns(patterns);
35123508
// This creates an alternative path for lowering tensor.pad - by

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 178 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,37 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
5959
ArrayRef<bool> inputVecScalableFlags = {},
6060
bool flatten1DDepthwiseConv = false);
6161

62+
/// Vectorize tensor::InsertSliceOp with:
63+
/// * vector::TransferReadOp + vector::TransferWriteOp
64+
/// The vector sizes are either:
65+
/// * user-provided in `inputVectorSizes`, or
66+
/// * inferred from the static dims in the input and output tensors.
67+
/// Bails out if:
68+
/// * vector sizes are not user-provided, and
69+
/// * at least one dim is dynamic (in both the input and output tensors).
70+
///
71+
/// Before:
72+
/// !t_in_type = tensor<1x2x3xf32>
73+
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
74+
/// !v_type = vector<1x2x3xf32>
75+
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
76+
/// into !t_out_type
77+
/// After:
78+
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
79+
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
80+
static LogicalResult
81+
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
82+
ArrayRef<int64_t> inputVectorSizes,
83+
SmallVectorImpl<Value> &newResults);
84+
85+
/// Returns the effective Pad value for the input op, provided it's a scalar.
86+
///
87+
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
88+
/// this Op performs padding, retrieve the padding value provided that it's
89+
/// a scalar and static/fixed for all the padded values. Returns an empty value
90+
/// otherwise.
91+
static Value getStaticPadVal(Operation *op);
92+
6293
/// Return the unique instance of OpType in `block` if it is indeed unique.
6394
/// Return null if none or more than 1 instances exist.
6495
template <typename OpType>
@@ -1557,6 +1588,7 @@ static LogicalResult
15571588
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15581589
ArrayRef<int64_t> inputVectorSizes,
15591590
SmallVectorImpl<Value> &newResults) {
1591+
// TODO: Introduce a parent class that will handle the insertion point update.
15601592
OpBuilder::InsertionGuard g(rewriter);
15611593
rewriter.setInsertionPoint(packOp);
15621594

@@ -1633,6 +1665,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16331665
ArrayRef<int64_t> inputVectorSizes,
16341666
SmallVectorImpl<Value> &newResults) {
16351667

1668+
// TODO: Introduce a parent class that will handle the insertion point update.
16361669
OpBuilder::InsertionGuard g(rewriter);
16371670
rewriter.setInsertionPoint(unpackOp);
16381671

@@ -1763,7 +1796,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
17631796
auto padValue = padOp.getConstantPaddingValue();
17641797
Location loc = padOp.getLoc();
17651798

1766-
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1799+
// TODO: Introduce a parent class that will handle the insertion point update.
17671800
OpBuilder::InsertionGuard g(rewriter);
17681801
rewriter.setInsertionPoint(padOp);
17691802

@@ -1874,6 +1907,38 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18741907
return success();
18751908
}
18761909

1910+
static LogicalResult
1911+
vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1912+
ArrayRef<int64_t> inputVectorSizes) {
1913+
1914+
TypedValue<RankedTensorType> source = sliceOp.getSource();
1915+
auto sourceType = source.getType();
1916+
if (!VectorType::isValidElementType(sourceType.getElementType()))
1917+
return failure();
1918+
1919+
// Get the pad value.
1920+
// TransferReadOp (which is used to vectorize InsertSliceOp), requires a
1921+
// scalar padding value. Note that:
1922+
// * for in-bounds accesses,
1923+
// the value is actually irrelevant. There are 2 cases in which xfer.read
1924+
// accesses are known to be in-bounds:
1925+
// 1. The source shape is static (output vector sizes would be based on
1926+
// the source shape and hence all memory accesses would be in-bounds),
1927+
// 2. Masking is used, i.e. the output vector sizes are user-provided. In
1928+
// this case it is safe to assume that all memory accesses are in-bounds.
1929+
//
1930+
// When the value is not known and not needed, use 0. Otherwise, bail out.
1931+
Value padValue = getStaticPadVal(sliceOp);
1932+
bool isOutOfBoundsRead =
1933+
!sourceType.hasStaticShape() && inputVectorSizes.empty();
1934+
1935+
if (!padValue && isOutOfBoundsRead) {
1936+
LDBG("Failed to get a pad value for out-of-bounds read access\n");
1937+
return failure();
1938+
}
1939+
return success();
1940+
}
1941+
18771942
static LogicalResult vectorizeLinalgOpPrecondition(
18781943
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
18791944
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -2144,6 +2209,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
21442209
.Case<tensor::UnPackOp>([&](auto unpackOp) {
21452210
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
21462211
})
2212+
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2213+
return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2214+
})
21472215
.Default([](auto) { return failure(); });
21482216
}
21492217

@@ -2163,8 +2231,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
21632231
}
21642232

21652233
bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2166-
return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167-
op);
2234+
return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
2235+
tensor::InsertSliceOp>(op);
21682236
}
21692237

21702238
/// Emit a suitable vector form for an operation. If provided,
@@ -2244,6 +2312,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
22442312
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
22452313
results);
22462314
})
2315+
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2316+
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2317+
results);
2318+
})
22472319
.Case<tensor::UnPackOp>([&](auto unpackOp) {
22482320
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
22492321
inputVectorSizes, results);
@@ -2540,6 +2612,9 @@ struct PadOpVectorizationWithTransferWritePattern
25402612
/// this Op performs padding, retrieve the padding value provided that it's
25412613
/// a scalar and static/fixed for all the padded values. Returns an empty value
25422614
/// otherwise.
2615+
///
2616+
/// TODO: This is used twice (when checking vectorization pre-conditions and
2617+
/// when vectorizing). Cache results instead of re-running.
25432618
static Value getStaticPadVal(Operation *op) {
25442619
if (!op)
25452620
return {};
@@ -2583,113 +2658,118 @@ static Value getStaticPadVal(Operation *op) {
25832658
return {};
25842659
}
25852660

2586-
/// Rewrite tensor.insert.slice as a vector.transfer_read +
2587-
/// vector.transfer_write pair. The vector size is inferred from the static
2588-
/// dims in the input and output tensors. If a dim is dynamic in both the input
2589-
/// and output tensors, bails out.
2590-
///
2591-
/// Before:
2592-
/// !t_in_type = tensor<1x2x3xf32>
2593-
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
2594-
/// !v_type = vector<1x2x3xf32>
2595-
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596-
/// into !t_out_type
2597-
/// After:
2598-
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599-
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600-
///
2601-
/// TODO: Support masking
2602-
struct InsertSliceVectorizePattern
2603-
: public OpRewritePattern<tensor::InsertSliceOp> {
2604-
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2661+
static LogicalResult
2662+
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2663+
ArrayRef<int64_t> inputVectorSizes,
2664+
SmallVectorImpl<Value> &newResults) {
2665+
// TODO: Introduce a parent class that will handle the insertion point update.
2666+
OpBuilder::InsertionGuard g(rewriter);
2667+
rewriter.setInsertionPoint(sliceOp);
26052668

2606-
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2607-
PatternRewriter &rewriter) const final {
2608-
auto sourceType = sliceOp.getSource().getType();
2609-
if (!VectorType::isValidElementType(sourceType.getElementType()))
2610-
return failure();
2669+
TypedValue<RankedTensorType> source = sliceOp.getSource();
2670+
auto sourceType = source.getType();
2671+
auto resultType = sliceOp.getResultType();
26112672

2612-
auto resultType = sliceOp.getResultType();
2613-
2614-
// 1. Get the pad value.
2615-
// TransferReadOp requires a scalar padding value. Note that:
2616-
// * for in-bounds access, the value is actually irrelevant.
2617-
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618-
// 1. The source shape is static (output vector sizes would be based on
2619-
// the source shape and hence all memory accesses would be in-bounds),
2620-
// 2. Masking is used (output vector sizes would be user-provided, in which
2621-
// case it is assumed that all memory accesses are in-bounds). This
2622-
// remains a TODO.
2623-
//
2624-
// When the value is not known and not needed, use 0. Otherwise, bail out.
2625-
Value padValue = getStaticPadVal(sliceOp);
2626-
bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2627-
2628-
if (!padValue && isOutOfBoundsRead) {
2629-
LDBG("Failed to get a pad value for out-of-bounds read access\n");
2673+
Value padValue = getStaticPadVal(sliceOp);
2674+
2675+
if (!padValue) {
2676+
auto elemType = sourceType.getElementType();
2677+
padValue = rewriter.create<arith::ConstantOp>(
2678+
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2679+
}
2680+
2681+
// 2. Get the vector shape and in-bounds attributes
2682+
SmallVector<int64_t> vecShape;
2683+
SmallVector<bool> readInBounds;
2684+
SmallVector<bool> writeInBounds;
2685+
size_t rankDiff = resultType.getRank() - sourceType.getRank();
2686+
for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2687+
if (!inputVectorSizes.empty()) {
2688+
vecShape.push_back(inputVectorSizes[i]);
2689+
readInBounds.push_back(false);
2690+
writeInBounds.push_back(false);
2691+
} else if (!sourceType.isDynamicDim(i)) {
2692+
vecShape.push_back(sourceType.getDimSize(i));
2693+
// Source shape is statically known: Neither read nor write are
2694+
// out-of-bounds.
2695+
readInBounds.push_back(true);
2696+
writeInBounds.push_back(true);
2697+
} else if (!resultType.isDynamicDim(i)) {
2698+
// Source shape is not statically known, but result shape is.
2699+
// Vectorize with size of result shape. This may be larger than the
2700+
// source size.
2701+
// FIXME: Using rankDiff implies that the source tensor is inserted at
2702+
// the end of the destination tensor. However, that's not required.
2703+
vecShape.push_back(resultType.getDimSize(rankDiff + i));
2704+
// Read may be out-of-bounds because the result size could be larger
2705+
// than the source size.
2706+
readInBounds.push_back(false);
2707+
// Write will be in-bounds provided that the corresponding write idx is 0.
2708+
// To keep this logic simple, conservatively mark as out-of-bounds.
2709+
writeInBounds.push_back(false);
2710+
} else {
2711+
// Neither source nor result dim of padOp is static. Cannot vectorize
2712+
// the copy.
2713+
// TODO: Add support for masking
26302714
return failure();
26312715
}
2716+
}
2717+
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
26322718

2633-
if (!padValue) {
2634-
auto elemType = sourceType.getElementType();
2635-
padValue = rewriter.create<arith::ConstantOp>(
2636-
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2637-
}
2719+
// 3. Generate TransferReadOp.
2720+
SmallVector<Value> readIndices(
2721+
vecType.getRank(),
2722+
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2723+
Operation *read = rewriter.create<vector::TransferReadOp>(
2724+
sliceOp.getLoc(), vecType, source, readIndices, padValue,
2725+
ArrayRef<bool>{readInBounds});
26382726

2639-
// 2. Get the vector shape and in-bounds attributes
2640-
SmallVector<int64_t> vecShape;
2641-
SmallVector<bool> readInBounds;
2642-
SmallVector<bool> writeInBounds;
2643-
size_t rankDiff = resultType.getRank() - sourceType.getRank();
2644-
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2645-
if (!sourceType.isDynamicDim(i)) {
2646-
vecShape.push_back(sourceType.getDimSize(i));
2647-
// Source shape is statically known: Neither read nor write are
2648-
// out-of-bounds.
2649-
readInBounds.push_back(true);
2650-
writeInBounds.push_back(true);
2651-
} else if (!resultType.isDynamicDim(i)) {
2652-
// Source shape is not statically known, but result shape is.
2653-
// Vectorize with size of result shape. This may be larger than the
2654-
// source size.
2655-
// FIXME: Using rankDiff implies that the source tensor is inserted at
2656-
// the end of the destination tensor. However, that's not required.
2657-
vecShape.push_back(resultType.getDimSize(rankDiff + i));
2658-
// Read may be out-of-bounds because the result size could be larger
2659-
// than the source size.
2660-
readInBounds.push_back(false);
2661-
// Write will in-bounds provided that the corresponding write idx is 0.
2662-
// To keep this logic simple, conservatively mark as out-of-bounds.
2663-
writeInBounds.push_back(false);
2664-
} else {
2665-
// Neither source nor result dim of padOp is static. Cannot vectorize
2666-
// the copy.
2667-
// TODO: Add support for masking
2668-
return failure();
2669-
}
2727+
// If vector sizes are user provided, make sure to mask xfer_read.
2728+
if (!inputVectorSizes.empty()) {
2729+
auto *srcDefOp = source.getDefiningOp();
2730+
if (!srcDefOp) {
2731+
LDBG("Unable to get the defining Op of " << sliceOp);
2732+
return failure();
26702733
}
2671-
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
26722734

2673-
// 3. Generate TransferReadOp.
2674-
SmallVector<Value> readIndices(
2675-
vecType.getRank(),
2676-
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2677-
auto read = rewriter.create<vector::TransferReadOp>(
2678-
sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2679-
ArrayRef<bool>{readInBounds});
2735+
ReifiedRankedShapedTypeDims reifiedSrcSizes;
2736+
LogicalResult status =
2737+
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
2738+
rewriter, reifiedSrcSizes);
2739+
if (status.failed()) {
2740+
LDBG("Unable to reify result shapes of " << sliceOp);
2741+
return failure();
2742+
}
26802743

2681-
// 4. Generate TransferWriteOp.
2682-
auto writeIndices = getValueOrCreateConstantIndexOp(
2683-
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2744+
// Create the mask
2745+
SmallVector<int64_t> readMaskShape(
2746+
sliceOp.getSource().getType().getShape());
2747+
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
2748+
Value maskOp = rewriter.create<vector::CreateMaskOp>(
2749+
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
26842750

2685-
// 5. Finalize
2686-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2687-
sliceOp, read, sliceOp.getDest(), writeIndices,
2688-
ArrayRef<bool>{writeInBounds});
2751+
// Mask the xfer_read Op
2752+
read = mlir::vector::maskOperation(rewriter, read, maskOp);
2753+
}
26892754

2690-
return success();
2755+
// 4. Generate TransferWriteOp.
2756+
if (!inputVectorSizes.empty() &&
2757+
ShapedType::isDynamicShape(resultType.getShape())) {
2758+
LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp);
2759+
return failure();
26912760
}
2692-
};
2761+
2762+
auto writeIndices = getValueOrCreateConstantIndexOp(
2763+
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2764+
2765+
// 5. Finalize
2766+
Operation *write = rewriter.create<vector::TransferWriteOp>(
2767+
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
2768+
ArrayRef<bool>{writeInBounds});
2769+
newResults.push_back(write->getResult(0));
2770+
2771+
return success();
2772+
}
26932773

26942774
/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
26952775
/// ```
@@ -2778,11 +2858,6 @@ struct PadOpVectorizationWithInsertSlicePattern
27782858
}
27792859
};
27802860

2781-
void mlir::linalg::populateInsertSliceVectorizationPatterns(
2782-
RewritePatternSet &patterns) {
2783-
patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2784-
}
2785-
27862861
void mlir::linalg::populatePadOpVectorizationPatterns(
27872862
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
27882863
patterns.add<PadOpVectorizationWithTransferReadPattern,

0 commit comments

Comments
 (0)