Skip to content

Commit ca54c1f

Browse files
committed
[mlir][linalg] Add support for masked vectorization of tensor.insert_slice
This PR refactors the `InsertSliceVectorizePattern` to enable masked vectorization of `tensor.insert_slice`. Note, `tensor.insert_slice` is vectorised using the `vector.transfer_read` + `vector.transfer_write` pair. ATM, only `vector.transfer_read` is masked. If `vector.transfer_write` also requires masking, the vectorizer will bail out. This will be addressed in a sub-sequent PR. Summary of changes: * Added an argument to specify vector sizes (behavior remains unchanged if vector sizes are not specified). * Renamed `InsertSliceVectorizePattern` to `vectorizeAsInsertSliceOp` and integrated into (alongside other hooks for vectorization) in `linalg::vectorize`. * Removed `populateInsertSliceVectorizationPatterns`, as `InsertSliceVectorizePattern` was its only pattern. * Updated `vectorizeAsInsertSliceOp` to support masking for the "read" operation. * Updated `@pad_and_insert_slice_dest` in "vectorization-pad-patterns.mlir" to reflect the removal of `populateInsertSliceVectorizationPatterns` from `ApplyPadVectorizationPatternsOps`.
1 parent df40b05 commit ca54c1f

File tree

6 files changed

+284
-150
lines changed

6 files changed

+284
-150
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,11 +1723,6 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);
17231723
/// \see rewriteInIm2Col for more details.
17241724
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
17251725

1726-
/// Populates `patterns` with vectorisation patterns for tensor.insert_slice.
1727-
/// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either
1728-
/// expand or merge with other `populate{}`.
1729-
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns);
1730-
17311726
/// Populates `patterns` with patterns that vectorize tensor.pad.
17321727
/// These patterns are meant to apply in a complementary fashion. Benefits
17331728
/// are used to encode a certain ordering of pattern application. To avoid

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
265265
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
266266
RewritePatternSet &patterns) {
267267
linalg::populatePadOpVectorizationPatterns(patterns);
268-
linalg::populateInsertSliceVectorizationPatterns(patterns);
269268
}
270269

271270
//===----------------------------------------------------------------------===//
@@ -3504,9 +3503,6 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
35043503

35053504
patterns.add<CopyVectorizationPattern>(ctx);
35063505

3507-
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
3508-
linalg::populateInsertSliceVectorizationPatterns(patterns);
3509-
35103506
if (getVectorizePadding()) {
35113507
linalg::populatePadOpVectorizationPatterns(patterns);
35123508
// This creates an alternative path for lowering tensor.pad - by

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 167 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,30 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
5959
ArrayRef<bool> inputVecScalableFlags = {},
6060
bool flatten1DDepthwiseConv = false);
6161

62+
/// Vectorize tensor::InsertSliceOp with:
63+
/// * vector::TransferReadOp + vector::TransferWriteOp
64+
/// The vector sizes are either:
65+
/// * user-provided in `inputVectorSizes`, or
66+
/// * inferred from the static dims in the input and output tensors.
67+
/// Bails out if:
68+
/// * vector sizes are not user-provided, and
69+
/// * at least one dim is dynamic (in both the input and output tensors),
70+
/// bails out.
71+
///
72+
/// Before:
73+
/// !t_in_type = tensor<1x2x3xf32>
74+
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
75+
/// !v_type = vector<1x2x3xf32>
76+
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
77+
/// into !t_out_type
78+
/// After:
79+
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
80+
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
81+
static LogicalResult
82+
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
83+
ArrayRef<int64_t> inputVectorSizes,
84+
SmallVectorImpl<Value> &newResults);
85+
6286
/// Return the unique instance of OpType in `block` if it is indeed unique.
6387
/// Return null if none or more than 1 instances exist.
6488
template <typename OpType>
@@ -1557,6 +1581,7 @@ static LogicalResult
15571581
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
15581582
ArrayRef<int64_t> inputVectorSizes,
15591583
SmallVectorImpl<Value> &newResults) {
1584+
// TODO: Introduce a parent class that will handle the insertion point update.
15601585
OpBuilder::InsertionGuard g(rewriter);
15611586
rewriter.setInsertionPoint(packOp);
15621587

@@ -1633,6 +1658,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16331658
ArrayRef<int64_t> inputVectorSizes,
16341659
SmallVectorImpl<Value> &newResults) {
16351660

1661+
// TODO: Introduce a parent class that will handle the insertion point update.
16361662
OpBuilder::InsertionGuard g(rewriter);
16371663
rewriter.setInsertionPoint(unpackOp);
16381664

@@ -1763,7 +1789,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
17631789
auto padValue = padOp.getConstantPaddingValue();
17641790
Location loc = padOp.getLoc();
17651791

1766-
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1792+
// TODO: Introduce a parent class that will handle the insertion point update.
17671793
OpBuilder::InsertionGuard g(rewriter);
17681794
rewriter.setInsertionPoint(padOp);
17691795

@@ -1874,6 +1900,15 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18741900
return success();
18751901
}
18761902

1903+
/// Need to check if the inner-tiles are static/constant.
1904+
static LogicalResult
1905+
vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1906+
ArrayRef<int64_t> inputVectorSizes) {
1907+
1908+
// TODO: Move pre-conditions from the vectorization logic
1909+
return success();
1910+
}
1911+
18771912
static LogicalResult vectorizeLinalgOpPrecondition(
18781913
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
18791914
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -2144,6 +2179,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
21442179
.Case<tensor::UnPackOp>([&](auto unpackOp) {
21452180
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
21462181
})
2182+
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2183+
return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2184+
})
21472185
.Default([](auto) { return failure(); });
21482186
}
21492187

@@ -2163,8 +2201,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
21632201
}
21642202

21652203
bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2166-
return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167-
op);
2204+
return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
2205+
tensor::InsertSliceOp>(op);
21682206
}
21692207

21702208
/// Emit a suitable vector form for an operation. If provided,
@@ -2178,6 +2216,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
21782216
ArrayRef<bool> inputScalableVecDims,
21792217
bool vectorizeNDExtract,
21802218
bool flatten1DDepthwiseConv) {
2219+
rewriter.getInsertionPoint();
21812220
LDBG("Attempting to vectorize:\n" << *op << "\n");
21822221
LDBG("Input vector sizes: ");
21832222
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2244,6 +2283,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
22442283
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
22452284
results);
22462285
})
2286+
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2287+
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2288+
results);
2289+
})
22472290
.Case<tensor::UnPackOp>([&](auto unpackOp) {
22482291
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
22492292
inputVectorSizes, results);
@@ -2583,113 +2626,139 @@ static Value getStaticPadVal(Operation *op) {
25832626
return {};
25842627
}
25852628

2586-
/// Rewrite tensor.insert.slice as a vector.transfer_read +
2587-
/// vector.transfer_write pair. The vector size is inferred from the static
2588-
/// dims in the input and output tensors. If a dim is dynamic in both the input
2589-
/// and output tensors, bails out.
2590-
///
2591-
/// Before:
2592-
/// !t_in_type = tensor<1x2x3xf32>
2593-
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
2594-
/// !v_type = vector<1x2x3xf32>
2595-
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596-
/// into !t_out_type
2597-
/// After:
2598-
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599-
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600-
///
2601-
/// TODO: Support masking
2602-
struct InsertSliceVectorizePattern
2603-
: public OpRewritePattern<tensor::InsertSliceOp> {
2604-
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2629+
static LogicalResult
2630+
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2631+
ArrayRef<int64_t> inputVectorSizes,
2632+
SmallVectorImpl<Value> &newResults) {
2633+
// TODO: Introduce a parent class that will handle the insertion point update.
2634+
OpBuilder::InsertionGuard g(rewriter);
2635+
rewriter.setInsertionPoint(sliceOp);
26052636

2606-
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
2607-
PatternRewriter &rewriter) const final {
2608-
auto sourceType = sliceOp.getSource().getType();
2609-
if (!VectorType::isValidElementType(sourceType.getElementType()))
2610-
return failure();
2637+
TypedValue<RankedTensorType> source = sliceOp.getSource();
2638+
auto sourceType = source.getType();
2639+
if (!VectorType::isValidElementType(sourceType.getElementType()))
2640+
return failure();
26112641

2612-
auto resultType = sliceOp.getResultType();
2613-
2614-
// 1. Get the pad value.
2615-
// TransferReadOp requires a scalar padding value. Note that:
2616-
// * for in-bounds access, the value is actually irrelevant.
2617-
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618-
// 1. The source shape is static (output vector sizes would be based on
2619-
// the source shape and hence all memory accesses would be in-bounds),
2620-
// 2. Masking is used (output vector sizes would be user-provided, in which
2621-
// case it is assumed that all memory accesses are in-bounds). This
2622-
// remains a TODO.
2623-
//
2624-
// When the value is not known and not needed, use 0. Otherwise, bail out.
2625-
Value padValue = getStaticPadVal(sliceOp);
2626-
bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2627-
2628-
if (!padValue && isOutOfBoundsRead) {
2629-
LDBG("Failed to get a pad value for out-of-bounds read access\n");
2642+
auto resultType = sliceOp.getResultType();
2643+
2644+
// 1. Get the pad value.
2645+
// TransferReadOp requires a scalar padding value. Note that:
2646+
// * for in-bounds access, the value is actually irrelevant.
2647+
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
2648+
// 1. The source shape is static (output vector sizes would be based on
2649+
// the source shape and hence all memory accesses would be in-bounds),
2650+
// 2. Masking is used (output vector sizes would be user-provided, in which
2651+
// case it is assumed that all memory accesses are in-bounds). This
2652+
// remains a TODO.
2653+
//
2654+
// When the value is not known and not needed, use 0. Otherwise, bail out.
2655+
Value padValue = getStaticPadVal(sliceOp);
2656+
bool isOutOfBoundsRead =
2657+
!sourceType.hasStaticShape() && inputVectorSizes.empty();
2658+
2659+
if (!padValue && isOutOfBoundsRead) {
2660+
LDBG("Failed to get a pad value for out-of-bounds read access\n");
2661+
return failure();
2662+
}
2663+
2664+
if (!padValue) {
2665+
auto elemType = sourceType.getElementType();
2666+
padValue = rewriter.create<arith::ConstantOp>(
2667+
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2668+
}
2669+
2670+
// 2. Get the vector shape and in-bounds attributes
2671+
SmallVector<int64_t> vecShape;
2672+
SmallVector<bool> readInBounds;
2673+
SmallVector<bool> writeInBounds;
2674+
size_t rankDiff = resultType.getRank() - sourceType.getRank();
2675+
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2676+
if (!inputVectorSizes.empty()) {
2677+
vecShape.push_back(inputVectorSizes[i]);
2678+
readInBounds.push_back(false);
2679+
writeInBounds.push_back(false);
2680+
} else if (!sourceType.isDynamicDim(i)) {
2681+
vecShape.push_back(sourceType.getDimSize(i));
2682+
// Source shape is statically known: Neither read nor write are
2683+
// out-of-bounds.
2684+
readInBounds.push_back(true);
2685+
writeInBounds.push_back(true);
2686+
} else if (!resultType.isDynamicDim(i)) {
2687+
// Source shape is not statically known, but result shape is.
2688+
// Vectorize with size of result shape. This may be larger than the
2689+
// source size.
2690+
// FIXME: Using rankDiff implies that the source tensor is inserted at
2691+
// the end of the destination tensor. However, that's not required.
2692+
vecShape.push_back(resultType.getDimSize(rankDiff + i));
2693+
// Read may be out-of-bounds because the result size could be larger
2694+
// than the source size.
2695+
readInBounds.push_back(false);
2696+
// Write will be in-bounds provided that the corresponding write idx is 0.
2697+
// To keep this logic simple, conservatively mark as out-of-bounds.
2698+
writeInBounds.push_back(false);
2699+
} else {
2700+
// Neither source nor result dim of padOp is static. Cannot vectorize
2701+
// the copy.
2702+
// TODO: Add support for masking
26302703
return failure();
26312704
}
2705+
}
2706+
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
26322707

2633-
if (!padValue) {
2634-
auto elemType = sourceType.getElementType();
2635-
padValue = rewriter.create<arith::ConstantOp>(
2636-
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2637-
}
2708+
// 3. Generate TransferReadOp.
2709+
SmallVector<Value> readIndices(
2710+
vecType.getRank(),
2711+
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2712+
Operation *read = rewriter.create<vector::TransferReadOp>(
2713+
sliceOp.getLoc(), vecType, source, readIndices, padValue,
2714+
ArrayRef<bool>{readInBounds});
26382715

2639-
// 2. Get the vector shape and in-bounds attributes
2640-
SmallVector<int64_t> vecShape;
2641-
SmallVector<bool> readInBounds;
2642-
SmallVector<bool> writeInBounds;
2643-
size_t rankDiff = resultType.getRank() - sourceType.getRank();
2644-
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2645-
if (!sourceType.isDynamicDim(i)) {
2646-
vecShape.push_back(sourceType.getDimSize(i));
2647-
// Source shape is statically known: Neither read nor write are
2648-
// out-of-bounds.
2649-
readInBounds.push_back(true);
2650-
writeInBounds.push_back(true);
2651-
} else if (!resultType.isDynamicDim(i)) {
2652-
// Source shape is not statically known, but result shape is.
2653-
// Vectorize with size of result shape. This may be larger than the
2654-
// source size.
2655-
// FIXME: Using rankDiff implies that the source tensor is inserted at
2656-
// the end of the destination tensor. However, that's not required.
2657-
vecShape.push_back(resultType.getDimSize(rankDiff + i));
2658-
// Read may be out-of-bounds because the result size could be larger
2659-
// than the source size.
2660-
readInBounds.push_back(false);
2661-
// Write will in-bounds provided that the corresponding write idx is 0.
2662-
// To keep this logic simple, conservatively mark as out-of-bounds.
2663-
writeInBounds.push_back(false);
2664-
} else {
2665-
// Neither source nor result dim of padOp is static. Cannot vectorize
2666-
// the copy.
2667-
// TODO: Add support for masking
2668-
return failure();
2669-
}
2716+
// If vector sizes are user provided, make sure to mask xfer_read.
2717+
if (!inputVectorSizes.empty()) {
2718+
auto *srcDefOp = source.getDefiningOp();
2719+
if (!srcDefOp) {
2720+
LDBG("Unable to get the defining Op of " << sliceOp);
2721+
return failure();
26702722
}
2671-
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
26722723

2673-
// 3. Generate TransferReadOp.
2674-
SmallVector<Value> readIndices(
2675-
vecType.getRank(),
2676-
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2677-
auto read = rewriter.create<vector::TransferReadOp>(
2678-
sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2679-
ArrayRef<bool>{readInBounds});
2724+
ReifiedRankedShapedTypeDims reifiedSrcSizes;
2725+
LogicalResult status =
2726+
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp)
2727+
.reifyResultShapes(rewriter, reifiedSrcSizes);
2728+
if (status.failed()) {
2729+
LDBG("Unable to reify result shapes of " << sliceOp);
2730+
return failure();
2731+
}
26802732

2681-
// 4. Generate TransferWriteOp.
2682-
auto writeIndices = getValueOrCreateConstantIndexOp(
2683-
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2733+
// Create the mask
2734+
SmallVector<int64_t> readMaskShape(
2735+
sliceOp.getSource().getType().getShape());
2736+
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
2737+
Value maskOp = rewriter.create<vector::CreateMaskOp>(
2738+
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
26842739

2685-
// 5. Finalize
2686-
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2687-
sliceOp, read, sliceOp.getDest(), writeIndices,
2688-
ArrayRef<bool>{writeInBounds});
2740+
// Mask the xfer_read Op
2741+
read = mlir::vector::maskOperation(rewriter, read, maskOp);
2742+
}
26892743

2690-
return success();
2744+
// 4. Generate TransferWriteOp.
2745+
if (!inputVectorSizes.empty() &&
2746+
ShapedType::isDynamicShape(resultType.getShape())) {
2747+
LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp);
2748+
return failure();
26912749
}
2692-
};
2750+
2751+
auto writeIndices = getValueOrCreateConstantIndexOp(
2752+
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2753+
2754+
// 5. Finalize
2755+
Operation *write = rewriter.create<vector::TransferWriteOp>(
2756+
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
2757+
ArrayRef<bool>{writeInBounds});
2758+
newResults.push_back(write->getResult(0));
2759+
2760+
return success();
2761+
}
26932762

26942763
/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
26952764
/// ```
@@ -2778,11 +2847,6 @@ struct PadOpVectorizationWithInsertSlicePattern
27782847
}
27792848
};
27802849

2781-
void mlir::linalg::populateInsertSliceVectorizationPatterns(
2782-
RewritePatternSet &patterns) {
2783-
patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
2784-
}
2785-
27862850
void mlir::linalg::populatePadOpVectorizationPatterns(
27872851
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
27882852
patterns.add<PadOpVectorizationWithTransferReadPattern,

0 commit comments

Comments
 (0)