Skip to content

Commit b91ce7b

Browse files
authored
[Tosa] : Fix integer overflow for computing intmax+1 in tosa.cast to linalg. (#112455)
This PR fixes an issue related to integer overflow when computing `(intmax+1)` for `i64` during `tosa-to-linalg` pass for `tosa.cast`. Found this issue while debugging a numerical mismatch for `deeplabv3` model from `torchvision` represented in `tosa` dialect using the `TorchToTosa` pipeline in `torch-mlir` repository. `torch.aten.to.dtype` is converted to `tosa.cast` that casts `f32` to `i64` type. Technically by the specification, `tosa.cast` doesn't handle casting `f32` to `i64`. So it's possible to add a verifier to error out for such tosa ops instead of producing incorrect code. However, I chose to fix the overflow issue to still be able to represent the `deeplabv3` model with `tosa` ops in the above-mentioned pipeline. Open to suggestions if adding the verifier is more appropriate instead.
1 parent 2e43a30 commit b91ce7b

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
555555
auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
556556
loc, rewriter.getFloatAttr(
557557
getElementTypeOrSelf(srcTy),
558-
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
559-
.getSExtValue() +
560-
1));
558+
static_cast<double>(
559+
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
560+
.getSExtValue()) +
561+
1.0f));
561562

562563
auto intMax = rewriter.create<arith::ConstantOp>(
563564
loc, rewriter.getIntegerAttr(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,3 +1929,30 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
19291929
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
19301930
return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
19311931
}
1932+
1933+
1934+
// -----
1935+
1936+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
1937+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>
1938+
1939+
// CHECK-LABEL: func.func @test_cast_fp32_i64(
1940+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1xf32>) -> tensor<1xi64> {
1941+
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<1xi64>
1942+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<1xi64>) {
1943+
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: i64):
1944+
// CHECK: %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
1945+
// CHECK: %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
1946+
// CHECK: %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
1947+
// CHECK: %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
1948+
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
1949+
// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
1950+
// CHECK: %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
1951+
// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
1952+
// CHECK: linalg.yield %[[SELECT]] : i64
1953+
// CHECK: } -> tensor<1xi64>
1954+
// CHECK: return %[[RESULT]] : tensor<1xi64>
1955+
func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
1956+
%0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
1957+
return %0: tensor<1xi64>
1958+
}

0 commit comments

Comments
 (0)