Skip to content

[mlir][tosa] Change ClampOp's min/max attributes #125197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 11, 2025
Merged

Conversation

Hsiangkai
Copy link
Contributor

This changes Tosa ClampOp attributes to min_val and max_val which are either integer attributes or float attributes, and adds verify checks that these attribute element types must match element types of input and output

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Hsiangkai Wang (Hsiangkai)

Changes

This changes Tosa ClampOp attributes to min_val and max_val which are either integer attributes or float attributes, and adds verify checks that these attribute element types must match element types of input and output


Patch is 56.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125197.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-4)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+8)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+94-35)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+25-11)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-35)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+28-37)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+22-22)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9e3e41d288e4ac..41acff74321fdc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -387,10 +387,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
 
   let arguments = (ins
     Tosa_Tensor:$input,
-    I64Attr:$min_int,
-    I64Attr:$max_int,
-    Tosa_FloatAttr:$min_fp,
-    Tosa_FloatAttr:$max_fp,
+    Tosa_IntOrFloatAttr:$min_val,
+    Tosa_IntOrFloatAttr:$max_val,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf3a01db4..3795d51e5afce3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,6 +202,14 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
   let returnType = [{ ::mlir::APFloat }];
 }
 
+def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+                          "arbitrary integer attribute"> {
+  let storageType = [{ ::mlir::IntegerAttr }];
+  let returnType = [{ ::llvm::APInt }];
+}
+
+def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
+
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..49cb87a8786f95 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -385,8 +385,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   // tosa::ClampOp
   if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
     bool losesInfo = false;
-    APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
-    APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
+    APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
+    APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
     minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
                    APFloat::rmNearestTiesToEven, &losesInfo);
     maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
@@ -401,9 +401,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
     auto intTy = cast<IntegerType>(elementTy);
     int64_t min =
-        cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+        cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
     int64_t max =
-        cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
+        cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();
 
     int64_t minRepresentable = std::numeric_limits<int64_t>::min();
     int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 98871268e313b6..71369d81fbe908 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -287,10 +287,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
 
     if (isa<FloatType>(inputElementType)) {
       // Unlike integer types, floating point types can represent infinity.
-      auto minClamp = op.getMinFp();
-      auto maxClamp = op.getMaxFp();
-      bool isMin = minClamp.isInfinity() && minClamp.isNegative();
-      bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
+      auto minClamp =
+          llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
+      auto maxClamp =
+          llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
+      bool isMin = minClamp.isNegInfinity();
+      bool isMax = maxClamp.isInfinity();
 
       if (isMin && isMax) {
         rewriter.replaceOp(op, input);
@@ -300,8 +302,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
     }
 
     if (inputElementType.isUnsignedInteger()) {
-      int64_t minClamp = op.getMinInt();
-      int64_t maxClamp = op.getMaxInt();
+      int64_t minClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
+      int64_t maxClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
 
       int64_t intMin =
           APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
@@ -318,8 +322,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
     }
 
     if (llvm::isa<IntegerType>(inputElementType)) {
-      int64_t minClamp = op.getMinInt();
-      int64_t maxClamp = op.getMaxInt();
+      int64_t minClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
+      int64_t maxClamp =
+          llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
 
       int64_t intMin =
           APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
@@ -374,9 +380,10 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
 
   LogicalResult matchAndRewrite(tosa::ClampOp op,
                                 PatternRewriter &rewriter) const override {
+    Value input = op.getInput();
+
     // Check the input to the CLAMP op is itself a CLAMP.
-    auto clampOp =
-        dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
+    auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
     if (!clampOp)
       return failure();
 
@@ -386,34 +393,86 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
     if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
       return failure();
 
-    // Check we have intersecting ranges.
-    const auto opMinInt = op.getMinInt();
-    const auto opMaxInt = op.getMaxInt();
-    const auto clampOpMinInt = clampOp.getMinInt();
-    const auto clampOpMaxInt = clampOp.getMaxInt();
-    ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
-    ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
-    if (!opRangeIntRange.intersects(clampRangeIntRange))
-      return failure();
+    auto maxValAttr = op.getMaxValAttr();
+    auto minValAttr = op.getMinValAttr();
+    auto clampOpMaxValAttr = clampOp.getMaxValAttr();
+    auto clampOpMinValAttr = clampOp.getMinValAttr();
 
-    const auto opMinFloat = op.getMinFp();
-    const auto opMaxFloat = op.getMaxFp();
-    const auto clampOpMinFloat = clampOp.getMinFp();
-    const auto clampOpMaxFloat = clampOp.getMaxFp();
-    ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
-    ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
-    if (!opRangeFloatRange.intersects(clampRangeFloatRange))
-      return failure();
+    auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
+    if (auto quantType =
+            llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
+      inputEType = quantType.getStorageType();
+    }
+
+    Attribute newMinValAttr, newMaxValAttr;
+    if (mlir::isa<FloatType>(inputEType)) {
+      auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
+      auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
+      auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
+      auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);
+
+      // Check we have intersecting ranges.
+      const auto opMinFloat = floatMinValAttr.getValue();
+      const auto opMaxFloat = floatMaxValAttr.getValue();
+      const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
+      const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
+      ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
+      ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
+                                               clampOpMaxFloat);
+      if (!opRangeFloatRange.intersects(clampRangeFloatRange))
+        return failure();
+
+      // Run the transformation.
+      auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
+      auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
+      newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
+      newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
+    } else {
+      assert(mlir::isa<IntegerType>(inputEType));
+      auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
+      auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
+      auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
+      auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);
+
+      if (inputEType.isUnsignedInteger()) {
+        // Check we have intersecting ranges.
+        const auto opMinInt = intMinValAttr.getUInt();
+        const auto opMaxInt = intMaxValAttr.getUInt();
+        const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
+        const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
+        ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
+        ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
+                                                     clampOpMaxInt);
+        if (!opRangeIntRange.intersects(clampRangeIntRange))
+          return failure();
+
+        // Run the transformation.
+        auto newMinVal = std::max(opMinInt, clampOpMinInt);
+        auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
+        newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
+        newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
+      } else {
+        // Check we have intersecting ranges.
+        const auto opMinInt = intMinValAttr.getInt();
+        const auto opMaxInt = intMaxValAttr.getInt();
+        const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
+        const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
+        ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
+        ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
+                                                    clampOpMaxInt);
+        if (!opRangeIntRange.intersects(clampRangeIntRange))
+          return failure();
+
+        // Run the transformation.
+        auto newMinVal = std::max(opMinInt, clampOpMinInt);
+        auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
+        newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
+        newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
+      }
+    }
 
-    // Run the transformation.
-    const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
-    const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
-    const auto minInt = std::max(opMinInt, clampOpMinInt);
-    const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
     rewriter.replaceOpWithNewOp<tosa::ClampOp>(
-        op, op.getType(), clampOp.getInput(),
-        rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
-        rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
+        op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
         rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
                                                            : opNanMode));
     return success();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c0b419b6f473c8..23c1d45f7c1057 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -381,26 +381,40 @@ LogicalResult tosa::ClampOp::verify() {
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
     inputETy = quantType.getStorageType();
   }
-  mlir::Type maxFpType = getMaxFpAttr().getType();
-  mlir::Type minFpType = getMinFpAttr().getType();
   mlir::Type outputETy =
       llvm::cast<ShapedType>(getOutput().getType()).getElementType();
   if (auto quantType =
           llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
     outputETy = quantType.getStorageType();
   }
-  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
-
   if (inputETy != outputETy)
     return emitOpError("input/output element types are incompatible.");
 
-  // If input datatype is float, check that the two min/max_fp attributes
-  // share the same type and that their type is either the same of the input's
-  // datatype, or a float type whose bitwidth > input datatype bitwidth.
-  if (!inputETy.isInteger(dataTypeBitWidth)) {
-    if (((maxFpType != minFpType) ||
-         (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
-                                       inputETy.getIntOrFloatBitWidth())))
+  auto maxValAttr = getMaxValAttr();
+  auto minValAttr = getMinValAttr();
+
+  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
+
+  if (inputETy.isInteger(dataTypeBitWidth)) {
+    // if input datatype is integer, check that the min_val/max_val attributes
+    // are integer attributes, and that their type is the same as the input's
+    // datatype
+    auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
+    auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
+    if (!intMaxValAttr || !intMinValAttr ||
+        (intMaxValAttr.getType() != intMinValAttr.getType()) ||
+        (intMaxValAttr.getType() != inputETy))
+      return emitOpError("min/max attributes types are incompatible with "
+                         "input/output element types.");
+  } else {
+    // otherwise, input datatype is float, check that the min_val/max_val
+    // attributes share the same type and that their type is the same as the
+    // input's datatype
+    auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
+    auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
+    if (!floatMaxValAttr || !floatMinValAttr ||
+        (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
+        (floatMaxValAttr.getType() != inputETy))
       return emitOpError("min/max attributes types are incompatible with "
                          "input/output element types.");
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index f9bdcefa35317a..9ba08b427b1ae5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -529,7 +529,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   // CHECK: linalg.generic
   // CHECK: arith.minimumf
   // CHECK: arith.maximumf
-  %18 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
+  %18 = tosa.clamp %0 {min_val = 1.0 : f32, max_val = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -729,35 +729,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   // CHECK: linalg.generic
   // CHECK-DAG: arith.maxsi
   // CHECK-DAG: arith.minsi
-  %19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %19 = tosa.clamp %0 {min_val = 1 : i32, max_val = 5 : i32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
   // CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
   // CHECK-DAG: arith.maxui %[[LB]],
   // CHECK-DAG: arith.minui %[[UB]],
-  %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
-  // CHECK: linalg.generic
-  // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
-  // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
-  // CHECK-DAG: arith.maxui %[[LB]],
-  // CHECK-DAG: arith.minui %[[UB]],
-  %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
-  // CHECK: linalg.generic
-  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: arith.maxui %[[LB]],
-  // CHECK-DAG: arith.minui %[[UB]],
-  %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
-
-  // CHECK: linalg.generic
-  // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
-  // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
-  // CHECK-DAG: arith.maxui %[[LB]],
-  // CHECK-DAG: arith.minui %[[UB]],
-  %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
+  %u0 = tosa.clamp %unsigned {min_val = 4 : ui32, max_val = 32 : ui32} : (tensor<1xui32>) -> tensor<1xui32>
 
   // CHECK: linalg.generic
   // CHECK: arith.trunci
@@ -807,15 +786,7 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
   // CHECK-DAG: %[[C126:.+]] = arith.constant 126
   // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
   // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
-  %0 = tosa.clamp %arg0 {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
-
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.+]]: i8,
-  // CHECK-DAG: %[[C128:.+]] = arith.constant -128
-  // CHECK-DAG: %[[C127:.+]] = arith.constant 127
-  // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C128]], %[[ARG1]]
-  // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C127]], %[[LOWER]]
-  %1 = tosa.clamp %arg0 {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>
+  %0 = tosa.clamp %arg0 {min_val = -127 : i8, max_val = 126 : i8} : (tensor<1xi8>) -> tensor<1xi8>
 
   return
 }
@@ -830,7 +801,7 @@ func.func @test_i64(%arg0: tensor<1xi64>) -> () {
   // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
   // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
   // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
-  %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+  %0 = tosa.clamp %arg0 {min_val = -9223372036854775808 : i64, max_val = 9223372036854775807 : i64} : (tensor<1xi64>) -> tensor<1xi64>
 
   return
 }
@@ -845,7 +816,7 @@ func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
   // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
   // CHECK-DAG: %[[MIN:.+]] = arith.minimumf %[[ARG1]], %[[C6]]
   // CHECK-DAG: %[[MAX:.+]] = arith.maximumf %[[MIN]], %[[C0]]
-  %0 = tosa.clamp %arg0 {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>
+  %0 = tosa.clamp %arg0 {min_val = 0.0 : f16, max_val = 6.0 : f16} : (tensor<1xf16>) -> tensor<1xf16>
 
   return
 }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 71a7e2826a63cc..c104ac10f64b92 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -52,25 +52,16 @@ func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
 // CHECK-LABEL: @clamp_i32_not_noop
 func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
   // CHECK: tosa.clamp
-  %0 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 4 : i64, min_fp = 1.0 : f32, max_fp = 4.0 : f32} : (tensor<4xi32>) -> tensor<4xi32>
+  %0 = tosa.clamp %arg0 {min_val = 1 : i32, max_val = 4 : i32} : (tensor<4xi32>) -> tensor<4xi32>
   return %0 : tensor<4xi32>
 }
 
 // -----
 
-// CHECK-LABEL: @clamp_f16_not_noop
-func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
-  // CHECK: tosa.clamp
-  %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf16>) -> tensor<4xf16>
-  return %0 : tensor<4xf16>
-}
-
-...
[truncated]

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the verifier needs to be extended to check if fp min/max values are present on fp inputs and integer on int inputs?
Edit: Ignore it is already there.

This changes Tosa ClampOp attributes to min_val and max_val
which are either integer attributes or float attributes,
and adds verify checks that these attribute element types
must match element types of input and output

Co-authored-by: Tai Ly <[email protected]>
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great one @Hsiangkai

@Jerry-Ge Jerry-Ge merged commit ab93bd6 into llvm:main Feb 11, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
This changes Tosa ClampOp attributes to min_val and max_val which are
either integer attributes or float attributes, and adds verify checks
that these attribute element types must match element types of input and
output

Co-authored-by: Tai Ly <[email protected]>
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
This changes Tosa ClampOp attributes to min_val and max_val which are
either integer attributes or float attributes, and adds verify checks
that these attribute element types must match element types of input and
output

Co-authored-by: Tai Ly <[email protected]>
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
This changes Tosa ClampOp attributes to min_val and max_val which are
either integer attributes or float attributes, and adds verify checks
that these attribute element types must match element types of input and
output

Co-authored-by: Tai Ly <[email protected]>
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
This changes Tosa ClampOp attributes to min_val and max_val which are
either integer attributes or float attributes, and adds verify checks
that these attribute element types must match element types of input and
output

Co-authored-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants