Skip to content

[mlir][tensor] Fix slice canonicalizer for out-of-bounds cases #132534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions mlir/include/mlir/Interfaces/ViewLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,

namespace mlir {

/// Result for slice bounds verification;
struct SliceBoundsVerificationResult {
/// If set to "true", the slice bounds verification was successful.
bool isValid;
/// An error message that can be printed during op verification.
std::string errorMessage;
};

/// Verify that the offsets/sizes/strides-style access into the given shape
/// is in-bounds. Only static values are verified. If `generateErrorMessage`
/// is set to "true", an error message is produced that can be printed by the
/// op verifier.
SliceBoundsVerificationResult
verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides,
bool generateErrorMessage = false);
SliceBoundsVerificationResult verifyInBoundsSlice(
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
bool generateErrorMessage = false);

/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
/// constant arguments. This pattern assumes that the op has a suitable builder
/// that takes a result type, a "source" operand and mixed offsets, sizes and
Expand All @@ -54,7 +76,8 @@ namespace mlir {
/// returns the new result type of the op, based on the new offsets, sizes and
/// strides. `CastOpFunc` is used to generate a cast op if the result type of
/// the op has changed.
template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
bool CheckInBounds = false>
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
: public OpRewritePattern<OpType> {
public:
Expand All @@ -72,11 +95,22 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
failed(foldDynamicIndexList(mixedStrides)))
return failure();

// Create the new op in canonical form.
if (CheckInBounds) {
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
mixedSizes, mixedStrides);
if (!sliceResult.isValid)
return failure();
}

// Compute the new result type.
auto resultType =
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
if (!resultType)
return failure();

// Create the new op in canonical form.
auto newOp =
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
mixedOffsets, mixedSizes, mixedStrides);
Expand Down
93 changes: 50 additions & 43 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}

/// Verify that the offsets/sizes/strides-style access into the given tensor
/// is in-bounds. Only static information is verified.
static LogicalResult verifyInBoundsSlice(Operation *op,
RankedTensorType tensorType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides) {
for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
// Nothing to verify for dynamic source dims.
if (tensorType.isDynamicDim(i))
continue;
// Nothing to verify if the offset is dynamic.
if (ShapedType::isDynamic(staticOffsets[i]))
continue;
if (staticOffsets[i] >= tensorType.getDimSize(i))
return op->emitOpError("offset ")
<< i << " is out-of-bounds: " << staticOffsets[i]
<< " >= " << tensorType.getDimSize(i);
if (ShapedType::isDynamic(staticSizes[i]) ||
ShapedType::isDynamic(staticStrides[i]))
continue;
int64_t lastPos =
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
if (lastPos >= tensorType.getDimSize(i))
return op->emitOpError("slice along dimension ")
<< i << " runs out-of-bounds: " << lastPos
<< " >= " << tensorType.getDimSize(i);
}
return success();
}

/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
Expand All @@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {

// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
getStaticSizes(), getStaticStrides());
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);

return success();
}

llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
Expand Down Expand Up @@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
if (!canFoldIntoConsumerOp(castOp))
return failure();

// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
sliceOp.getStaticStrides());
if (!sliceResult.isValid)
return failure();

// Create folded extract.
Location loc = sliceOp.getLoc();
Value newResult = rewriter.create<ExtractSliceOp>(
Expand Down Expand Up @@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {

void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
ExtractSliceOpCastFolder>(context);
results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer,
SliceCanonicalizer, /*CheckInBounds=*/true>,
ExtractSliceOpCastFolder>(context);
}

//
Expand Down Expand Up @@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);

// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides());
// to the destination tensor.
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);

return success();
}

/// If we have two consecutive InsertSliceOp writing to the same slice, we
Expand Down Expand Up @@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
failed(foldDynamicStrideList(mixedStrides)))
return failure();

// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult =
verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
mixedOffsets, mixedSizes, mixedStrides);
if (!sliceResult.isValid)
return failure();

// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
Expand Down Expand Up @@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
size = srcType.getDimSize(rankReducedIdx++);
}
}

// Pattern does not apply if the produced op would not verify.
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
SliceBoundsVerificationResult sliceResult =
verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
mixedSizes, insertSliceOp.getMixedStrides());
if (!sliceResult.isValid)
return failure();

Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
Expand Down Expand Up @@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);

// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides());
// to the destination tensor.
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);

return success();
}

void ParallelInsertSliceOp::getCanonicalizationPatterns(
Expand Down
58 changes: 58 additions & 0 deletions mlir/lib/Interfaces/ViewLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
return success();
}

SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
bool generateErrorMessage) {
SliceBoundsVerificationResult result;
result.isValid = true;
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
// Nothing to verify for dynamic source dims.
if (ShapedType::isDynamic(shape[i]))
continue;
// Nothing to verify if the offset is dynamic.
if (ShapedType::isDynamic(staticOffsets[i]))
continue;
if (staticOffsets[i] >= shape[i]) {
result.errorMessage =
std::string("offset ") + std::to_string(i) +
" is out-of-bounds: " + std::to_string(staticOffsets[i]) +
" >= " + std::to_string(shape[i]);
result.isValid = false;
return result;
}
if (ShapedType::isDynamic(staticSizes[i]) ||
ShapedType::isDynamic(staticStrides[i]))
continue;
int64_t lastPos =
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
if (lastPos >= shape[i]) {
result.errorMessage = std::string("slice along dimension ") +
std::to_string(i) +
" runs out-of-bounds: " + std::to_string(lastPos) +
" >= " + std::to_string(shape[i]);
result.isValid = false;
return result;
}
}
return result;
}

SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
bool generateErrorMessage) {
auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> staticValues;
for (OpFoldResult ofr : ofrs) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
staticValues.push_back(cast<IntegerAttr>(attr).getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
}
}
return staticValues;
};
return verifyInBoundsSlice(
shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
getStaticValues(mixedStrides), generateErrorMessage);
}

LogicalResult
mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
Expand Down
50 changes: 50 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1

// -----

// CHECK-LABEL: func @out_of_bounds_extract_slice
// CHECK: tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
%c10 = arith.constant 10 : index
%r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
return %r : tensor<?xf32>
}

// -----

// CHECK-LABEL: func @out_of_bounds_extract_slice
// CHECK: tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
%t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
%r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
return %r : tensor<10xf32>
}

// -----

// CHECK-LABEL: func @out_of_bounds_insert_slice
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
%c10 = arith.constant 10 : index
%r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
return %r : tensor<10xf32>
}

// -----

// CHECK-LABEL: func @out_of_bounds_insert_slice
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
%src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
%r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
return %r : tensor<10xf32>
}

// -----

// CHECK-LABEL: func @out_of_bounds_insert_slice
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
%dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
%r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
return %r : tensor<?xf32>
}

// -----

// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
Expand Down