Skip to content

Commit fbbb33f

Browse files
authored
[mlir] Fix crash when verifying linalg.transpose (#131733)
Adds checks in `isPermutationVector` for indices that are outside of the bounds and removes the assert. Signed-off-by: Ian Wood <[email protected]>
1 parent d039af3 commit fbbb33f

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,10 @@ bool mlir::isIdentityPermutation(ArrayRef<int64_t> permutation) {
220220
}
221221

222222
bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
223-
assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
224-
"permutation must be non-negative");
225223
llvm::SmallDenseSet<int64_t, 4> seenVals;
226224
for (auto val : interchange) {
225+
if (val < 0 || static_cast<uint64_t>(val) >= interchange.size())
226+
return false;
227227
if (seenVals.count(val))
228228
return false;
229229
seenVals.insert(val);

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,28 @@ func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>,
949949

950950
// -----
951951

952+
func.func @transpose_out_of_range_permutation(%input: tensor<16x32x64xf32>,
953+
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
954+
// expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
955+
%transpose = linalg.transpose
956+
ins(%input:tensor<16x32x64xf32>)
957+
outs(%init:tensor<32x64x16xf32>)
958+
permutation = [1, 2, 3]
959+
func.return %transpose : tensor<32x64x16xf32>
960+
}
961+
962+
// -----
963+
964+
func.func @transpose_negative_permutation(%input: tensor<16x32x64xf32>,
965+
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
966+
// expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
967+
%transpose = linalg.transpose
968+
ins(%input:tensor<16x32x64xf32>)
969+
outs(%init:tensor<32x64x16xf32>)
970+
permutation = [1, 2, -1]
971+
func.return %transpose : tensor<32x64x16xf32>
972+
}
973+
// -----
952974
func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>,
953975
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
954976
// expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}}

mlir/test/Dialect/Linalg/transform-op-pack.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ module attributes {transform.with_named_sequence} {
595595
%unpack = transform.get_consumers_of_result %1[0]
596596
: (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.unpack">)
597597
%2, %pack_2, %unpack_2 =
598-
// expected-error @below {{invalid outer_perm}}
598+
// expected-error @below {{"outer_perm" is not a valid permutation}}
599599
transform.structured.pack_transpose %unpack with_compute_op(%1)
600600
outer_perm = [1]
601601
: (!transform.op<"linalg.unpack">, !transform.op<"linalg.generic">)
@@ -623,7 +623,7 @@ module attributes {transform.with_named_sequence} {
623623
%unpack = transform.get_consumers_of_result %1[0]
624624
: (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.unpack">)
625625
%2, %pack_2, %unpack_2 =
626-
// expected-error @below {{invalid inner_perm}}
626+
// expected-error @below {{"inner_perm" is not a valid permutation}}
627627
transform.structured.pack_transpose %unpack with_compute_op(%1)
628628
inner_perm = [1]
629629
: (!transform.op<"linalg.unpack">, !transform.op<"linalg.generic">)

0 commit comments

Comments
 (0)