Skip to content

Commit e98a61d

Browse files
authored
[mlir][tosa] Add verifier check for Concat Op (llvm#136047)
This adds verifier check for Concat Op to make sure the sum of concatenated axis dimensions is equal to the output's axis dimension add tests in verifier.mlir also moved existing concat verifier checks to verifier.mlir Signed-off-by: Tai Ly <[email protected]>
1 parent 94a14f9 commit e98a61d

File tree

3 files changed

+58
-31
lines changed

3 files changed

+58
-31
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,25 @@ LogicalResult tosa::ConcatOp::verify() {
13221322
<< " on operands 0 and " << operandNum;
13231323
}
13241324
}
1325+
1326+
// ERROR_IF(axis_sum != shape[axis]);
1327+
int64_t axisSum = 0;
1328+
for (const auto &input : inputList) {
1329+
const ShapeAdaptor inputShape(input.getType());
1330+
if (inputShape.isDynamicDim(axis)) {
1331+
// make axisSum negative to indicate invalid value
1332+
axisSum = -1;
1333+
break;
1334+
}
1335+
axisSum += inputShape.getDimSize(axis);
1336+
}
1337+
const ShapeAdaptor outputShape(outType);
1338+
if (axisSum >= 0 && outputShape.hasRank() &&
1339+
!outputShape.isDynamicDim(axis) &&
1340+
axisSum != outputShape.getDimSize(axis))
1341+
return emitOpError("requires sum of axis dimensions of input1 "
1342+
"equal to output axis dimension, got ")
1343+
<< axisSum << " and " << outputShape.getDimSize(axis);
13251344
}
13261345

13271346
return success();

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -272,37 +272,6 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
272272

273273
// -----
274274

275-
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
276-
// expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
277-
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
278-
return %0 : tensor<?x?xi8>
279-
}
280-
281-
// -----
282-
283-
func.func @test_concat_zero_inputs() {
284-
// expected-error@+1 {{'tosa.concat' op expect at least one input}}
285-
%0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
286-
}
287-
288-
// -----
289-
290-
func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
291-
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
292-
%0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
293-
return %0 : tensor<2x2xf32>
294-
}
295-
296-
// -----
297-
298-
func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
299-
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
300-
%0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
301-
return %0 : tensor<2x2xf32>
302-
}
303-
304-
// -----
305-
306275
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
307276
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
308277
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,42 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
319319
: (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
320320
return %0 : tensor<1x4x8x19x34xf32>
321321
}
322+
323+
// -----
324+
325+
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
326+
// expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
327+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
328+
return %0 : tensor<?x?xi8>
329+
}
330+
331+
// -----
332+
333+
func.func @test_concat_zero_inputs() {
334+
// expected-error@+1 {{'tosa.concat' op expect at least one input}}
335+
%0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32>
336+
}
337+
338+
// -----
339+
340+
func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
341+
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}}
342+
%0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
343+
return %0 : tensor<2x2xf32>
344+
}
345+
346+
// -----
347+
348+
func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
349+
// expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}}
350+
%0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
351+
return %0 : tensor<2x2xf32>
352+
}
353+
354+
// -----
355+
356+
func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?xf32> {
357+
// expected-error@+1 {{'tosa.concat' op requires sum of axis dimensions of input1 equal to output axis dimension, got 3 and 2}}
358+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
359+
return %0 : tensor<2x?xf32>
360+
}

0 commit comments

Comments
 (0)