@@ -2352,13 +2352,52 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2352
2352
}
2353
2353
}
2354
2354
2355
+ // / Verify that the offsets/sizes/strides-style access into the given tensor
2356
+ // / is in-bounds. Only static information is verified.
2357
+ static LogicalResult verifyInBoundsSlice (Operation *op,
2358
+ RankedTensorType tensorType,
2359
+ ArrayRef<int64_t > staticOffsets,
2360
+ ArrayRef<int64_t > staticSizes,
2361
+ ArrayRef<int64_t > staticStrides) {
2362
+ for (int64_t i = 0 , e = tensorType.getRank (); i < e; ++i) {
2363
+ // Nothing to verify for dynamic source dims.
2364
+ if (tensorType.isDynamicDim (i))
2365
+ continue ;
2366
+ // Nothing to verify if the offset is dynamic.
2367
+ if (ShapedType::isDynamic (staticOffsets[i]))
2368
+ continue ;
2369
+ if (staticOffsets[i] >= tensorType.getDimSize (i))
2370
+ return op->emitOpError (" offset " )
2371
+ << i << " is out-of-bounds: " << staticOffsets[i]
2372
+ << " >= " << tensorType.getDimSize (i);
2373
+ if (ShapedType::isDynamic (staticSizes[i]) ||
2374
+ ShapedType::isDynamic (staticStrides[i]))
2375
+ continue ;
2376
+ int64_t lastPos =
2377
+ staticOffsets[i] + (staticSizes[i] - 1 ) * staticStrides[i];
2378
+ if (lastPos >= tensorType.getDimSize (i))
2379
+ return op->emitOpError (" slice along dimension " )
2380
+ << i << " runs out-of-bounds: " << lastPos
2381
+ << " >= " << tensorType.getDimSize (i);
2382
+ }
2383
+ return success ();
2384
+ }
2385
+
2355
2386
// / Verifier for ExtractSliceOp.
2356
2387
LogicalResult ExtractSliceOp::verify () {
2388
+ RankedTensorType sourceType = getSourceType ();
2389
+
2357
2390
// Verify result type against inferred type.
2358
2391
RankedTensorType expectedType = ExtractSliceOp::inferResultType (
2359
- getSourceType () , getMixedOffsets (), getMixedSizes (), getMixedStrides ());
2392
+ sourceType , getMixedOffsets (), getMixedSizes (), getMixedStrides ());
2360
2393
SliceVerificationResult result = isRankReducedType (expectedType, getType ());
2361
- return produceSliceErrorMsg (result, *this , expectedType);
2394
+ if (result != SliceVerificationResult::Success)
2395
+ return produceSliceErrorMsg (result, *this , expectedType);
2396
+
2397
+ // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2398
+ // to the source tensor.
2399
+ return verifyInBoundsSlice (getOperation (), sourceType, getStaticOffsets (),
2400
+ getStaticSizes (), getStaticStrides ());
2362
2401
}
2363
2402
2364
2403
llvm::SmallBitVector ExtractSliceOp::getDroppedDims () {
@@ -2729,11 +2768,18 @@ static SliceVerificationResult verifyInsertSliceOp(
2729
2768
2730
2769
// / Verifier for InsertSliceOp.
2731
2770
LogicalResult InsertSliceOp::verify () {
2771
+ // Verify result type against inferred type.
2732
2772
RankedTensorType expectedType;
2733
2773
SliceVerificationResult result =
2734
2774
verifyInsertSliceOp (getSourceType (), getType (), getStaticOffsets (),
2735
2775
getStaticSizes (), getStaticStrides (), &expectedType);
2736
- return produceSliceErrorMsg (result, *this , expectedType);
2776
+ if (result != SliceVerificationResult::Success)
2777
+ return produceSliceErrorMsg (result, *this , expectedType);
2778
+
2779
+ // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2780
+ // to the source tensor.
2781
+ return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
2782
+ getStaticSizes (), getStaticStrides ());
2737
2783
}
2738
2784
2739
2785
// / If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -3747,11 +3793,18 @@ LogicalResult ParallelInsertSliceOp::verify() {
3747
3793
return this ->emitError (" expected ParallelCombiningOpInterface parent, got:" )
3748
3794
<< *(getOperation ()->getParentOp ());
3749
3795
3796
+ // Verify result type against inferred type.
3750
3797
RankedTensorType expectedType;
3751
3798
SliceVerificationResult result =
3752
3799
verifyInsertSliceOp (getSourceType (), getDestType (), getStaticOffsets (),
3753
3800
getStaticSizes (), getStaticStrides (), &expectedType);
3754
- return produceSliceErrorMsg (result, *this , expectedType);
3801
+ if (result != SliceVerificationResult::Success)
3802
+ return produceSliceErrorMsg (result, *this , expectedType);
3803
+
3804
+ // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3805
+ // to the source tensor.
3806
+ return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
3807
+ getStaticSizes (), getStaticStrides ());
3755
3808
}
3756
3809
3757
3810
void ParallelInsertSliceOp::getCanonicalizationPatterns (
0 commit comments