27
27
#include " mlir/Interfaces/InferIntRangeInterface.h"
28
28
#include " mlir/Interfaces/LoopLikeInterface.h"
29
29
#include " mlir/Interfaces/Utils/InferIntRangeCommon.h"
30
+ #include " mlir/Interfaces/ViewLikeInterface.h"
30
31
#include " mlir/Support/LLVM.h"
31
32
#include " llvm/ADT/DenseSet.h"
32
33
#include " llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2352
2353
}
2353
2354
}
2354
2355
2355
- // / Verify that the offsets/sizes/strides-style access into the given tensor
2356
- // / is in-bounds. Only static information is verified.
2357
- static LogicalResult verifyInBoundsSlice (Operation *op,
2358
- RankedTensorType tensorType,
2359
- ArrayRef<int64_t > staticOffsets,
2360
- ArrayRef<int64_t > staticSizes,
2361
- ArrayRef<int64_t > staticStrides) {
2362
- for (int64_t i = 0 , e = tensorType.getRank (); i < e; ++i) {
2363
- // Nothing to verify for dynamic source dims.
2364
- if (tensorType.isDynamicDim (i))
2365
- continue ;
2366
- // Nothing to verify if the offset is dynamic.
2367
- if (ShapedType::isDynamic (staticOffsets[i]))
2368
- continue ;
2369
- if (staticOffsets[i] >= tensorType.getDimSize (i))
2370
- return op->emitOpError (" offset " )
2371
- << i << " is out-of-bounds: " << staticOffsets[i]
2372
- << " >= " << tensorType.getDimSize (i);
2373
- if (ShapedType::isDynamic (staticSizes[i]) ||
2374
- ShapedType::isDynamic (staticStrides[i]))
2375
- continue ;
2376
- int64_t lastPos =
2377
- staticOffsets[i] + (staticSizes[i] - 1 ) * staticStrides[i];
2378
- if (lastPos >= tensorType.getDimSize (i))
2379
- return op->emitOpError (" slice along dimension " )
2380
- << i << " runs out-of-bounds: " << lastPos
2381
- << " >= " << tensorType.getDimSize (i);
2382
- }
2383
- return success ();
2384
- }
2385
-
2386
2356
// / Verifier for ExtractSliceOp.
2387
2357
LogicalResult ExtractSliceOp::verify () {
2388
2358
RankedTensorType sourceType = getSourceType ();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
2396
2366
2397
2367
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
2398
2368
// to the source tensor.
2399
- return verifyInBoundsSlice (getOperation (), sourceType, getStaticOffsets (),
2400
- getStaticSizes (), getStaticStrides ());
2369
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice (
2370
+ sourceType.getShape (), getStaticOffsets (), getStaticSizes (),
2371
+ getStaticStrides (), /* generateErrorMessage=*/ true );
2372
+ if (!boundsResult.isValid )
2373
+ return getOperation ()->emitError (boundsResult.errorMessage );
2374
+
2375
+ return success ();
2401
2376
}
2402
2377
2403
2378
llvm::SmallBitVector ExtractSliceOp::getDroppedDims () {
@@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
2470
2445
if (!canFoldIntoConsumerOp (castOp))
2471
2446
return failure ();
2472
2447
2448
+ // Pattern does not apply if the produced op would not verify.
2449
+ SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice (
2450
+ cast<RankedTensorType>(castOp.getSource ().getType ()).getShape (),
2451
+ sliceOp.getStaticOffsets (), sliceOp.getStaticSizes (),
2452
+ sliceOp.getStaticStrides ());
2453
+ if (!sliceResult.isValid )
2454
+ return failure ();
2455
+
2473
2456
// Create folded extract.
2474
2457
Location loc = sliceOp.getLoc ();
2475
2458
Value newResult = rewriter.create <ExtractSliceOp>(
@@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
2634
2617
2635
2618
void ExtractSliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
2636
2619
MLIRContext *context) {
2637
- results.add <
2638
- OpWithOffsetSizesAndStridesConstantArgumentFolder<
2639
- ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2640
- ExtractSliceOpCastFolder>(context);
2620
+ results.add <OpWithOffsetSizesAndStridesConstantArgumentFolder<
2621
+ ExtractSliceOp, SliceReturnTypeCanonicalizer,
2622
+ SliceCanonicalizer, /* CheckInBounds= */ true >,
2623
+ ExtractSliceOpCastFolder>(context);
2641
2624
}
2642
2625
2643
2626
//
@@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
2775
2758
return produceSliceErrorMsg (result, *this , expectedType);
2776
2759
2777
2760
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
2778
- // to the source tensor.
2779
- return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
2780
- getStaticSizes (), getStaticStrides ());
2761
+ // to the destination tensor.
2762
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice (
2763
+ getDestType ().getShape (), getStaticOffsets (), getStaticSizes (),
2764
+ getStaticStrides (), /* generateErrorMessage=*/ true );
2765
+ if (!boundsResult.isValid )
2766
+ return getOperation ()->emitError (boundsResult.errorMessage );
2767
+
2768
+ return success ();
2781
2769
}
2782
2770
2783
2771
// / If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
2872
2860
failed (foldDynamicStrideList (mixedStrides)))
2873
2861
return failure ();
2874
2862
2863
+ // Pattern does not apply if the produced op would not verify.
2864
+ SliceBoundsVerificationResult sliceResult =
2865
+ verifyInBoundsSlice (insertSliceOp.getDest ().getType ().getShape (),
2866
+ mixedOffsets, mixedSizes, mixedStrides);
2867
+ if (!sliceResult.isValid )
2868
+ return failure ();
2869
+
2875
2870
// Create the new op in canonical form.
2876
2871
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType (
2877
2872
insertSliceOp.getSourceType ().getRank (), insertSliceOp.getDestType (),
@@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
2969
2964
size = srcType.getDimSize (rankReducedIdx++);
2970
2965
}
2971
2966
}
2967
+
2968
+ // Pattern does not apply if the produced op would not verify.
2972
2969
if (verifyInsertSliceOp (srcType, dstType, insertSliceOp.getStaticOffsets (),
2973
2970
staticSizes, insertSliceOp.getStaticStrides ()) !=
2974
2971
SliceVerificationResult::Success)
2975
2972
return failure ();
2973
+ SliceBoundsVerificationResult sliceResult =
2974
+ verifyInBoundsSlice (dstType.getShape (), insertSliceOp.getMixedOffsets (),
2975
+ mixedSizes, insertSliceOp.getMixedStrides ());
2976
+ if (!sliceResult.isValid )
2977
+ return failure ();
2976
2978
2977
2979
Operation *replacement = rewriter.create <InsertOpTy>(
2978
2980
insertSliceOp.getLoc (), src, dst, insertSliceOp.getMixedOffsets (),
@@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
3800
3802
return produceSliceErrorMsg (result, *this , expectedType);
3801
3803
3802
3804
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
3803
- // to the source tensor.
3804
- return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
3805
- getStaticSizes (), getStaticStrides ());
3805
+ // to the destination tensor.
3806
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice (
3807
+ getDestType ().getShape (), getStaticOffsets (), getStaticSizes (),
3808
+ getStaticStrides (), /* generateErrorMessage=*/ true );
3809
+ if (!boundsResult.isValid )
3810
+ return getOperation ()->emitError (boundsResult.errorMessage );
3811
+
3812
+ return success ();
3806
3813
}
3807
3814
3808
3815
void ParallelInsertSliceOp::getCanonicalizationPatterns (
0 commit comments