Skip to content

[mlir][linalg] Add support for masked vectorization of tensor.insert_slice (1/N) #122927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1731,11 +1731,6 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);
/// \see rewriteInIm2Col for more details.
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with vectorisation patterns for tensor.insert_slice.
/// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either
/// expand or merge with other `populate{}`.
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that vectorize tensor.pad.
/// These patterns are meant to apply in a complementary fashion. Benefits
/// are used to encode a certain ordering of pattern application. To avoid
Expand Down
4 changes: 0 additions & 4 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::populatePadOpVectorizationPatterns(patterns);
linalg::populateInsertSliceVectorizationPatterns(patterns);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3504,9 +3503,6 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(

patterns.add<CopyVectorizationPattern>(ctx);

// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
linalg::populateInsertSliceVectorizationPatterns(patterns);

if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
Expand Down
281 changes: 178 additions & 103 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
ArrayRef<bool> inputVecScalableFlags = {},
bool flatten1DDepthwiseConv = false);

/// Vectorize tensor::InsertSliceOp with:
/// * vector::TransferReadOp + vector::TransferWriteOp
/// The vector sizes are either:
/// * user-provided in `inputVectorSizes`, or
/// * inferred from the static dims in the input and output tensors.
/// Bails out if:
/// * vector sizes are not user-provided, and
/// * at least one dim is dynamic (in both the input and output tensors).
///
/// Before:
/// !t_in_type = tensor<1x2x3xf32>
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
/// !v_type = vector<1x2x3xf32>
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
/// into !t_out_type
/// After:
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
static LogicalResult
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults);

/// Returns the effective Pad value for the input op, provided it's a scalar.
///
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
/// this Op performs padding, retrieve the padding value provided that it's
/// a scalar and static/fixed for all the padded values. Returns an empty value
/// otherwise.
static Value getStaticPadVal(Operation *op);

/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
template <typename OpType>
Expand Down Expand Up @@ -1557,6 +1588,7 @@ static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);

Expand Down Expand Up @@ -1633,6 +1665,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {

// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);

Expand Down Expand Up @@ -1763,7 +1796,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
auto padValue = padOp.getConstantPaddingValue();
Location loc = padOp.getLoc();

// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(padOp);

Expand Down Expand Up @@ -1874,6 +1907,38 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
return success();
}

static LogicalResult
vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
ArrayRef<int64_t> inputVectorSizes) {

TypedValue<RankedTensorType> source = sliceOp.getSource();
auto sourceType = source.getType();
if (!VectorType::isValidElementType(sourceType.getElementType()))
return failure();

// Get the pad value.
// TransferReadOp (which is used to vectorize InsertSliceOp), requires a
// scalar padding value. Note that:
// * for in-bounds accesses,
// the value is actually irrelevant. There are 2 cases in which xfer.read
// accesses are known to be in-bounds:
// 1. The source shape is static (output vector sizes would be based on
// the source shape and hence all memory accesses would be in-bounds),
// 2. Masking is used, i.e. the output vector sizes are user-provided. In
// this case it is safe to assume that all memory accesses are in-bounds.
//
// When the value is not known and not needed, use 0. Otherwise, bail out.
Value padValue = getStaticPadVal(sliceOp);
bool isOutOfBoundsRead =
!sourceType.hasStaticShape() && inputVectorSizes.empty();

if (!padValue && isOutOfBoundsRead) {
LDBG("Failed to get a pad value for out-of-bounds read access\n");
return failure();
}
return success();
}

static LogicalResult vectorizeLinalgOpPrecondition(
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
Expand Down Expand Up @@ -2144,6 +2209,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
})
.Default([](auto) { return failure(); });
}

Expand All @@ -2163,8 +2231,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
}

bool mlir::linalg::hasVectorizationImpl(Operation *op) {
return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
op);
return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
tensor::InsertSliceOp>(op);
}

/// Emit a suitable vector form for an operation. If provided,
Expand Down Expand Up @@ -2244,6 +2312,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
results);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
Expand Down Expand Up @@ -2540,6 +2612,9 @@ struct PadOpVectorizationWithTransferWritePattern
/// this Op performs padding, retrieve the padding value provided that it's
/// a scalar and static/fixed for all the padded values. Returns an empty value
/// otherwise.
///
/// TODO: This is used twice (when checking vectorization pre-conditions and
/// when vectorizing). Cache results instead of re-running.
static Value getStaticPadVal(Operation *op) {
if (!op)
return {};
Expand Down Expand Up @@ -2583,113 +2658,118 @@ static Value getStaticPadVal(Operation *op) {
return {};
}

/// Rewrite tensor.insert.slice as a vector.transfer_read +
/// vector.transfer_write pair. The vector size is inferred from the static
/// dims in the input and output tensors. If a dim is dynamic in both the input
/// and output tensors, bails out.
///
/// Before:
/// !t_in_type = tensor<1x2x3xf32>
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
/// !v_type = vector<1x2x3xf32>
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
/// into !t_out_type
/// After:
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
///
/// TODO: Support masking
struct InsertSliceVectorizePattern
: public OpRewritePattern<tensor::InsertSliceOp> {
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
static LogicalResult
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
// TODO: Introduce a parent class that will handle the insertion point update.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by a parent "class"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these patterns do exactly the same:

We may as well just wrap this into a parent class.

OpBuilder::InsertionGuard g(rewriter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: g -> guard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with the suggestion, but I think that we should prioritise consistency and follow the existing style within the file:

😅

rewriter.setInsertionPoint(sliceOp);

LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
PatternRewriter &rewriter) const final {
auto sourceType = sliceOp.getSource().getType();
if (!VectorType::isValidElementType(sourceType.getElementType()))
return failure();
TypedValue<RankedTensorType> source = sliceOp.getSource();
auto sourceType = source.getType();
auto resultType = sliceOp.getResultType();

auto resultType = sliceOp.getResultType();

// 1. Get the pad value.
// TransferReadOp requires a scalar padding value. Note that:
// * for in-bounds access, the value is actually irrelevant.
// There are 2 cases in which xfer.read accesses are known to be in-bounds:
// 1. The source shape is static (output vector sizes would be based on
// the source shape and hence all memory accesses would be in-bounds),
// 2. Masking is used (output vector sizes would be user-provided, in which
// case it is assumed that all memory accesses are in-bounds). This
// remains a TODO.
//
// When the value is not known and not needed, use 0. Otherwise, bail out.
Value padValue = getStaticPadVal(sliceOp);
bool isOutOfBoundsRead = !sourceType.hasStaticShape();

if (!padValue && isOutOfBoundsRead) {
LDBG("Failed to get a pad value for out-of-bounds read access\n");
Value padValue = getStaticPadVal(sliceOp);

if (!padValue) {
auto elemType = sourceType.getElementType();
padValue = rewriter.create<arith::ConstantOp>(
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
}

// 2. Get the vector shape and in-bounds attributes
SmallVector<int64_t> vecShape;
SmallVector<bool> readInBounds;
SmallVector<bool> writeInBounds;
size_t rankDiff = resultType.getRank() - sourceType.getRank();
for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
if (!inputVectorSizes.empty()) {
vecShape.push_back(inputVectorSizes[i]);
readInBounds.push_back(false);
writeInBounds.push_back(false);
} else if (!sourceType.isDynamicDim(i)) {
vecShape.push_back(sourceType.getDimSize(i));
// Source shape is statically known: Neither read nor write are
// out-of-bounds.
readInBounds.push_back(true);
writeInBounds.push_back(true);
} else if (!resultType.isDynamicDim(i)) {
// Source shape is not statically known, but result shape is.
// Vectorize with size of result shape. This may be larger than the
// source size.
// FIXME: Using rankDiff implies that the source tensor is inserted at
// the end of the destination tensor. However, that's not required.
vecShape.push_back(resultType.getDimSize(rankDiff + i));
// Read may be out-of-bounds because the result size could be larger
// than the source size.
readInBounds.push_back(false);
// Write will be in-bounds provided that the corresponding write idx is 0.
// To keep this logic simple, conservatively mark as out-of-bounds.
writeInBounds.push_back(false);
} else {
// Neither source nor result dim of padOp is static. Cannot vectorize
// the copy.
// TODO: Add support for masking
return failure();
}
}
auto vecType = VectorType::get(vecShape, sourceType.getElementType());

if (!padValue) {
auto elemType = sourceType.getElementType();
padValue = rewriter.create<arith::ConstantOp>(
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
}
// 3. Generate TransferReadOp.
SmallVector<Value> readIndices(
vecType.getRank(),
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
Operation *read = rewriter.create<vector::TransferReadOp>(
sliceOp.getLoc(), vecType, source, readIndices, padValue,
ArrayRef<bool>{readInBounds});

// 2. Get the vector shape and in-bounds attributes
SmallVector<int64_t> vecShape;
SmallVector<bool> readInBounds;
SmallVector<bool> writeInBounds;
size_t rankDiff = resultType.getRank() - sourceType.getRank();
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
if (!sourceType.isDynamicDim(i)) {
vecShape.push_back(sourceType.getDimSize(i));
// Source shape is statically known: Neither read nor write are
// out-of-bounds.
readInBounds.push_back(true);
writeInBounds.push_back(true);
} else if (!resultType.isDynamicDim(i)) {
// Source shape is not statically known, but result shape is.
// Vectorize with size of result shape. This may be larger than the
// source size.
// FIXME: Using rankDiff implies that the source tensor is inserted at
// the end of the destination tensor. However, that's not required.
vecShape.push_back(resultType.getDimSize(rankDiff + i));
// Read may be out-of-bounds because the result size could be larger
// than the source size.
readInBounds.push_back(false);
// Write will in-bounds provided that the corresponding write idx is 0.
// To keep this logic simple, conservatively mark as out-of-bounds.
writeInBounds.push_back(false);
} else {
// Neither source nor result dim of padOp is static. Cannot vectorize
// the copy.
// TODO: Add support for masking
return failure();
}
// If vector sizes are user provided, make sure to mask xfer_read.
if (!inputVectorSizes.empty()) {
auto *srcDefOp = source.getDefiningOp();
if (!srcDefOp) {
LDBG("Unable to get the defining Op of " << sliceOp);
return failure();
}
auto vecType = VectorType::get(vecShape, sourceType.getElementType());

// 3. Generate TransferReadOp.
SmallVector<Value> readIndices(
vecType.getRank(),
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
ArrayRef<bool>{readInBounds});
ReifiedRankedShapedTypeDims reifiedSrcSizes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
rewriter, reifiedSrcSizes);
if (status.failed()) {
LDBG("Unable to reify result shapes of " << sliceOp);
return failure();
}

// 4. Generate TransferWriteOp.
auto writeIndices = getValueOrCreateConstantIndexOp(
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
// Create the mask
SmallVector<int64_t> readMaskShape(
sliceOp.getSource().getType().getShape());
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
Value maskOp = rewriter.create<vector::CreateMaskOp>(
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);

// 5. Finalize
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
sliceOp, read, sliceOp.getDest(), writeIndices,
ArrayRef<bool>{writeInBounds});
// Mask the xfer_read Op
read = mlir::vector::maskOperation(rewriter, read, maskOp);
}

return success();
// 4. Generate TransferWriteOp.
if (!inputVectorSizes.empty() &&
ShapedType::isDynamicShape(resultType.getShape())) {
LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp);
return failure();
}
};

auto writeIndices = getValueOrCreateConstantIndexOp(
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());

// 5. Finalize
Operation *write = rewriter.create<vector::TransferWriteOp>(
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
ArrayRef<bool>{writeInBounds});
newResults.push_back(write->getResult(0));

return success();
}

/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
/// ```
Expand Down Expand Up @@ -2778,11 +2858,6 @@ struct PadOpVectorizationWithInsertSlicePattern
}
};

void mlir::linalg::populateInsertSliceVectorizationPatterns(
RewritePatternSet &patterns) {
patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
}

void mlir::linalg::populatePadOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
patterns.add<PadOpVectorizationWithTransferReadPattern,
Expand Down
Loading