@@ -2354,11 +2354,44 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
2354
2354
2355
2355
// / Verifier for ExtractSliceOp.
2356
2356
LogicalResult ExtractSliceOp::verify () {
2357
+ RankedTensorType sourceType = getSourceType ();
2358
+ SmallVector<OpFoldResult> mixedOffsets = getMixedOffsets ();
2359
+ SmallVector<OpFoldResult> mixedSizes = getMixedSizes ();
2360
+ SmallVector<OpFoldResult> mixedStrides = getMixedStrides ();
2361
+
2357
2362
// Verify result type against inferred type.
2358
2363
RankedTensorType expectedType = ExtractSliceOp::inferResultType (
2359
- getSourceType (), getMixedOffsets (), getMixedSizes (), getMixedStrides () );
2364
+ sourceType, mixedOffsets, mixedSizes, mixedStrides );
2360
2365
SliceVerificationResult result = isRankReducedType (expectedType, getType ());
2361
- return produceSliceErrorMsg (result, *this , expectedType);
2366
+ if (result != SliceVerificationResult::Success)
2367
+ return produceSliceErrorMsg (result, *this , expectedType);
2368
+
2369
+ // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2370
+ // to the source tensor.
2371
+ for (int64_t i = 0 , e = sourceType.getRank (); i < e; ++i) {
2372
+ // Nothing to verify for dynamic source dims.
2373
+ if (sourceType.isDynamicDim (i))
2374
+ continue ;
2375
+ auto offsetOfr = dyn_cast<Attribute>(mixedOffsets[i]);
2376
+ // Nothing to verify if the offset is dynamic.
2377
+ if (!offsetOfr)
2378
+ continue ;
2379
+ int64_t staticOffset = *getConstantIntValue (offsetOfr);
2380
+ if (staticOffset >= sourceType.getDimSize (i))
2381
+ return emitOpError (" offset " ) << i << " is out-of-bounds" ;
2382
+ auto sizeOfr = dyn_cast<Attribute>(mixedSizes[i]);
2383
+ auto strideOfr = dyn_cast<Attribute>(mixedStrides[i]);
2384
+ if (!sizeOfr || !strideOfr)
2385
+ continue ;
2386
+ int64_t staticSize = *getConstantIntValue (sizeOfr);
2387
+ int64_t staticStride = *getConstantIntValue (strideOfr);
2388
+ if (staticOffset + (staticSize - 1 ) * staticStride >=
2389
+ sourceType.getDimSize (i))
2390
+ return emitOpError (" extraction along source dimension " )
2391
+ << i << " runs out-of-bounds" ;
2392
+ }
2393
+
2394
+ return success ();
2362
2395
}
2363
2396
2364
2397
llvm::SmallBitVector ExtractSliceOp::getDroppedDims () {
0 commit comments