Skip to content

[TOSA] TosaToLinalg: fix int64_t min/max lowering of clamp #82641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2024

Conversation

mgehre-amd
Copy link
Contributor

tosa.clamp takes min/max attributes as i64, so ensure that the lowering to linalg works for the whole range.

@llvmbot
Copy link
Member

llvmbot commented Feb 22, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Gehre (mgehre-amd)

Changes

tosa.clamp takes min/max attributes as i64, so ensure that the lowering to linalg works for the whole range.


Full diff: https://github.com/llvm/llvm-project/pull/82641.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+14-14)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+15)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7eb32ebe3228fb..b706ac35c5ab15 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -382,25 +382,25 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return clampFloatHelper(loc, args[0], min, max, rewriter);
   }
 
-  if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
-    auto intTy = cast<IntegerType>(elementTy);
-    int32_t min = static_cast<int32_t>(
-        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
-    int32_t max = static_cast<int32_t>(
-        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
+  if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
+    auto intTy = elementTy.cast<IntegerType>();
+    int64_t min =
+        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+    int64_t max =
+        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
 
     if (intTy.isUnsignedInteger()) {
-      min = std::max<int32_t>(min, 0);
-      max = std::min<int32_t>(
+      min = std::max(min, (int64_t)0);
+      max = std::min(
           max,
           APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
     } else {
-      min = std::max<int32_t>(
-          min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
-                   .getSExtValue());
-      max = std::min<int32_t>(
-          max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
-                   .getSExtValue());
+      min =
+          std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue());
+      max =
+          std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+                            .getSExtValue());
     }
 
     auto minVal = rewriter.create<arith::ConstantIntOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index febe74e8767465..1fa783f05f04ee 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_i64
+func.func @test_i64(%arg0: tensor<1xi64>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.+]]: i64,
+  // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
+  // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
+  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
+  %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_clamp_f16
 func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK: linalg.generic

@mgehre-amd mgehre-amd force-pushed the matthias.tosa_linalg_clamp branch from fb85db3 to d39cc53 Compare February 22, 2024 15:59
Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spec-wise TOSA doesn't support i64 clamp, but the dialect does (and linalg definitely does).

It's likely we'll expand support at some point for i64 at least as an optional data type, so this is a useful change to have to support in the path to linalg.

@mgehre-amd mgehre-amd merged commit c1e9883 into llvm:main Feb 22, 2024
@mgehre-amd mgehre-amd deleted the matthias.tosa_linalg_clamp branch February 22, 2024 20:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants