-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Fix slice canonicalizer for out-of-bounds cases #132534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesSince #130487, This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR). Full diff: https://github.com/llvm/llvm-project/pull/132534.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 8f07e43f847ae..e74326dba7c80 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
namespace mlir {
+/// Result for slice bounds verification;
+struct SliceBoundsVerificationResult {
+ /// If set to "true", the slice bounds verification was successful.
+ bool isValid;
+ /// An error message that can be printed during op verification.
+ std::string errorMessage;
+};
+
+/// Verify that the offsets/sizes/strides-style access into the given shape
+/// is in-bounds. Only static values are verified. If `generateErrorMessage`
+/// is set to "true", an error message is produced that can be printed by the
+/// op verifier.
+SliceBoundsVerificationResult
+verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides,
+ bool generateErrorMessage = false);
+SliceBoundsVerificationResult verifyInBoundsSlice(
+ ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+ bool generateErrorMessage = false);
+
/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
/// constant arguments. This pattern assumes that the op has a suitable builder
/// that takes a result type, a "source" operand and mixed offsets, sizes and
@@ -54,7 +76,8 @@ namespace mlir {
/// returns the new result type of the op, based on the new offsets, sizes and
/// strides. `CastOpFunc` is used to generate a cast op if the result type of
/// the op has changed.
-template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
+template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
+ bool CheckInBounds = false>
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
: public OpRewritePattern<OpType> {
public:
@@ -72,11 +95,22 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
failed(foldDynamicIndexList(mixedStrides)))
return failure();
- // Create the new op in canonical form.
+ if (CheckInBounds) {
+ // Pattern does not apply if the produced op would not verify.
+ SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+ cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
+ mixedSizes, mixedStrides);
+ if (!sliceResult.isValid)
+ return failure();
+ }
+
+ // Compute the new result type.
auto resultType =
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
if (!resultType)
return failure();
+
+ // Create the new op in canonical form.
auto newOp =
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
mixedOffsets, mixedSizes, mixedStrides);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2d5df07f8af4b..5f8493de991f3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}
-/// Verify that the offsets/sizes/strides-style access into the given tensor
-/// is in-bounds. Only static information is verified.
-static LogicalResult verifyInBoundsSlice(Operation *op,
- RankedTensorType tensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides) {
- for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
- // Nothing to verify for dynamic source dims.
- if (tensorType.isDynamicDim(i))
- continue;
- // Nothing to verify if the offset is dynamic.
- if (ShapedType::isDynamic(staticOffsets[i]))
- continue;
- if (staticOffsets[i] >= tensorType.getDimSize(i))
- return op->emitOpError("offset ")
- << i << " is out-of-bounds: " << staticOffsets[i]
- << " >= " << tensorType.getDimSize(i);
- if (ShapedType::isDynamic(staticSizes[i]) ||
- ShapedType::isDynamic(staticStrides[i]))
- continue;
- int64_t lastPos =
- staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
- if (lastPos >= tensorType.getDimSize(i))
- return op->emitOpError("slice along dimension ")
- << i << " runs out-of-bounds: " << lastPos
- << " >= " << tensorType.getDimSize(i);
- }
- return success();
-}
-
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
- return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
- getStaticSizes(), getStaticStrides());
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+ sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
+ getStaticStrides(), /*generateErrorMessage=*/true);
+ if (!boundsResult.isValid)
+ return getOperation()->emitError(boundsResult.errorMessage);
+
+ return success();
}
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
if (!canFoldIntoConsumerOp(castOp))
return failure();
+ // Pattern does not apply if the produced op would not verify.
+ SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+ cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
+ sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
+ sliceOp.getStaticStrides());
+ if (!sliceResult.isValid)
+ return failure();
+
// Create folded extract.
Location loc = sliceOp.getLoc();
Value newResult = rewriter.create<ExtractSliceOp>(
@@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<
- OpWithOffsetSizesAndStridesConstantArgumentFolder<
- ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
- ExtractSliceOpCastFolder>(context);
+ results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+ ExtractSliceOp, SliceReturnTypeCanonicalizer,
+ SliceCanonicalizer, /*CheckInBounds=*/true>,
+ ExtractSliceOpCastFolder>(context);
}
//
@@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
- // to the source tensor.
- return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
- getStaticSizes(), getStaticStrides());
+ // to the destination tensor.
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+ getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+ getStaticStrides(), /*generateErrorMessage=*/true);
+ if (!boundsResult.isValid)
+ return getOperation()->emitError(boundsResult.errorMessage);
+
+ return success();
}
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
failed(foldDynamicStrideList(mixedStrides)))
return failure();
+ // Pattern does not apply if the produced op would not verify.
+ SliceBoundsVerificationResult sliceResult =
+ verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
+ mixedOffsets, mixedSizes, mixedStrides);
+ if (!sliceResult.isValid)
+ return failure();
+
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
@@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
size = srcType.getDimSize(rankReducedIdx++);
}
}
+
+ // Pattern does not apply if the produced op would not verify.
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
+ SliceBoundsVerificationResult sliceResult =
+ verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
+ mixedSizes, insertSliceOp.getMixedStrides());
+ if (!sliceResult.isValid)
+ return failure();
Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
@@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
- // to the source tensor.
- return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
- getStaticSizes(), getStaticStrides());
+ // to the destination tensor.
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+ getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+ getStaticStrides(), /*generateErrorMessage=*/true);
+ if (!boundsResult.isValid)
+ return getOperation()->emitError(boundsResult.errorMessage);
+
+ return success();
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 57b5cce7bb13b..70dd7b4aec88c 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
return success();
}
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+ ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
+ bool generateErrorMessage) {
+ SliceBoundsVerificationResult result;
+ result.isValid = true;
+ for (int64_t i = 0, e = shape.size(); i < e; ++i) {
+ // Nothing to verify for dynamic source dims.
+ if (ShapedType::isDynamic(shape[i]))
+ continue;
+ // Nothing to verify if the offset is dynamic.
+ if (ShapedType::isDynamic(staticOffsets[i]))
+ continue;
+ if (staticOffsets[i] >= shape[i]) {
+ result.errorMessage =
+ std::string("offset ") + std::to_string(i) +
+ " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
+ " >= " + std::to_string(shape[i]);
+ result.isValid = false;
+ return result;
+ }
+ if (ShapedType::isDynamic(staticSizes[i]) ||
+ ShapedType::isDynamic(staticStrides[i]))
+ continue;
+ int64_t lastPos =
+ staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
+ if (lastPos >= shape[i]) {
+ result.errorMessage = std::string("slice along dimension ") +
+ std::to_string(i) +
+ " runs out-of-bounds: " + std::to_string(lastPos) +
+ " >= " + std::to_string(shape[i]);
+ result.isValid = false;
+ return result;
+ }
+ }
+ return result;
+}
+
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+ ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+ bool generateErrorMessage) {
+ auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
+ SmallVector<int64_t> staticValues;
+ for (OpFoldResult ofr : ofrs) {
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ staticValues.push_back(cast<IntegerAttr>(attr).getInt());
+ } else {
+ staticValues.push_back(ShapedType::kDynamic);
+ }
+ }
+ return staticValues;
+ };
+ return verifyInBoundsSlice(
+ shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
+ getStaticValues(mixedStrides), generateErrorMessage);
+}
+
LogicalResult
mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90cc0ca658ffb..fd96328c6033d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1
// -----
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+// CHECK: tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
+ %c10 = arith.constant 10 : index
+ %r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
+ return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+// CHECK: tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
+ %t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
+ %r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
+ return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
+ %c10 = arith.constant 10 : index
+ %r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
+ return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
+ %src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
+ %r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
+ return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
+ %dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
+ %r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+ return %r : tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
|
Since #130487,
tensor.extract_slice
andtensor.insert_slice
ops that are statically detected to go out of bounds are rejected by the verifier.This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR).