Skip to content

Commit 695b02a

Browse files
committed
[mlir][tosa] Add verifier check for Concat Op
This adds verifier check for Concat Op to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension add tests in verifier.mlir also moved existing concat verifier checks to verifier.mlir Signed-off-by: Tai Ly <[email protected]> Change-Id: I53e41ca3c1f4ee48997c510fee2c16ed912dfaa0
1 parent e643050 commit 695b02a

File tree

3 files changed

+58
-31
lines changed

3 files changed

+58
-31
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,25 @@ LogicalResult tosa::ConcatOp::verify() {
11781178
<< " on operands 0 and " << operandNum;
11791179
}
11801180
}
1181+
1182+
// ERROR_IF(axis_sum != shape[axis]);
1183+
int64_t axis_sum = 0;
1184+
for (const auto &input : inputList) {
1185+
const ShapeAdaptor inputShape(input.getType());
1186+
if (inputShape.isDynamicDim(axis)) {
1187+
// make axis_sum negative to indicate invalid value
1188+
axis_sum = -1;
1189+
break;
1190+
}
1191+
axis_sum += inputShape.getDimSize(axis);
1192+
}
1193+
const ShapeAdaptor outputShape(outType);
1194+
if (axis_sum >= 0 && outputShape.hasRank() &&
1195+
!outputShape.isDynamicDim(axis) &&
1196+
axis_sum != outputShape.getDimSize(axis))
1197+
return emitOpError("requires sum of axis dimensions of input1 "
1198+
"equal to output axis dimension, got ")
1199+
<< axis_sum << " and " << outputShape.getDimSize(axis);
11811200
}
11821201

11831202
return success();

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
272272

273273
// -----
274274

275-
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
276-
// expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
277-
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
278-
return %0 : tensor<?x?xi8>
279-
}
280-
281-
// -----
282-
283-
func.func @test_concat_zero_inputs() {
284-
// expected-error@+1 {{'tosa.concat' op expect at least one input}}
285-
%0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
286-
}
287-
288-
// -----
289-
290-
func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
291-
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
292-
%0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
293-
return %0 : tensor<2x2xf32>
294-
}
295-
296-
// -----
297-
298-
func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
299-
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
300-
%0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
301-
return %0 : tensor<2x2xf32>
302-
}
303-
304-
// -----
305-
306275
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
307276
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
308277
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,42 @@ func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
167167
%2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
168168
return %2 : tensor<f32>
169169
}
170+
171+
// -----
172+
173+
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
174+
// expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
175+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
176+
return %0 : tensor<?x?xi8>
177+
}
178+
179+
// -----
180+
181+
func.func @test_concat_zero_inputs() {
182+
// expected-error@+1 {{'tosa.concat' op expect at least one input}}
183+
%0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
184+
}
185+
186+
// -----
187+
188+
func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
189+
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
190+
%0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
191+
return %0 : tensor<2x2xf32>
192+
}
193+
194+
// -----
195+
196+
func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
197+
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
198+
%0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
199+
return %0 : tensor<2x2xf32>
200+
}
201+
202+
// -----
203+
204+
func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
205+
// expected-error@+1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
206+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
207+
return %0 : tensor<2x?xf32>
208+
}

0 commit comments

Comments
 (0)