Skip to content

Commit 3326d8e

Browse files
committed
Update to unspecified overflow behavior of int mul
* Do not use saturating multiplication * Emit a warning if overflows occur * Check the behavior in a test * Add a test with a small bit width which would overflow if the mul wouldn't always result a 32-bit int
1 parent 81c13b6 commit 3326d8e

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,26 @@ struct TosaFoldConstantMul : public OpRewritePattern<MulOp> {
7474
assert(isa<IntegerType>(rhsElemType) &&
7575
isa<IntegerType>(resultElementType));
7676
auto resultElementWidth = resultElementType.getIntOrFloatBitWidth();
77-
assert(resultElementWidth == 32 &&
78-
"All integer multiplications in TOSA are specified to result in "
79-
"32 bit width");
77+
assert(resultElementWidth >= lhsElemType.getIntOrFloatBitWidth() &&
78+
"The multiplication is expected to have an at least a big output "
79+
"as input type");
8080
// TODO: To implement shifts > 0, capture the shift value stored in the
8181
// mul here
82-
auto intMulFun = [&resultElementWidth](const APInt &first,
83-
const APInt &second) {
84-
// TODO the documentation has conflicting definitions for the behavior
85-
// of overflows
86-
// The sign extend should always be valid as the result type is required
87-
// to be i32 and all other integer input types are smaller or equal
88-
// to 32.
89-
return first.sext(resultElementWidth)
90-
.smul_sat(second.sext(resultElementWidth));
82+
bool mulOverflowed;
83+
auto intMulFun = [&resultElementWidth, &mulOverflowed](
84+
const APInt &first, const APInt &second) {
85+
bool didOverflow;
86+
auto res = first.sext(resultElementWidth)
87+
.smul_ov(second.sext(resultElementWidth), didOverflow);
88+
mulOverflowed |= didOverflow;
89+
return res;
9190
};
9291
newTensor = applyElementWise<APInt, APInt>(lhsValues, rhsValues,
9392
resultType, intMulFun);
93+
if (mulOverflowed) {
94+
mulOp.emitWarning(
95+
"Multiplication did overflow. The results are unspecified.");
96+
}
9497
} else {
9598
assert(isa<FloatType>(lhsElemType) && isa<FloatType>(rhsElemType) &&
9699
isa<FloatType>(resultType.getElementType()));

mlir/test/Dialect/Tosa/constant-mul-opt.mlir

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,28 @@ func.func @mul_fold_int() -> tensor<4xi32> {
7373
return %2 : tensor<4xi32>
7474
}
7575

76-
// -----
77-
// self-multiplication
76+
// CHECK-LABEL: @mul_fold_i8
77+
func.func @mul_fold_i8() -> tensor<4xi32> {
78+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -12, 0, 0
79+
// CHECK-NOT: tosa.mul
80+
// CHECK: return [[RES]]
81+
%0 = "tosa.const"() {value =
82+
dense<[-17, 4, -2, 0]> :
83+
tensor<4xi8>
84+
} : () -> tensor<4xi8>
85+
%1 = "tosa.const"() {value =
86+
dense<[-12, -3, 0, 5]> :
87+
tensor<4xi8>
88+
} : () -> tensor<4xi8>
89+
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32>
90+
return %2 : tensor<4xi32>
91+
}
7892

7993
// CHECK-LABEL: @mul_fold_int_overflow
80-
// TODO: Change expected behavior if the tosa.mul on i32 should not be
81-
// saturating. Also add a test with different widths in that case.
8294
func.func @mul_fold_int_overflow() -> tensor<4xi32> {
83-
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2147483647, 2147483647, -2147483648, -2147483648
95+
// Don't expect any specific results for the overflowing multiplication, just
96+
// that it is folded.
97+
// CHECK: [[RES:]] ={{.*}}tosa.const
8498
// CHECK-NOT: tosa.mul
8599
// CHECK: return [[RES]]
86100
%0 = "tosa.const"() {value =
@@ -91,10 +105,14 @@ func.func @mul_fold_int_overflow() -> tensor<4xi32> {
91105
dense<[1, 10, 1, 30]> :
92106
tensor<4xi32>
93107
} : () -> tensor<4xi32>
108+
// expected-warning@below {{Multiplication did overflow. The results are unspecified.}}
94109
%2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
95110
return %2 : tensor<4xi32>
96111
}
97112

113+
// -----
114+
// self-multiplication
115+
98116
// CHECK-LABEL: @mul_fold_equal_args
99117
func.func @mul_fold_equal_args() -> tensor<3xi32> {
100118
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}289, 16, 0

0 commit comments

Comments
 (0)