Skip to content

Commit c1e9883

Browse files
mgehre-amdttjost
andauthored
[TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp (#82641)
tosa.clamp takes `min`/`max` attributes as i64, so ensure that the lowering to linalg works for the whole range. Co-authored-by: Tiago Trevisan Jost <[email protected]>
1 parent f5c8e9e commit c1e9883

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,23 +384,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
384384

385385
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
386386
auto intTy = cast<IntegerType>(elementTy);
387-
int32_t min = static_cast<int32_t>(
388-
cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
389-
int32_t max = static_cast<int32_t>(
390-
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
387+
int64_t min =
388+
cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
389+
int64_t max =
390+
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
391391

392392
if (intTy.isUnsignedInteger()) {
393-
min = std::max<int32_t>(min, 0);
394-
max = std::min<int32_t>(
393+
min = std::max(min, (int64_t)0);
394+
max = std::min(
395395
max,
396396
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
397397
} else {
398-
min = std::max<int32_t>(
399-
min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
400-
.getSExtValue());
401-
max = std::min<int32_t>(
402-
max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
403-
.getSExtValue());
398+
min =
399+
std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
400+
.getSExtValue());
401+
max =
402+
std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
403+
.getSExtValue());
404404
}
405405

406406
auto minVal = rewriter.create<arith::ConstantIntOp>(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
759759

760760
// -----
761761

762+
// CHECK-LABEL: @test_i64
763+
func.func @test_i64(%arg0: tensor<1xi64>) -> () {
764+
// CHECK: linalg.generic
765+
// CHECK: ^bb0(%[[ARG1:.+]]: i64,
766+
// CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
767+
// CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
768+
// CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
769+
// CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
770+
%0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
771+
772+
return
773+
}
774+
775+
// -----
776+
762777
// CHECK-LABEL: @test_clamp_f16
763778
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
764779
// CHECK: linalg.generic

0 commit comments

Comments
 (0)