Skip to content

Commit 418e07b

Browse files
[mlir][Tensor] Check for out-of-bounds slice in insert/extract_slice verifier (#130487)
Also fix test cases that had invalid ops.
1 parent c86d884 commit 418e07b

File tree

5 files changed

+100
-19
lines changed

5 files changed

+100
-19
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,13 +2352,52 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
23522352
}
23532353
}
23542354

2355+
/// Verify that the offsets/sizes/strides-style access into the given tensor
2356+
/// is in-bounds. Only static information is verified.
2357+
static LogicalResult verifyInBoundsSlice(Operation *op,
2358+
RankedTensorType tensorType,
2359+
ArrayRef<int64_t> staticOffsets,
2360+
ArrayRef<int64_t> staticSizes,
2361+
ArrayRef<int64_t> staticStrides) {
2362+
for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
2363+
// Nothing to verify for dynamic source dims.
2364+
if (tensorType.isDynamicDim(i))
2365+
continue;
2366+
// Nothing to verify if the offset is dynamic.
2367+
if (ShapedType::isDynamic(staticOffsets[i]))
2368+
continue;
2369+
if (staticOffsets[i] >= tensorType.getDimSize(i))
2370+
return op->emitOpError("offset ")
2371+
<< i << " is out-of-bounds: " << staticOffsets[i]
2372+
<< " >= " << tensorType.getDimSize(i);
2373+
if (ShapedType::isDynamic(staticSizes[i]) ||
2374+
ShapedType::isDynamic(staticStrides[i]))
2375+
continue;
2376+
int64_t lastPos =
2377+
staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
2378+
if (lastPos >= tensorType.getDimSize(i))
2379+
return op->emitOpError("slice along dimension ")
2380+
<< i << " runs out-of-bounds: " << lastPos
2381+
<< " >= " << tensorType.getDimSize(i);
2382+
}
2383+
return success();
2384+
}
2385+
23552386
/// Verifier for ExtractSliceOp.
23562387
LogicalResult ExtractSliceOp::verify() {
2388+
RankedTensorType sourceType = getSourceType();
2389+
23572390
// Verify result type against inferred type.
23582391
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2359-
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
2392+
sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
23602393
SliceVerificationResult result = isRankReducedType(expectedType, getType());
2361-
return produceSliceErrorMsg(result, *this, expectedType);
2394+
if (result != SliceVerificationResult::Success)
2395+
return produceSliceErrorMsg(result, *this, expectedType);
2396+
2397+
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
2398+
// to the source tensor.
2399+
return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
2400+
getStaticSizes(), getStaticStrides());
23622401
}
23632402

23642403
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@@ -2729,11 +2768,18 @@ static SliceVerificationResult verifyInsertSliceOp(
27292768

27302769
/// Verifier for InsertSliceOp.
27312770
LogicalResult InsertSliceOp::verify() {
2771+
// Verify result type against inferred type.
27322772
RankedTensorType expectedType;
27332773
SliceVerificationResult result =
27342774
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
27352775
getStaticSizes(), getStaticStrides(), &expectedType);
2736-
return produceSliceErrorMsg(result, *this, expectedType);
2776+
if (result != SliceVerificationResult::Success)
2777+
return produceSliceErrorMsg(result, *this, expectedType);
2778+
2779+
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
2780+
// to the source tensor.
2781+
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
2782+
getStaticSizes(), getStaticStrides());
27372783
}
27382784

27392785
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -3747,11 +3793,18 @@ LogicalResult ParallelInsertSliceOp::verify() {
37473793
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
37483794
<< *(getOperation()->getParentOp());
37493795

3796+
// Verify result type against inferred type.
37503797
RankedTensorType expectedType;
37513798
SliceVerificationResult result =
37523799
verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
37533800
getStaticSizes(), getStaticStrides(), &expectedType);
3754-
return produceSliceErrorMsg(result, *this, expectedType);
3801+
if (result != SliceVerificationResult::Success)
3802+
return produceSliceErrorMsg(result, *this, expectedType);
3803+
3804+
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
3805+
// to the source tensor.
3806+
return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
3807+
getStaticSizes(), getStaticStrides());
37553808
}
37563809

37573810
void ParallelInsertSliceOp::getCanonicalizationPatterns(

mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func.func @no_bubble_up_extract_slice_on_non_contiguous(%src: tensor<60xf32>) ->
3131

3232
func.func @no_bubble_up_extract_slice_on_stride(%src: tensor<60xf32>) -> tensor<1x1x5xf32> {
3333
%expand = tensor.expand_shape %src [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
34-
%extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 2] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
34+
%extract = tensor.extract_slice %expand[0, 0, 0][1, 1, 5][1, 1, 2] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
3535
return %extract : tensor<1x1x5xf32>
3636
}
3737

mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8
5151

5252
// -----
5353

54-
func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
54+
func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x15x15xf32>) -> tensor<1x15x15xf32> {
5555
%extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
56-
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
57-
return %inserted_slice : tensor<1x4x4xf32>
56+
%inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x15x15xf32>
57+
return %inserted_slice : tensor<1x15x15xf32>
5858
}
5959
// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
6060
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
6161
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
6262
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
63-
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
63+
// CHECK: return %[[INSERTED_SLICE]] : tensor<1x15x15xf32>
6464

6565
// -----
6666

mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -271,38 +271,34 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
271271

272272
// -----
273273

274-
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
275274
// CHECK-LABEL: func @insert_slice_of_insert_slice(
276275
// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<f32>
277276
// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<1x14xf32>
278277
// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
279-
// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
280-
// CHECK: tensor.insert_slice %[[t]] into %[[r1]][4, %[[add]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
278+
// CHECK: tensor.insert_slice %[[t]] into %[[r1]][0, %[[pos]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
281279
func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1x1xf32>, %r1: tensor<1x14xf32>, %pos: index)
282280
-> tensor<1x14xf32>
283281
{
284-
%0 = tensor.insert_slice %t into %r0[1, 2] [1, 1] [1, 1]
282+
%0 = tensor.insert_slice %t into %r0[0, 0] [1, 1] [1, 1]
285283
: tensor<f32> into tensor<1x1xf32>
286-
%1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1]
284+
%1 = tensor.insert_slice %0 into %r1[0, %pos] [1, 1] [1, 1]
287285
: tensor<1x1xf32> into tensor<1x14xf32>
288286
return %1 : tensor<1x14xf32>
289287
}
290288

291289
// -----
292290

293-
// CHECK-DAG: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
294291
// CHECK-LABEL: func @insert_slice_of_insert_slice(
295292
// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<f32>
296293
// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<1x14xf32>
297294
// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
298-
// CHECK: %[[composed_pos:.+]] = affine.apply #[[$map]]()[%[[pos]]]
299-
// CHECK: tensor.insert_slice %[[t]] into %[[r1]][3, %[[composed_pos]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
295+
// CHECK: tensor.insert_slice %[[t]] into %[[r1]][0, %[[pos]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
300296
func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1xf32>, %r1: tensor<1x14xf32>, %pos: index)
301297
-> tensor<1x14xf32>
302298
{
303-
%0 = tensor.insert_slice %t into %r0[2] [1] [1]
299+
%0 = tensor.insert_slice %t into %r0[0] [1] [1]
304300
: tensor<f32> into tensor<1xf32>
305-
%1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1]
301+
%1 = tensor.insert_slice %0 into %r1[0, %pos] [1, 1] [1, 1]
306302
: tensor<1xf32> into tensor<1x14xf32>
307303
return %1 : tensor<1x14xf32>
308304
}

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,22 @@ func.func @illegal_num_offsets(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 :
258258

259259
// -----
260260

261+
func.func @extract_slice_offset_out_of_bounds(%arg0: tensor<10xf32>) {
262+
// expected-error@+1 {{offset 0 is out-of-bounds: 10 >= 10}}
263+
%0 = tensor.extract_slice %arg0 [10][1][1] : tensor<10xf32> to tensor<1xf32>
264+
return
265+
}
266+
267+
// -----
268+
269+
func.func @extract_slice_runs_out_of_bounds(%arg0: tensor<9xf32>) {
270+
// expected-error@+1 {{slice along dimension 0 runs out-of-bounds: 9 >= 9}}
271+
%0 = tensor.extract_slice %arg0 [3][4][2] : tensor<9xf32> to tensor<4xf32>
272+
return
273+
}
274+
275+
// -----
276+
261277
func.func @insert_slice_wrong_result_rank(%t1: tensor<?xf32>, %t2: tensor<?x?xf32>, %idx : index) {
262278
// expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
263279
%0 = tensor.insert_slice %t2 into %t1[0][4][1] : tensor<?x?xf32> into tensor<?xf32>
@@ -296,6 +312,22 @@ func.func @insert_slice_wrong_dynamic_type(%t1: tensor<?x4x4xf32>, %t2: tensor<8
296312

297313
// -----
298314

315+
func.func @insert_slice_offset_out_of_bounds(%arg0: tensor<1xf32>, %arg1: tensor<10xf32>) {
316+
// expected-error@+1 {{offset 0 is out-of-bounds: 10 >= 10}}
317+
%0 = tensor.insert_slice %arg0 into %arg1 [10][1][1] : tensor<1xf32> into tensor<10xf32>
318+
return
319+
}
320+
321+
// -----
322+
323+
func.func @insert_slice_runs_out_of_bounds(%arg0: tensor<4xf32>, %arg1: tensor<9xf32>) {
324+
// expected-error@+1 {{slice along dimension 0 runs out-of-bounds: 9 >= 9}}
325+
%0 = tensor.insert_slice %arg0 into %arg1 [3][4][2] : tensor<4xf32> into tensor<9xf32>
326+
return
327+
}
328+
329+
// -----
330+
299331
func.func @illegal_expanding_reshape_static_tensor
300332
(%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> {
301333
// expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}

0 commit comments

Comments
 (0)