Skip to content

Commit fbcd0c6

Browse files
authored
Updates to 'tosa.reshape' verifier (#87416)
This addition catches common cases of malformed `tosa.reshape` ops. This prevents the `--tosa-to-tensor` pass from asserting when fed invalid operations, as these will be caught ahead of time by the verifier. Closes #87396
1 parent cd29126 commit fbcd0c6

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -955,25 +955,34 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
955955
}
956956

957957
mlir::LogicalResult tosa::ReshapeOp::verify() {
958-
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
959-
ShapedType outputType = llvm::cast<ShapedType>(getType());
958+
TensorType inputType = getInput1().getType();
959+
RankedTensorType outputType = getType();
960960

961961
if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
962962
return emitOpError() << "tensor has a dimension with size zero. Each "
963963
"dimension of a tensor must have size >= 1";
964964

965+
if ((int64_t) getNewShape().size() != outputType.getRank())
966+
return emitOpError() << "new shape does not match result rank";
967+
968+
for (auto [newShapeDim, outputShapeDim] :
969+
zip(getNewShape(), outputType.getShape()))
970+
if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
971+
newShapeDim != outputShapeDim)
972+
return emitOpError() << "new shape is inconsistent with result shape";
973+
965974
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
966975
int64_t inputElementsNum = inputType.getNumElements();
967976
int64_t outputElementsNum = outputType.getNumElements();
968977
if (inputElementsNum != outputElementsNum) {
969-
return emitOpError() << "Cannot reshape " << inputElementsNum
978+
return emitOpError() << "cannot reshape " << inputElementsNum
970979
<< " elements into " << outputElementsNum;
971980
}
972981
}
973982

974983
int missingDims = llvm::count(getNewShape(), -1);
975984
if (missingDims > 1)
976-
return emitOpError() << "At most one target dimension can be -1";
985+
return emitOpError() << "expected at most one target dimension to be -1";
977986

978987
return mlir::success();
979988
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -243,38 +243,70 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
243243

244244
// -----
245245

246-
func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
247-
// expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
248-
%0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
246+
func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
247+
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
248+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
249249
return
250250
}
251251

252252
// -----
253253

254-
func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
255-
// expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
256-
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
257-
return %0 : tensor<100x100xf32>
254+
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
255+
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
256+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
257+
return
258258
}
259259

260260
// -----
261261

262-
func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () {
263-
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
264-
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32>
262+
func.func @test_reshape_rank_mismatch(%arg0 : tensor<?xf32>) -> () {
263+
// expected-error@+1 {{'tosa.reshape' op new shape does not match result rank}}
264+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4>} : (tensor<?xf32>) -> tensor<?xf32>
265265
return
266266
}
267267

268268
// -----
269269

270-
func.func @test_reshape_zero_dim_input(%arg0 : tensor<?x0x3xf32>) -> () {
271-
// expected-error@+1 {{'tosa.reshape' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
272-
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3>} : (tensor<?x0x3xf32>) -> tensor<13x0x3xf32>
270+
func.func @test_reshape_inconsistent_result_type(%arg0 : tensor<?xf32>) -> () {
271+
// expected-error@+1 {{'tosa.reshape' op new shape is inconsistent with result shape}}
272+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 4, -1>} : (tensor<?xf32>) -> tensor<?x3x5xf32>
273+
return
274+
}
275+
276+
// -----
277+
278+
func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
279+
// expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 15}}
280+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 3, 5>} : (tensor<2x4xf32>) -> tensor<3x5xf32>
281+
return
282+
}
283+
284+
// -----
285+
286+
func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
287+
// expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
288+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>
273289
return
274290
}
275291

276292
// -----
277293

294+
func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
295+
// expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
296+
%0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>
297+
return
298+
}
299+
300+
// -----
301+
302+
func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
303+
// expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
304+
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
305+
return %0 : tensor<100x100xf32>
306+
}
307+
308+
// -----
309+
278310
func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> {
279311
// expected-error@+1 {{'tosa.conv2d' op tensor has a dimension with size zero. Each dimension of a tensor must have size >= 1}}
280312
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}

0 commit comments

Comments
 (0)