Skip to content

Commit 18f8928

Browse files
lhutton1Tai78641
andauthored
[mlir][tosa] Fix mul folder conformance to the spec (#137601)
Change the folder for mul with a shift such that the rounding happens correctly according to the spec pesudo-code. Fixes: https://discourse.llvm.org/t/tosa-mul-i32-shift-incorrect-result/86040 Partial cherry-pick from: #128059 Co-authored-by: Tai Ly <[email protected]>
1 parent 2a32d73 commit 18f8928

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,27 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
918918
}
919919

920920
namespace {
921+
// calculate lhs * rhs >> shift according to TOSA Spec
922+
// return nullopt if result is not in range of int32_t when shift > 0
923+
std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
924+
unsigned bitwidth) {
925+
APInt result = lhs.sext(64) * rhs.sext(64);
926+
927+
if (shift > 0) {
928+
auto round = APInt(64, 1) << (shift - 1);
929+
result += round;
930+
result.ashrInPlace(shift);
931+
// REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
932+
if (!(result.getSExtValue() >= INT32_MIN &&
933+
result.getSExtValue() <= INT32_MAX)) {
934+
// REQUIRE failed
935+
return std::nullopt;
936+
}
937+
}
938+
939+
return result.trunc(bitwidth);
940+
}
941+
921942
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
922943
RankedTensorType ty, int32_t shift) {
923944
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
@@ -930,12 +951,10 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
930951
}
931952

932953
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
933-
l = l.sext(bitwidth * 2);
934-
r = r.sext(bitwidth * 2);
935-
auto result = l * r;
936-
result.lshrInPlace(shift);
937-
result = result.trunc(bitwidth);
938-
return DenseElementsAttr::get(ty, result);
954+
const std::optional<APInt> result = mulInt(l, r, shift, bitwidth);
955+
if (!result)
956+
return {};
957+
return DenseElementsAttr::get(ty, result.value());
939958
}
940959

941960
if (llvm::isa<FloatType>(ty.getElementType())) {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,4 +1226,43 @@ func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?
12261226
%1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
12271227
%2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
12281228
return %2 : tensor<2x60x58x?xf32>
1229-
}
1229+
}
1230+
1231+
// -----
1232+
1233+
// CHECK-LABEL: @fold_mul_shift
1234+
// CHECK-DAG: "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
1235+
func.func @fold_mul_shift() -> tensor<i32> {
1236+
%0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
1237+
%1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
1238+
%2 = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8>
1239+
%3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
1240+
return %3 : tensor<i32>
1241+
}
1242+
1243+
// -----
1244+
1245+
// CHECK-LABEL: @fold_mul_no_shift
1246+
// CHECK-DAG: "tosa.const"() <{values = dense<781333542> : tensor<i32>}> : () -> tensor<i32>
1247+
func.func @fold_mul_no_shift() -> tensor<i32> {
1248+
%0 = "tosa.const"() <{values = dense<-23661> : tensor<i32>}> : () -> tensor<i32>
1249+
%1 = "tosa.const"() <{values = dense<-33022> : tensor<i32>}> : () -> tensor<i32>
1250+
%2 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1251+
%3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
1252+
return %3 : tensor<i32>
1253+
}
1254+
1255+
// -----
1256+
1257+
// CHECK-LABEL: @no_fold_mul_result_exceeds_i32
1258+
// CHECK-DAG: %[[LHS:.*]] = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
1259+
// CHECK-DAG: %[[RHS:.*]] = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
1260+
// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
1261+
// CHECK: tosa.mul %[[LHS]], %[[RHS]], %[[SHIFT]] : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
1262+
func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> {
1263+
%0 = "tosa.const"() <{values = dense<23661> : tensor<i32>}> : () -> tensor<i32>
1264+
%1 = "tosa.const"() <{values = dense<330222> : tensor<i32>}> : () -> tensor<i32>
1265+
%2 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
1266+
%3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
1267+
return %3 : tensor<i32>
1268+
}

0 commit comments

Comments
 (0)