Skip to content

Commit ede11f4

Browse files
[mlir][tosa] Avoid overflow in reduction folders
Avoid operations that can overflow in constant folders for tosa.reduce_max and tosa.reduce_min Includes tests to avoid regressions Signed-off-by: Ian Tayler Lessa <[email protected]>
1 parent f7aea4d commit ede11f4

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,8 +1731,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
17311731

17321732
/// Return the max of the two integer operands
17331733
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1734-
const llvm::APInt subtractRes = leftOperand - rightOperand;
1735-
return (!subtractRes.isNegative()) ? leftOperand : rightOperand;
1734+
return (leftOperand.sge(rightOperand)) ? leftOperand : rightOperand;
17361735
}
17371736
}];
17381737
}
@@ -1772,8 +1771,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
17721771

17731772
/// Return the min of the two integer operands
17741773
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1775-
const llvm::APInt subtractRes = leftOperand - rightOperand;
1776-
return (!subtractRes.isNegative()) ? rightOperand : leftOperand;
1774+
return (leftOperand.sle(rightOperand)) ? leftOperand : rightOperand;
17771775
}
17781776
}];
17791777
}

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,18 @@ func.func @reduce_max_constant() -> tensor<1x1x1xi32> {
883883
return %0 : tensor<1x1x1xi32>
884884
}
885885

886+
// -----
887+
888+
func.func @reduce_max_constant_no_overflow() -> tensor<1xi8> {
889+
// CHECK-LABEL: func.func @reduce_max_constant_no_overflow() -> tensor<1xi8> {
890+
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<120> : tensor<1xi8>}> : () -> tensor<1xi8>
891+
// CHECK: return %[[VAL_0]] : tensor<1xi8>
892+
// CHECK: }
893+
%const = "tosa.const"() <{values = dense<[-127, 120, -126]> : tensor<3xi8>}> : () -> tensor<3xi8>
894+
%0 = tosa.reduce_max %const {axis = 0 : i32} : (tensor<3xi8>) -> tensor<1xi8>
895+
return %0 : tensor<1xi8>
896+
}
897+
886898
// -----
887899

888900
func.func @reduce_min_constant() -> tensor<1x3xi32> {
@@ -968,6 +980,19 @@ func.func @reduce_min_constant() -> tensor<1x1x1xi32> {
968980
return %0 : tensor<1x1x1xi32>
969981
}
970982

983+
// -----
984+
985+
func.func @reduce_min_constant_no_overflow() -> tensor<1xi8> {
986+
// CHECK-LABEL: func.func @reduce_min_constant_no_overflow() -> tensor<1xi8> {
987+
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-127> : tensor<1xi8>}> : () -> tensor<1xi8>
988+
// CHECK: return %[[VAL_0]] : tensor<1xi8>
989+
// CHECK: }
990+
%const = "tosa.const"() <{values = dense<[-127, 120, -126]> : tensor<3xi8>}> : () -> tensor<3xi8>
991+
%0 = tosa.reduce_min %const {axis = 0 : i32} : (tensor<3xi8>) -> tensor<1xi8>
992+
return %0 : tensor<1xi8>
993+
}
994+
995+
971996
// -----
972997

973998
func.func @reduce_any_constant() -> tensor<1x3xi1> {

0 commit comments

Comments
 (0)