Skip to content

Commit 721105e

Browse files
[mlir][Transforms] Add missing check in applyPermutation
The applyPermutation() utility should make sure that the permutation numbers are within the size of the input array. Otherwise it will cause a cryptic array out of bound assertion later.
1 parent 1745c8e commit 721105e

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ SmallVector<T> applyPermutation(ArrayRef<T> input,
202202
ArrayRef<int64_t> permutation) {
203203
assert(input.size() == permutation.size() &&
204204
"expected input rank to equal permutation rank");
205+
assert(
206+
llvm::all_of(permutation, [&](size_t s) { return s < input.size(); }) &&
207+
"permutation must be within input bounds");
205208
auto permutationRange = llvm::map_range(
206209
llvm::seq<unsigned>(0, input.size()),
207210
[&](int64_t idx) -> T { return input[permutation[idx]]; });

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
#include "mlir/Transforms/InliningUtils.h"
2929
#include "llvm/ADT/APFloat.h"
3030
#include "llvm/ADT/DenseMap.h"
31+
#include "llvm/ADT/STLExtras.h"
3132
#include "llvm/ADT/TypeSwitch.h"
33+
#include <sys/_types/_int64_t.h>
3234

3335
using namespace mlir;
3436
using namespace mlir::tosa;
@@ -1116,6 +1118,12 @@ LogicalResult tosa::TransposeOp::verify() {
11161118
"Unexpectedly found permutation tensor without rank");
11171119
if (!isPermutationVector(constantPerms))
11181120
return emitOpError() << "expected valid permutation tensor";
1121+
1122+
if (inputType.hasRank() && (!inputType.getNumDynamicDims()) &&
1123+
!llvm::all_of(constantPerms,
1124+
[&](int64_t s) { return s < inputType.getRank(); })) {
1125+
return emitOpError() << "permutation must be within input bounds";
1126+
}
11191127
}
11201128
return success();
11211129
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,14 @@ func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x
3636
%real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
3737
return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
3838
}
39+
40+
// -----
41+
42+
// CHECK-LABEL: @test_invalid_constant_permutation
43+
func.func @test_invalid_constant_permutation() {
44+
// expected-error@+3 {{permutation must be within input bounds}}
45+
%14 = tensor.empty() : tensor<3x4x5xi32>
46+
%c1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
47+
%72 = tosa.transpose %14, %c1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
48+
return
49+
}

0 commit comments

Comments
 (0)