Skip to content

[mlir][tensor] Fix slice canonicalizer for out-of-bounds cases #132534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025

Conversation

matthias-springer
Copy link
Member

Since #130487, tensor.extract_slice and tensor.insert_slice ops that are statically detected to go out of bounds are rejected by the verifier.

This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR).

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Since #130487, tensor.extract_slice and tensor.insert_slice ops that are statically detected to go out of bounds are rejected by the verifier.

This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR).


Full diff: https://github.com/llvm/llvm-project/pull/132534.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+36-2)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+50-43)
  • (modified) mlir/lib/Interfaces/ViewLikeInterface.cpp (+58)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+50)
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 8f07e43f847ae..e74326dba7c80 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
 
 namespace mlir {
 
+/// Result for slice bounds verification;
+struct SliceBoundsVerificationResult {
+  /// If set to "true", the slice bounds verification was successful.
+  bool isValid;
+  /// An error message that can be printed during op verification.
+  std::string errorMessage;
+};
+
+/// Verify that the offsets/sizes/strides-style access into the given shape
+/// is in-bounds. Only static values are verified. If `generateErrorMessage`
+/// is set to "true", an error message is produced that can be printed by the
+///  op verifier.
+SliceBoundsVerificationResult
+verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+                    ArrayRef<int64_t> staticSizes,
+                    ArrayRef<int64_t> staticStrides,
+                    bool generateErrorMessage = false);
+SliceBoundsVerificationResult verifyInBoundsSlice(
+    ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+    ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+    bool generateErrorMessage = false);
+
 /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
 /// constant arguments. This pattern assumes that the op has a suitable builder
 /// that takes a result type, a "source" operand and mixed offsets, sizes and
@@ -54,7 +76,8 @@ namespace mlir {
 /// returns the new result type of the op, based on the new offsets, sizes and
 /// strides. `CastOpFunc` is used to generate a cast op if the result type of
 /// the op has changed.
-template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
+template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
+          bool CheckInBounds = false>
 class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     : public OpRewritePattern<OpType> {
 public:
@@ -72,11 +95,22 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
-    // Create the new op in canonical form.
+    if (CheckInBounds) {
+      // Pattern does not apply if the produced op would not verify.
+      SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+          cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
+          mixedSizes, mixedStrides);
+      if (!sliceResult.isValid)
+        return failure();
+    }
+
+    // Compute the new result type.
     auto resultType =
         ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
     if (!resultType)
       return failure();
+
+    // Create the new op in canonical form.
     auto newOp =
         rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
                                 mixedOffsets, mixedSizes, mixedStrides);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2d5df07f8af4b..5f8493de991f3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
   }
 }
 
-/// Verify that the offsets/sizes/strides-style access into the given tensor
-/// is in-bounds. Only static information is verified.
-static LogicalResult verifyInBoundsSlice(Operation *op,
-                                         RankedTensorType tensorType,
-                                         ArrayRef<int64_t> staticOffsets,
-                                         ArrayRef<int64_t> staticSizes,
-                                         ArrayRef<int64_t> staticStrides) {
-  for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
-    // Nothing to verify for dynamic source dims.
-    if (tensorType.isDynamicDim(i))
-      continue;
-    // Nothing to verify if the offset is dynamic.
-    if (ShapedType::isDynamic(staticOffsets[i]))
-      continue;
-    if (staticOffsets[i] >= tensorType.getDimSize(i))
-      return op->emitOpError("offset ")
-             << i << " is out-of-bounds: " << staticOffsets[i]
-             << " >= " << tensorType.getDimSize(i);
-    if (ShapedType::isDynamic(staticSizes[i]) ||
-        ShapedType::isDynamic(staticStrides[i]))
-      continue;
-    int64_t lastPos =
-        staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
-    if (lastPos >= tensorType.getDimSize(i))
-      return op->emitOpError("slice along dimension ")
-             << i << " runs out-of-bounds: " << lastPos
-             << " >= " << tensorType.getDimSize(i);
-  }
-  return success();
-}
-
 /// Verifier for ExtractSliceOp.
 LogicalResult ExtractSliceOp::verify() {
   RankedTensorType sourceType = getSourceType();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
 
   // Verify that offsets, sizes, strides do not run out-of-bounds with respect
   // to the source tensor.
-  return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
-                             getStaticSizes(), getStaticStrides());
+  SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+      sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
+      getStaticStrides(), /*generateErrorMessage=*/true);
+  if (!boundsResult.isValid)
+    return getOperation()->emitError(boundsResult.errorMessage);
+
+  return success();
 }
 
 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
     if (!canFoldIntoConsumerOp(castOp))
       return failure();
 
+    // Pattern does not apply if the produced op would not verify.
+    SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+        cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
+        sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
+        sliceOp.getStaticStrides());
+    if (!sliceResult.isValid)
+      return failure();
+
     // Create folded extract.
     Location loc = sliceOp.getLoc();
     Value newResult = rewriter.create<ExtractSliceOp>(
@@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
 
 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<
-      OpWithOffsetSizesAndStridesConstantArgumentFolder<
-          ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
-      ExtractSliceOpCastFolder>(context);
+  results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+                  ExtractSliceOp, SliceReturnTypeCanonicalizer,
+                  SliceCanonicalizer, /*CheckInBounds=*/true>,
+              ExtractSliceOpCastFolder>(context);
 }
 
 //
@@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
     return produceSliceErrorMsg(result, *this, expectedType);
 
   // Verify that offsets, sizes, strides do not run out-of-bounds with respect
-  // to the source tensor.
-  return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
-                             getStaticSizes(), getStaticStrides());
+  // to the destination tensor.
+  SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+      getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+      getStaticStrides(), /*generateErrorMessage=*/true);
+  if (!boundsResult.isValid)
+    return getOperation()->emitError(boundsResult.errorMessage);
+
+  return success();
 }
 
 /// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
         failed(foldDynamicStrideList(mixedStrides)))
       return failure();
 
+    // Pattern does not apply if the produced op would not verify.
+    SliceBoundsVerificationResult sliceResult =
+        verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
+                            mixedOffsets, mixedSizes, mixedStrides);
+    if (!sliceResult.isValid)
+      return failure();
+
     // Create the new op in canonical form.
     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
@@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
         size = srcType.getDimSize(rankReducedIdx++);
       }
     }
+
+    // Pattern does not apply if the produced op would not verify.
     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
                             staticSizes, insertSliceOp.getStaticStrides()) !=
         SliceVerificationResult::Success)
       return failure();
+    SliceBoundsVerificationResult sliceResult =
+        verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
+                            mixedSizes, insertSliceOp.getMixedStrides());
+    if (!sliceResult.isValid)
+      return failure();
 
     Operation *replacement = rewriter.create<InsertOpTy>(
         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
@@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
     return produceSliceErrorMsg(result, *this, expectedType);
 
   // Verify that offsets, sizes, strides do not run out-of-bounds with respect
-  // to the source tensor.
-  return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
-                             getStaticSizes(), getStaticStrides());
+  // to the destination tensor.
+  SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+      getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+      getStaticStrides(), /*generateErrorMessage=*/true);
+  if (!boundsResult.isValid)
+    return getOperation()->emitError(boundsResult.errorMessage);
+
+  return success();
 }
 
 void ParallelInsertSliceOp::getCanonicalizationPatterns(
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 57b5cce7bb13b..70dd7b4aec88c 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
   return success();
 }
 
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+    ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+    ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
+    bool generateErrorMessage) {
+  SliceBoundsVerificationResult result;
+  result.isValid = true;
+  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
+    // Nothing to verify for dynamic source dims.
+    if (ShapedType::isDynamic(shape[i]))
+      continue;
+    // Nothing to verify if the offset is dynamic.
+    if (ShapedType::isDynamic(staticOffsets[i]))
+      continue;
+    if (staticOffsets[i] >= shape[i]) {
+      result.errorMessage =
+          std::string("offset ") + std::to_string(i) +
+          " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
+          " >= " + std::to_string(shape[i]);
+      result.isValid = false;
+      return result;
+    }
+    if (ShapedType::isDynamic(staticSizes[i]) ||
+        ShapedType::isDynamic(staticStrides[i]))
+      continue;
+    int64_t lastPos =
+        staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
+    if (lastPos >= shape[i]) {
+      result.errorMessage = std::string("slice along dimension ") +
+                            std::to_string(i) +
+                            " runs out-of-bounds: " + std::to_string(lastPos) +
+                            " >= " + std::to_string(shape[i]);
+      result.isValid = false;
+      return result;
+    }
+  }
+  return result;
+}
+
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+    ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+    ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+    bool generateErrorMessage) {
+  auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
+    SmallVector<int64_t> staticValues;
+    for (OpFoldResult ofr : ofrs) {
+      if (auto attr = dyn_cast<Attribute>(ofr)) {
+        staticValues.push_back(cast<IntegerAttr>(attr).getInt());
+      } else {
+        staticValues.push_back(ShapedType::kDynamic);
+      }
+    }
+    return staticValues;
+  };
+  return verifyInBoundsSlice(
+      shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
+      getStaticValues(mixedStrides), generateErrorMessage);
+}
+
 LogicalResult
 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90cc0ca658ffb..fd96328c6033d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1
 
 // -----
 
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+//       CHECK:   tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
+  %c10 = arith.constant 10 : index
+  %r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
+  return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+//       CHECK:   tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
+  %t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
+  %r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
+  return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+//       CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
+  %c10 = arith.constant 10 : index
+  %r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
+  return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+//       CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
+  %src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
+  %r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
+  return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+//       CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
+  %dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
+  %r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+  return %r : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
 //  CHECK-SAME:   %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
 //  CHECK-SAME:   %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>

@matthias-springer matthias-springer merged commit 529ee3c into main Mar 24, 2025
14 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/slice_can branch March 24, 2025 13:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants