Skip to content

Commit e6eb94d

Browse files
authored
[mlir][tosa] Add missing verifier check for tosa.reshape (llvm#109301)
This PR adds a missing verifier check for `tosa.reshape`, ensuring that the number of elements in `new_shape` matches the number of elements in the input tensor. Fixes llvm#108151 and fixes llvm#107969.
1 parent 76b827b commit e6eb94d

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#include "llvm/ADT/DenseMap.h"
3131
#include "llvm/ADT/TypeSwitch.h"
3232

33+
#include <numeric>
34+
3335
using namespace mlir;
3436
using namespace mlir::tosa;
3537

@@ -1015,12 +1017,25 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
10151017
<< newShapeDim;
10161018
}
10171019

1018-
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1020+
if (inputType.hasStaticShape()) {
10191021
int64_t inputElementsNum = inputType.getNumElements();
1020-
int64_t outputElementsNum = outputType.getNumElements();
1021-
if (inputElementsNum != outputElementsNum) {
1022+
if (outputType.hasStaticShape()) {
1023+
int64_t outputElementsNum = outputType.getNumElements();
1024+
if (inputElementsNum != outputElementsNum) {
1025+
return emitOpError() << "cannot reshape " << inputElementsNum
1026+
<< " elements into " << outputElementsNum;
1027+
}
1028+
}
1029+
1030+
int64_t newShapeElementsNum = std::accumulate(
1031+
getNewShape().begin(), getNewShape().end(), 1LL,
1032+
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1033+
bool isStaticNewShape =
1034+
llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
1035+
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1036+
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
10221037
return emitOpError() << "cannot reshape " << inputElementsNum
1023-
<< " elements into " << outputElementsNum;
1038+
<< " elements into " << newShapeElementsNum;
10241039
}
10251040
}
10261041

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,22 @@ func.func @test_reshape_invalid_size(%arg0 : tensor<2x4xf32>) -> () {
360360

361361
// -----
362362

363+
func.func @test_reshape_invalid_newshape(%arg0 : tensor<1xf32>) -> () {
364+
// expected-error@+1 {{'tosa.reshape' op cannot reshape 1 elements into 4}}
365+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 4>} : (tensor<1xf32>) -> tensor<?x4xf32>
366+
return
367+
}
368+
369+
// -----
370+
371+
func.func @test_reshape_invalid_newshape(%arg0 : tensor<8xf32>) -> () {
372+
// expected-error@+1 {{'tosa.reshape' op cannot reshape 8 elements into 4}}
373+
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 4>} : (tensor<8xf32>) -> tensor<?x4xf32>
374+
return
375+
}
376+
377+
// -----
378+
363379
func.func @test_reshape_invalid_placeholders(%arg0 : tensor<?xf32>) -> () {
364380
// expected-error@+1 {{'tosa.reshape' op expected at most one target dimension to be -1}}
365381
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, -1>} : (tensor<?xf32>) -> tensor<2x?x?xf32>

0 commit comments

Comments
 (0)