Skip to content

Commit 8190369

Browse files
authored
[mlir][tosa] Add verifier for tosa.transpose (#75376)
This patch adds a verifier to `tosa.transpose` which fixes a crash. Related: #74367 Fix #74479
1 parent fc3adf7 commit 8190369

File tree

4 files changed

+85
-12
lines changed

4 files changed

+85
-12
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
16781678

16791679
let hasCanonicalizer = 1;
16801680
let hasFolder = 1;
1681+
let hasVerifier = 1;
16811682
}
16821683

16831684
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1818
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
1919
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/DialectImplementation.h"
2223
#include "mlir/IR/Matchers.h"
@@ -1054,6 +1055,46 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
10541055
return success();
10551056
}
10561057

1058+
LogicalResult tosa::TransposeOp::verify() {
1059+
TensorType inputType = getInput1().getType();
1060+
TensorType permType = getPerms().getType();
1061+
TensorType outputType = getOutput().getType();
1062+
1063+
if (permType.hasRank() && permType.getRank() != 1)
1064+
return emitOpError()
1065+
<< "expected permutation tensor to be rank 1 but got rank "
1066+
<< permType.getRank();
1067+
if (inputType.hasRank() && permType.hasRank())
1068+
if (!permType.isDynamicDim(0) &&
1069+
permType.getDimSize(0) != inputType.getRank())
1070+
return emitOpError() << "expected permutation tensor dim 0 to have size "
1071+
<< inputType.getRank()
1072+
<< " (input rank) but got size "
1073+
<< permType.getDimSize(0);
1074+
if (inputType.hasRank() && outputType.hasRank() &&
1075+
inputType.getRank() != outputType.getRank())
1076+
return emitOpError()
1077+
<< "expected input tensor rank to equal result tensor rank";
1078+
if (outputType.hasRank() && permType.hasRank())
1079+
if (!permType.isDynamicDim(0) &&
1080+
permType.getDimSize(0) != outputType.getRank())
1081+
return emitOpError() << "expected permutation tensor dim 0 to have size "
1082+
<< outputType.getRank()
1083+
<< " (output rank) but got size "
1084+
<< permType.getDimSize(0);
1085+
1086+
SmallVector<int64_t> constantPerms;
1087+
if (succeeded(getConstantPerms(constantPerms))) {
1088+
// Assert that the permutation tensor has a rank, which means that the rank
1089+
// has been verified above.
1090+
assert(permType.hasRank() &&
1091+
"Unexpectedly found permutation tensor without rank");
1092+
if (!isPermutationVector(constantPerms))
1093+
return emitOpError() << "expected valid permutation tensor";
1094+
}
1095+
return success();
1096+
}
1097+
10571098
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
10581099
MLIRContext *context, ::std::optional<Location> location,
10591100
GatherOp::Adaptor adaptor,

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,49 @@ func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x
8080

8181
// -----
8282

83+
func.func @test_transpose_io_rank_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21x1xf32> {
84+
// expected-error@+1 {{'tosa.transpose' op expected input tensor rank to equal result tensor rank}}
85+
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21x1xf32>
86+
return %0 : tensor<3x13x21x1xf32>
87+
}
88+
89+
// -----
90+
91+
func.func @test_transpose_invalid_perms_rank(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<3x13x21xf32> {
92+
// expected-error@+1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 2}}
93+
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<3x13x21xf32>
94+
return %0 : tensor<3x13x21xf32>
95+
}
96+
97+
// -----
98+
99+
func.func @test_transpose_rank0_perms() {
100+
%14 = tensor.empty() : tensor<5x27xi64>
101+
%cst = tensor.empty() : tensor<i32>
102+
// expected-error@+1 {{'tosa.transpose' op expected permutation tensor to be rank 1 but got rank 0}}
103+
%72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
104+
return
105+
}
106+
107+
// -----
108+
109+
func.func @test_transpose_invalid_perms_size(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7xi32>) -> tensor<3x13x21xf32> {
110+
// expected-error@+1 {{'tosa.transpose' op expected permutation tensor dim 0 to have size 3 (input rank) but got size 7}}
111+
%0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<7xi32>) -> tensor<3x13x21xf32>
112+
return %0 : tensor<3x13x21xf32>
113+
}
114+
115+
// -----
116+
117+
func.func @test_transpose_invalid_permutation_tensor(%arg0: tensor<13x21x3xf32>) -> tensor<?x?x?xf32> {
118+
%perms = arith.constant dense<[2, 0, 0]> : tensor<3xi32>
119+
// expected-error@+1 {{'tosa.transpose' op expected valid permutation tensor}}
120+
%0 = tosa.transpose %arg0, %perms : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
121+
return %0 : tensor<?x?x?xf32>
122+
}
123+
124+
// -----
125+
83126
func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
84127
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
85128
%1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,15 +1309,3 @@ func.func @test_large_constant_permutation() {
13091309
%72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
13101310
return
13111311
}
1312-
1313-
// -----
1314-
1315-
// CHECK-LABEL: test_rank0_transpose_perms
1316-
// Fail to infer the shape but not crash.
1317-
func.func @test_rank0_transpose_perms() {
1318-
%14 = tensor.empty() : tensor<5x27xi64>
1319-
%cst = tensor.empty() : tensor<i32>
1320-
// CHECK: tosa.transpose
1321-
%72 = tosa.transpose %14, %cst : (tensor<5x27xi64>, tensor<i32>) -> tensor<?x?xi64>
1322-
return
1323-
}

0 commit comments

Comments
 (0)