Skip to content

Commit c8b5d30

Browse files
[mlir][Transforms] Add missing check in tosa::transpose::verify() (#102099)
The tosa::transpose::verify() should make sure that the permutation numbers are within the size of the input array. Otherwise it will cause a cryptic array out of bound assertion later.Fix #99513.
1 parent 5a42a67 commit c8b5d30

File tree

4 files changed

+45
-27
lines changed

4 files changed

+45
-27
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ SmallVector<T> applyPermutation(ArrayRef<T> input,
202202
ArrayRef<int64_t> permutation) {
203203
assert(input.size() == permutation.size() &&
204204
"expected input rank to equal permutation rank");
205+
assert(
206+
llvm::all_of(permutation, [&](size_t s) { return s < input.size(); }) &&
207+
"permutation must be within input bounds");
205208
auto permutationRange = llvm::map_range(
206209
llvm::seq<unsigned>(0, input.size()),
207210
[&](int64_t idx) -> T { return input[permutation[idx]]; });

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,12 @@ LogicalResult tosa::TransposeOp::verify() {
11161116
"Unexpectedly found permutation tensor without rank");
11171117
if (!isPermutationVector(constantPerms))
11181118
return emitOpError() << "expected valid permutation tensor";
1119+
1120+
if (inputType.hasRank() && !llvm::all_of(constantPerms, [&](int64_t s) {
1121+
return s < inputType.getRank();
1122+
})) {
1123+
return emitOpError() << "permutation must be within input bounds";
1124+
}
11191125
}
11201126
return success();
11211127
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,38 @@ func.func @test_tile_invalid_multiples() {
413413
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
414414
return
415415
}
416+
417+
// -----
418+
419+
// CHECK-LABEL: @test_invalid_constant_permutation
420+
func.func @test_invalid_constant_permutation() {
421+
// expected-error@+3 {{permutation must be within input bounds}}
422+
%0 = tensor.empty() : tensor<3x4x5xi32>
423+
%1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
424+
%2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
425+
return
426+
}
427+
428+
// -----
429+
430+
// CHECK-LABEL: test_rank_size_constant_permutation
431+
func.func @test_rank_size_constant_permutation() {
432+
// expected-error@+4 {{permutation must be within input bounds}}
433+
%0 = arith.constant 6 : index
434+
%1 = arith.constant dense<[0, 2]> : tensor<2xi32>
435+
%2 = tensor.empty(%0) : tensor<?x27xi64>
436+
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
437+
return
438+
}
439+
440+
// -----
441+
442+
// CHECK-LABEL: test_large_constant_permutation
443+
func.func @test_large_constant_permutation() {
444+
// expected-error@+4 {{permutation must be within input bounds}}
445+
%0 = arith.constant 6 : index
446+
%1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
447+
%2 = tensor.empty(%0) : tensor<?x27xi64>
448+
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
449+
return
450+
}

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,30 +1373,4 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
13731373
// CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
13741374
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
13751375
return %1 : tensor<?x16x16x16xf32>
1376-
}
1377-
1378-
// -----
1379-
1380-
// CHECK-LABEL: test_rank_size_constant_permutation
1381-
func.func @test_rank_size_constant_permutation() {
1382-
%c6 = arith.constant 6 : index
1383-
%cst_26 = arith.constant dense<[0, 2]> : tensor<2xi32>
1384-
%14 = tensor.empty(%c6) : tensor<?x27xi64>
1385-
// Fail to infer the shape but not crash.
1386-
// CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
1387-
%72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
1388-
return
1389-
}
1390-
1391-
// -----
1392-
1393-
// CHECK-LABEL: test_large_constant_permutation
1394-
func.func @test_large_constant_permutation() {
1395-
%c6 = arith.constant 6 : index
1396-
%cst_26 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
1397-
%14 = tensor.empty(%c6) : tensor<?x27xi64>
1398-
// Fail to infer the shape but not crash.
1399-
// CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
1400-
%72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
1401-
return
1402-
}
1376+
}

0 commit comments

Comments
 (0)