Skip to content

Commit c68f170

Browse files
committed
Add verifiers, tests for invalid reduce ops
1 parent 9c43021 commit c68f170

File tree

4 files changed

+97
-13
lines changed

4 files changed

+97
-13
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
12711271
);
12721272

12731273
let hasFolder = 1;
1274+
let hasVerifier = 1;
12741275

12751276
let extraClassDeclaration = [{
12761277
/// Returns true when two result types are compatible for this op;
@@ -1304,6 +1305,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
13041305
);
13051306

13061307
let hasFolder = 1;
1308+
let hasVerifier = 1;
13071309

13081310
let extraClassDeclaration = [{
13091311
/// Returns true when two result types are compatible for this op;
@@ -1337,6 +1339,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
13371339
);
13381340

13391341
let hasFolder = 1;
1342+
let hasVerifier = 1;
13401343

13411344
let extraClassDeclaration = [{
13421345
/// Returns true when two result types are compatible for this op;
@@ -1371,6 +1374,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
13711374
);
13721375

13731376
let hasFolder = 1;
1377+
let hasVerifier = 1;
13741378

13751379
let extraClassDeclaration = [{
13761380
/// Returns true when two result types are compatible for this op;
@@ -1405,6 +1409,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
14051409
);
14061410

14071411
let hasFolder = 1;
1412+
let hasVerifier = 1;
14081413

14091414
let extraClassDeclaration = [{
14101415
/// Returns true when two result types are compatible for this op;
@@ -1436,8 +1441,10 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
14361441
let results = (outs
14371442
Tosa_Tensor:$output
14381443
);
1439-
let hasFolder = 1;
14401444

1445+
let hasFolder = 1;
1446+
let hasVerifier = 1;
1447+
14411448
let extraClassDeclaration = [{
14421449
/// Returns true when two result types are compatible for this op;
14431450
/// Method used by InferTypeOpInterface.

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,46 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
11551155
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
11561156
#undef COMPATIBLE_RETURN_TYPES
11571157

1158+
template <typename T> static LogicalResult verifyReduceOp(T op) {
1159+
// All TOSA reduce Ops have input, output and axis.
1160+
TensorType inputType = op.getInput().getType();
1161+
TensorType outputType = op.getOutput().getType();
1162+
int32_t reduceAxis = op.getAxis();
1163+
1164+
if (reduceAxis < 0) {
1165+
op.emitOpError("reduce axis must not be negative");
1166+
return failure();
1167+
}
1168+
if (inputType.hasRank() && reduceAxis >= inputType.getRank()) {
1169+
op.emitOpError("expect input tensor rank (")
1170+
<< inputType.getRank() << ") to be larger than reduce axis ("
1171+
<< reduceAxis << ")";
1172+
return failure();
1173+
}
1174+
if (outputType.hasRank()) {
1175+
if (reduceAxis >= outputType.getRank()) {
1176+
op.emitOpError("expect output tensor rank (")
1177+
<< outputType.getRank() << ") to be larger than reduce axis ("
1178+
<< reduceAxis << ")";
1179+
return failure();
1180+
}
1181+
auto outputShape = outputType.getShape();
1182+
if (!outputType.isDynamicDim(reduceAxis) && outputShape[reduceAxis] != 1) {
1183+
op.emitOpError("expect reduced dimension size to be 1, got ")
1184+
<< outputShape[reduceAxis];
1185+
return failure();
1186+
}
1187+
}
1188+
return success();
1189+
}
1190+
1191+
LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1192+
LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1193+
LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1194+
LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1195+
LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1196+
LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1197+
11581198
static LogicalResult NAryInferReturnTypes(
11591199
const ValueShapeRange &operands,
11601200
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -593,13 +593,3 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
593593
}
594594

595595
// -----
596-
597-
// CHECK-LABEL: @fold_reduce_rank_zero
598-
func.func nested @fold_reduce_rank_zero() {
599-
// CHECK-NOT: tosa.reduce_min
600-
// CHECK-NOT: tosa.reverse
601-
%0 = tensor.empty() : tensor<i32>
602-
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
603-
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
604-
return
605-
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,61 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
128128
// -----
129129

130130
func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
131-
// expected-error@+2 {{failed to infer returned types}}
132-
// expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
131+
// expected-error@+1 {{'tosa.reduce_prod' op expect reduced dimension size to be 1, got 3}}
133132
%0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
134133
return
135134
}
136135

137136
// -----
138137

138+
func.func @test_reduce_all_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
139+
// expected-error@+1 {{'tosa.reduce_all' op expect input tensor rank (3) to be larger than reduce axis (3)}}
140+
%0 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
141+
return
142+
}
143+
144+
// -----
145+
146+
func.func @test_reduce_any_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
147+
// expected-error@+1 {{'tosa.reduce_any' op expect input tensor rank (3) to be larger than reduce axis (3)}}
148+
%0 = tosa.reduce_any %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
149+
return
150+
}
151+
152+
// -----
153+
154+
func.func @test_reduce_max_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
155+
// expected-error@+1 {{'tosa.reduce_max' op expect input tensor rank (3) to be larger than reduce axis (3)}}
156+
%0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
157+
return
158+
}
159+
160+
// -----
161+
162+
func.func @test_reduce_min_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
163+
// expected-error@+1 {{'tosa.reduce_min' op expect input tensor rank (3) to be larger than reduce axis (3)}}
164+
%0 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
165+
return
166+
}
167+
168+
// -----
169+
170+
func.func @test_reduce_prod_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
171+
// expected-error@+1 {{'tosa.reduce_prod' op expect input tensor rank (3) to be larger than reduce axis (3)}}
172+
%0 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
173+
return
174+
}
175+
176+
// -----
177+
178+
func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
179+
// expected-error@+1 {{'tosa.reduce_sum' op expect input tensor rank (3) to be larger than reduce axis (3)}}
180+
%0 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
181+
return
182+
}
183+
184+
// -----
185+
139186
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
140187
// expected-error@+2 {{failed to infer returned types}}
141188
// expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}

0 commit comments

Comments
 (0)