Skip to content

Commit 8a57bc0

Browse files
authored
[mlir][tosa] Add verifiers to ReduceOps, fix shape inference crash (llvm#69843)
This patch adds verifiers to `tosa.reduce_*` ops that check, among other things, that the supplied `axis` argument is compatible with the input/output tensors' shapes. We allow for a special case of `axis == 0 && rank == 0` to be valid. This patch also adds a check to `ReduceInferReturnTypes()` to ensure that the shape inference pass doesn't crash on an invalid `axis` argument anymore. Fix llvm#68187
1 parent d593f6c commit 8a57bc0

File tree

4 files changed

+125
-6
lines changed

4 files changed

+125
-6
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
12711271
);
12721272

12731273
let hasFolder = 1;
1274+
let hasVerifier = 1;
12741275

12751276
let extraClassDeclaration = [{
12761277
/// Returns true when two result types are compatible for this op;
@@ -1304,6 +1305,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
13041305
);
13051306

13061307
let hasFolder = 1;
1308+
let hasVerifier = 1;
13071309

13081310
let extraClassDeclaration = [{
13091311
/// Returns true when two result types are compatible for this op;
@@ -1337,6 +1339,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
13371339
);
13381340

13391341
let hasFolder = 1;
1342+
let hasVerifier = 1;
13401343

13411344
let extraClassDeclaration = [{
13421345
/// Returns true when two result types are compatible for this op;
@@ -1371,6 +1374,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
13711374
);
13721375

13731376
let hasFolder = 1;
1377+
let hasVerifier = 1;
13741378

13751379
let extraClassDeclaration = [{
13761380
/// Returns true when two result types are compatible for this op;
@@ -1405,6 +1409,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
14051409
);
14061410

14071411
let hasFolder = 1;
1412+
let hasVerifier = 1;
14081413

14091414
let extraClassDeclaration = [{
14101415
/// Returns true when two result types are compatible for this op;
@@ -1436,8 +1441,10 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
14361441
let results = (outs
14371442
Tosa_Tensor:$output
14381443
);
1439-
let hasFolder = 1;
14401444

1445+
let hasFolder = 1;
1446+
let hasVerifier = 1;
1447+
14411448
let extraClassDeclaration = [{
14421449
/// Returns true when two result types are compatible for this op;
14431450
/// Method used by InferTypeOpInterface.

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,14 +1109,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
11091109
static LogicalResult ReduceInferReturnTypes(
11101110
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
11111111
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1112-
if (!operandShape.hasRank() || operandShape.getRank() == 0) {
1112+
int64_t axisVal = axis.getValue().getSExtValue();
1113+
if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
11131114
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
11141115
return success();
11151116
}
11161117

11171118
SmallVector<int64_t> outputShape;
11181119
operandShape.getDims(outputShape);
1119-
int64_t axisVal = axis.getValue().getSExtValue();
11201120
outputShape[axisVal] = 1;
11211121
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
11221122
return success();
@@ -1155,6 +1155,63 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
11551155
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
11561156
#undef COMPATIBLE_RETURN_TYPES
11571157

1158+
template <typename T>
1159+
static LogicalResult verifyReduceOp(T op) {
1160+
// All TOSA reduce Ops have input, output and axis.
1161+
TensorType inputType = op.getInput().getType();
1162+
TensorType outputType = op.getOutput().getType();
1163+
int32_t reduceAxis = op.getAxis();
1164+
1165+
if (reduceAxis < 0) {
1166+
op.emitOpError("reduce axis must not be negative");
1167+
return failure();
1168+
}
1169+
if (inputType.hasRank()) {
1170+
int64_t inputRank = inputType.getRank();
1171+
// We allow for a special case where the input/output shape has rank 0 and
1172+
// axis is also 0.
1173+
if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1174+
op.emitOpError("expect input tensor rank (")
1175+
<< inputRank << ") to be larger than reduce axis (" << reduceAxis
1176+
<< ")";
1177+
return failure();
1178+
}
1179+
}
1180+
if (outputType.hasRank()) {
1181+
int64_t outputRank = outputType.getRank();
1182+
if (inputType.hasRank() && outputRank != inputType.getRank()) {
1183+
op.emitOpError(
1184+
"expect output tensor rank to be equal to input tensor rank");
1185+
return failure();
1186+
}
1187+
if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1188+
op.emitOpError("expect output tensor rank (")
1189+
<< outputRank << ") to be larger than reduce axis (" << reduceAxis
1190+
<< ")";
1191+
return failure();
1192+
}
1193+
// We can only verify the reduced dimension size to be 1 if this is not the
1194+
// special case of output rank == 0.
1195+
if (outputRank != 0) {
1196+
auto outputShape = outputType.getShape();
1197+
if (!outputType.isDynamicDim(reduceAxis) &&
1198+
outputShape[reduceAxis] != 1) {
1199+
op.emitOpError("expect reduced dimension size to be 1, got ")
1200+
<< outputShape[reduceAxis];
1201+
return failure();
1202+
}
1203+
}
1204+
}
1205+
return success();
1206+
}
1207+
1208+
LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1209+
LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1210+
LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1211+
LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1212+
LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1213+
LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1214+
11581215
static LogicalResult NAryInferReturnTypes(
11591216
const ValueShapeRange &operands,
11601217
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ func.func nested @fold_reduce_rank_zero() {
599599
// CHECK-NOT: tosa.reduce_min
600600
// CHECK-NOT: tosa.reverse
601601
%0 = tensor.empty() : tensor<i32>
602-
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
602+
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
603603
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
604604
return
605605
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,69 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
128128
// -----
129129

130130
func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
131-
// expected-error@+2 {{failed to infer returned types}}
132-
// expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
131+
// expected-error@+1 {{'tosa.reduce_prod' op expect reduced dimension size to be 1, got 3}}
133132
%0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
134133
return
135134
}
136135

137136
// -----
138137

138+
func.func @test_reduce_all_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
139+
// expected-error@+1 {{'tosa.reduce_all' op expect input tensor rank (3) to be larger than reduce axis (3)}}
140+
%0 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
141+
return
142+
}
143+
144+
// -----
145+
146+
func.func @test_reduce_any_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
147+
// expected-error@+1 {{'tosa.reduce_any' op expect input tensor rank (3) to be larger than reduce axis (3)}}
148+
%0 = tosa.reduce_any %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
149+
return
150+
}
151+
152+
// -----
153+
154+
func.func @test_reduce_max_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
155+
// expected-error@+1 {{'tosa.reduce_max' op expect input tensor rank (3) to be larger than reduce axis (3)}}
156+
%0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
157+
return
158+
}
159+
160+
// -----
161+
162+
func.func @test_reduce_min_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
163+
// expected-error@+1 {{'tosa.reduce_min' op expect input tensor rank (3) to be larger than reduce axis (3)}}
164+
%0 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
165+
return
166+
}
167+
168+
// -----
169+
170+
func.func @test_reduce_prod_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
171+
// expected-error@+1 {{'tosa.reduce_prod' op expect input tensor rank (3) to be larger than reduce axis (3)}}
172+
%0 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
173+
return
174+
}
175+
176+
// -----
177+
178+
func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
179+
// expected-error@+1 {{'tosa.reduce_sum' op expect input tensor rank (3) to be larger than reduce axis (3)}}
180+
%0 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
181+
return
182+
}
183+
184+
// -----
185+
186+
func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
187+
// expected-error@+1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
188+
%0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
189+
return
190+
}
191+
192+
// -----
193+
139194
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
140195
// expected-error@+2 {{failed to infer returned types}}
141196
// expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}

0 commit comments

Comments
 (0)