Skip to content

Commit 019fbcc

Browse files
authored
[mlir][tosa] Add missing check for new_shape of tosa.reshape (#104394)
This patch adds check for new_shape of `tosa.reshape`. Tensor dimension with size less than -1 is invalid. Fix #103062.
1 parent 5f01fda commit 019fbcc

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,11 +990,16 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
990990
return emitOpError() << "new shape does not match result rank";
991991

992992
for (auto [newShapeDim, outputShapeDim] :
993-
zip(getNewShape(), outputType.getShape()))
993+
zip(getNewShape(), outputType.getShape())) {
994994
if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
995995
newShapeDim != outputShapeDim)
996996
return emitOpError() << "new shape is inconsistent with result shape";
997997

998+
if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
999+
return emitOpError() << "new shape has invalid tensor dimension size "
1000+
<< newShapeDim;
1001+
}
1002+
9981003
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
9991004
int64_t inputElementsNum = inputType.getNumElements();
10001005
int64_t outputElementsNum = outputType.getNumElements();

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,14 @@ func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
291291

292292
// -----
293293

294+
func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () {
295+
// expected-error@+1 {{'tosa.reshape' op new shape has invalid tensor dimension size -2}}
296+
%0 = "tosa.reshape" (%arg0) {new_shape = array<i64: -2, -1>} : (tensor<4x?xf32>) -> tensor<?x4xf32>
297+
return
298+
}
299+
300+
// -----
301+
294302
func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () {
295303
// expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}}
296304
%0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor<?x?x?xi32>

0 commit comments

Comments
 (0)