Skip to content

Commit 106cc32

Browse files
[mlir][Tensor] Check for out-of-bounds extraction in extract_slice verifier
1 parent f3390fc commit 106cc32

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,11 +2354,44 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
23542354

23552355
/// Verifier for ExtractSliceOp.
23562356
LogicalResult ExtractSliceOp::verify() {
2357+
RankedTensorType sourceType = getSourceType();
2358+
SmallVector<OpFoldResult> mixedOffsets = getMixedOffsets();
2359+
SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
2360+
SmallVector<OpFoldResult> mixedStrides = getMixedStrides();
2361+
23572362
// Verify result type against inferred type.
23582363
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2359-
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2364+
sourceType, mixedOffsets, mixedSizes, mixedStrides);
23602365
SliceVerificationResult result = isRankReducedType(expectedType, getType());
2361-
return produceSliceErrorMsg(result, *this, expectedType);
2366+
if (result != SliceVerificationResult::Success)
2367+
return produceSliceErrorMsg(result, *this, expectedType);
2368+
2369+
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
2370+
// to the source tensor.
2371+
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
2372+
// Nothing to verify for dynamic source dims.
2373+
if (sourceType.isDynamicDim(i))
2374+
continue;
2375+
auto offsetOfr = dyn_cast<Attribute>(mixedOffsets[i]);
2376+
// Nothing to verify if the offset is dynamic.
2377+
if (!offsetOfr)
2378+
continue;
2379+
int64_t staticOffset = *getConstantIntValue(offsetOfr);
2380+
if (staticOffset >= sourceType.getDimSize(i))
2381+
return emitOpError("offset ") << i << " is out-of-bounds";
2382+
auto sizeOfr = dyn_cast<Attribute>(mixedSizes[i]);
2383+
auto strideOfr = dyn_cast<Attribute>(mixedStrides[i]);
2384+
if (!sizeOfr || !strideOfr)
2385+
continue;
2386+
int64_t staticSize = *getConstantIntValue(sizeOfr);
2387+
int64_t staticStride = *getConstantIntValue(strideOfr);
2388+
if (staticOffset + (staticSize - 1) * staticStride >=
2389+
sourceType.getDimSize(i))
2390+
return emitOpError("extraction along source dimension ")
2391+
<< i << " runs out-of-bounds";
2392+
}
2393+
2394+
return success();
23622395
}
23632396

23642397
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {

mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func.func @no_bubble_up_extract_slice_on_non_contiguous(%src: tensor<60xf32>) ->
3131

3232
func.func @no_bubble_up_extract_slice_on_stride(%src: tensor<60xf32>) -> tensor<1x1x5xf32> {
3333
%expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
34-
%extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 2] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
34+
%extract = tensor.extract_slice %expand[0, 0, 0][1, 1, 5][1, 1, 2] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
3535
return %extract : tensor<1x1x5xf32>
3636
}
3737

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,22 @@ func.func @illegal_num_offsets(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 :
258258

259259
// -----
260260

261+
func.func @extract_slice_offset_out_of_bounds(%arg0: tensor<10xf32>) {
262+
// expected-error@+1 {{offset 0 is out-of-bounds}}
263+
%0 = tensor.extract_slice %arg0 [10][1][1] : tensor<10xf32> to tensor<1xf32>
264+
return
265+
}
266+
267+
// -----
268+
269+
func.func @extract_slice_runs_out_of_bounds(%arg0: tensor<9xf32>) {
270+
// expected-error@+1 {{extraction along source dimension 0 runs out-of-bounds}}
271+
%0 = tensor.extract_slice %arg0 [3][4][2] : tensor<9xf32> to tensor<4xf32>
272+
return
273+
}
274+
275+
// -----
276+
261277
func.func @insert_slice_wrong_result_rank(%t1: tensor<?xf32>, %t2: tensor<?x?xf32>, %idx : index) {
262278
// expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
263279
%0 = tensor.insert_slice %t2 into %t1[0][4][1] : tensor<?x?xf32> into tensor<?xf32>

0 commit comments

Comments
 (0)