Skip to content

Commit 529ee3c

Browse files
[mlir][tensor] Fix slice canonicalizer for out-of-bounds cases (#132534)
Since #130487, `tensor.extract_slice` and `tensor.insert_slice` ops that are statically detected to go out of bounds are rejected by the verifier. This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR).
1 parent 85974a0 commit 529ee3c

File tree

4 files changed

+194
-45
lines changed

4 files changed

+194
-45
lines changed

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
4545

4646
namespace mlir {
4747

48+
/// Result for slice bounds verification;
49+
struct SliceBoundsVerificationResult {
50+
/// If set to "true", the slice bounds verification was successful.
51+
bool isValid;
52+
/// An error message that can be printed during op verification.
53+
std::string errorMessage;
54+
};
55+
56+
/// Verify that the offsets/sizes/strides-style access into the given shape
57+
/// is in-bounds. Only static values are verified. If `generateErrorMessage`
58+
/// is set to "true", an error message is produced that can be printed by the
59+
/// op verifier.
60+
SliceBoundsVerificationResult
61+
verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
62+
ArrayRef<int64_t> staticSizes,
63+
ArrayRef<int64_t> staticStrides,
64+
bool generateErrorMessage = false);
65+
SliceBoundsVerificationResult verifyInBoundsSlice(
66+
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
67+
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
68+
bool generateErrorMessage = false);
69+
4870
/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
4971
/// constant arguments. This pattern assumes that the op has a suitable builder
5072
/// that takes a result type, a "source" operand and mixed offsets, sizes and
@@ -54,7 +76,8 @@ namespace mlir {
5476
/// returns the new result type of the op, based on the new offsets, sizes and
5577
/// strides. `CastOpFunc` is used to generate a cast op if the result type of
5678
/// the op has changed.
57-
template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
79+
template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
80+
bool CheckInBounds = false>
5881
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
5982
: public OpRewritePattern<OpType> {
6083
public:
@@ -72,11 +95,22 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
7295
failed(foldDynamicIndexList(mixedStrides)))
7396
return failure();
7497

75-
// Create the new op in canonical form.
98+
if (CheckInBounds) {
99+
// Pattern does not apply if the produced op would not verify.
100+
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
101+
cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
102+
mixedSizes, mixedStrides);
103+
if (!sliceResult.isValid)
104+
return failure();
105+
}
106+
107+
// Compute the new result type.
76108
auto resultType =
77109
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
78110
if (!resultType)
79111
return failure();
112+
113+
// Create the new op in canonical form.
80114
auto newOp =
81115
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
82116
mixedOffsets, mixedSizes, mixedStrides);

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Interfaces/InferIntRangeInterface.h"
2828
#include "mlir/Interfaces/LoopLikeInterface.h"
2929
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
30+
#include "mlir/Interfaces/ViewLikeInterface.h"
3031
#include "mlir/Support/LLVM.h"
3132
#include "llvm/ADT/DenseSet.h"
3233
#include "llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
23522353
}
23532354
}
23542355

2355-
/// Verify that the offsets/sizes/strides-style access into the given tensor
2356-
/// is in-bounds. Only static information is verified.
2357-
static LogicalResult verifyInBoundsSlice(Operation *op,
2358-
RankedTensorType tensorType,
2359-
ArrayRef<int64_t> staticOffsets,
2360-
ArrayRef<int64_t> staticSizes,
2361-
ArrayRef<int64_t> staticStrides) {
2362-
for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
2363-
// Nothing to verify for dynamic source dims.
2364-
if (tensorType.isDynamicDim(i))
2365-
continue;
2366-
// Nothing to verify if the offset is dynamic.
2367-
if (ShapedType::isDynamic(staticOffsets[i]))
2368-
continue;
2369-
if (staticOffsets[i] >= tensorType.getDimSize(i))
2370-
return op->emitOpError("offset ")
2371-
<< i << " is out-of-bounds: " << staticOffsets[i]
2372-
<< " >= " << tensorType.getDimSize(i);
2373-
if (ShapedType::isDynamic(staticSizes[i]) ||
2374-
ShapedType::isDynamic(staticStrides[i]))
2375-
continue;
2376-
int64_t lastPos =
2377-
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
2378-
if (lastPos >= tensorType.getDimSize(i))
2379-
return op->emitOpError("slice along dimension ")
2380-
<< i << " runs out-of-bounds: " << lastPos
2381-
<< " >= " << tensorType.getDimSize(i);
2382-
}
2383-
return success();
2384-
}
2385-
23862356
/// Verifier for ExtractSliceOp.
23872357
LogicalResult ExtractSliceOp::verify() {
23882358
RankedTensorType sourceType = getSourceType();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
23962366

23972367
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
23982368
// to the source tensor.
2399-
return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
2400-
getStaticSizes(), getStaticStrides());
2369+
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2370+
sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2371+
getStaticStrides(), /*generateErrorMessage=*/true);
2372+
if (!boundsResult.isValid)
2373+
return getOperation()->emitError(boundsResult.errorMessage);
2374+
2375+
return success();
24012376
}
24022377

24032378
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
24702445
if (!canFoldIntoConsumerOp(castOp))
24712446
return failure();
24722447

2448+
// Pattern does not apply if the produced op would not verify.
2449+
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
2450+
cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2451+
sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2452+
sliceOp.getStaticStrides());
2453+
if (!sliceResult.isValid)
2454+
return failure();
2455+
24732456
// Create folded extract.
24742457
Location loc = sliceOp.getLoc();
24752458
Value newResult = rewriter.create<ExtractSliceOp>(
@@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
26342617

26352618
void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
26362619
MLIRContext *context) {
2637-
results.add<
2638-
OpWithOffsetSizesAndStridesConstantArgumentFolder<
2639-
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2640-
ExtractSliceOpCastFolder>(context);
2620+
results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2621+
ExtractSliceOp, SliceReturnTypeCanonicalizer,
2622+
SliceCanonicalizer, /*CheckInBounds=*/true>,
2623+
ExtractSliceOpCastFolder>(context);
26412624
}
26422625

26432626
//
@@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
27752758
return produceSliceErrorMsg(result, *this, expectedType);
27762759

27772760
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
2778-
// to the source tensor.
2779-
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
2780-
getStaticSizes(), getStaticStrides());
2761+
// to the destination tensor.
2762+
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
2763+
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
2764+
getStaticStrides(), /*generateErrorMessage=*/true);
2765+
if (!boundsResult.isValid)
2766+
return getOperation()->emitError(boundsResult.errorMessage);
2767+
2768+
return success();
27812769
}
27822770

27832771
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
28722860
failed(foldDynamicStrideList(mixedStrides)))
28732861
return failure();
28742862

2863+
// Pattern does not apply if the produced op would not verify.
2864+
SliceBoundsVerificationResult sliceResult =
2865+
verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
2866+
mixedOffsets, mixedSizes, mixedStrides);
2867+
if (!sliceResult.isValid)
2868+
return failure();
2869+
28752870
// Create the new op in canonical form.
28762871
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
28772872
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
@@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
29692964
size = srcType.getDimSize(rankReducedIdx++);
29702965
}
29712966
}
2967+
2968+
// Pattern does not apply if the produced op would not verify.
29722969
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
29732970
staticSizes, insertSliceOp.getStaticStrides()) !=
29742971
SliceVerificationResult::Success)
29752972
return failure();
2973+
SliceBoundsVerificationResult sliceResult =
2974+
verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
2975+
mixedSizes, insertSliceOp.getMixedStrides());
2976+
if (!sliceResult.isValid)
2977+
return failure();
29762978

29772979
Operation *replacement = rewriter.create<InsertOpTy>(
29782980
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
@@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
38003802
return produceSliceErrorMsg(result, *this, expectedType);
38013803

38023804
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
3803-
// to the source tensor.
3804-
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
3805-
getStaticSizes(), getStaticStrides());
3805+
// to the destination tensor.
3806+
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
3807+
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
3808+
getStaticStrides(), /*generateErrorMessage=*/true);
3809+
if (!boundsResult.isValid)
3810+
return getOperation()->emitError(boundsResult.errorMessage);
3811+
3812+
return success();
38063813
}
38073814

38083815
void ParallelInsertSliceOp::getCanonicalizationPatterns(

mlir/lib/Interfaces/ViewLikeInterface.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
3636
return success();
3737
}
3838

39+
SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
40+
ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
41+
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
42+
bool generateErrorMessage) {
43+
SliceBoundsVerificationResult result;
44+
result.isValid = true;
45+
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
46+
// Nothing to verify for dynamic source dims.
47+
if (ShapedType::isDynamic(shape[i]))
48+
continue;
49+
// Nothing to verify if the offset is dynamic.
50+
if (ShapedType::isDynamic(staticOffsets[i]))
51+
continue;
52+
if (staticOffsets[i] >= shape[i]) {
53+
result.errorMessage =
54+
std::string("offset ") + std::to_string(i) +
55+
" is out-of-bounds: " + std::to_string(staticOffsets[i]) +
56+
" >= " + std::to_string(shape[i]);
57+
result.isValid = false;
58+
return result;
59+
}
60+
if (ShapedType::isDynamic(staticSizes[i]) ||
61+
ShapedType::isDynamic(staticStrides[i]))
62+
continue;
63+
int64_t lastPos =
64+
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
65+
if (lastPos >= shape[i]) {
66+
result.errorMessage = std::string("slice along dimension ") +
67+
std::to_string(i) +
68+
" runs out-of-bounds: " + std::to_string(lastPos) +
69+
" >= " + std::to_string(shape[i]);
70+
result.isValid = false;
71+
return result;
72+
}
73+
}
74+
return result;
75+
}
76+
77+
SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
78+
ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
79+
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
80+
bool generateErrorMessage) {
81+
auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
82+
SmallVector<int64_t> staticValues;
83+
for (OpFoldResult ofr : ofrs) {
84+
if (auto attr = dyn_cast<Attribute>(ofr)) {
85+
staticValues.push_back(cast<IntegerAttr>(attr).getInt());
86+
} else {
87+
staticValues.push_back(ShapedType::kDynamic);
88+
}
89+
}
90+
return staticValues;
91+
};
92+
return verifyInBoundsSlice(
93+
shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
94+
getStaticValues(mixedStrides), generateErrorMessage);
95+
}
96+
3997
LogicalResult
4098
mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
4199
std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1
582582

583583
// -----
584584

585+
// CHECK-LABEL: func @out_of_bounds_extract_slice
586+
// CHECK: tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
587+
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
588+
%c10 = arith.constant 10 : index
589+
%r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
590+
return %r : tensor<?xf32>
591+
}
592+
593+
// -----
594+
595+
// CHECK-LABEL: func @out_of_bounds_extract_slice
596+
// CHECK: tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
597+
func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
598+
%t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
599+
%r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
600+
return %r : tensor<10xf32>
601+
}
602+
603+
// -----
604+
605+
// CHECK-LABEL: func @out_of_bounds_insert_slice
606+
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
607+
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
608+
%c10 = arith.constant 10 : index
609+
%r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
610+
return %r : tensor<10xf32>
611+
}
612+
613+
// -----
614+
615+
// CHECK-LABEL: func @out_of_bounds_insert_slice
616+
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
617+
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
618+
%src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
619+
%r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
620+
return %r : tensor<10xf32>
621+
}
622+
623+
// -----
624+
625+
// CHECK-LABEL: func @out_of_bounds_insert_slice
626+
// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
627+
func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
628+
%dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
629+
%r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
630+
return %r : tensor<?xf32>
631+
}
632+
633+
// -----
634+
585635
// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
586636
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
587637
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>

0 commit comments

Comments
 (0)