Skip to content

[mlir][arith] Introduce minnumf and maxnumf operations #66429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 14, 2023

Conversation

unterumarmung
Copy link
Contributor

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Here we introduce new operations for floating-point numbers: minnum and maxnum.
These operations have different semantics than minumumf and maximumf ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.

This patch addresses the 1.3 task from the RFC.

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`.
These operations have different semantics than `minumumf` and `maximumf` ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.

This patch addresses the 1.3 task from the RFC.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:arith labels Sep 14, 2023
@unterumarmung unterumarmung requested review from a team September 14, 2023 20:49
@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Changes This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Here we introduce new operations for floating-point numbers: minnum and maxnum.
These operations have different semantics than minumumf and maximumf ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.

This patch addresses the 1.3 task from the RFC.

--
Full diff: https://github.com/llvm/llvm-project/pull/66429.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+55)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+41-4)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+34-4)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+12-6)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 07708cf2d78a964..58e5385bf3ff268 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
+  let summary = "floating-point maximum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the maximum of the two arguments.
+    If the arguments are -0.0 and +0.0, then the result is either of them.
+    If one of the arguments is NaN, then the result is the other argument.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point maximum.
+    %a = arith.maxnumf %b, %c : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
+
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
+  let summary = "floating-point minimum operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type
+    ```
+
+    Returns the minimum of the two arguments.
+    If the arguments are -0.0 and +0.0, then the result is either of them.
+    If one of the arguments is NaN, then the result is the other argument.
+
+    Example:
+
+    ```mlir
+    // Scalar floating-point minimum.
+    %a = arith.minnumf %b, %c : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1e34ac598860f52..d39c5b6051122e4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
-  // maxf(x,x) -> x
+  // maximumf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
 
-  // maxf(x, -inf) -> x
+  // maximumf(x, -inf) -> x
   if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
     return getLhs();
 
@@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
 }
 
+//===----------------------------------------------------------------------===//
+// MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
+  // maxnumf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // maxnumf(x, -inf) -> x
+  if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
+}
+
+
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
-  // minf(x,x) -> x
+  // minimumf(x,x) -> x
   if (getLhs() == getRhs())
     return getRhs();
 
-  // minf(x, +inf) -> x
+  // minimumf(x, +inf) -> x
   if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
     return getLhs();
 
@@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
 }
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
+  // minnumf(x,x) -> x
+  if (getLhs() == getRhs())
+    return getRhs();
+
+  // minnumf(x, +inf) -> x
+  if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
+    return getLhs();
+
+  return constFoldBinaryOp<FloatAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
+}
+
 //===----------------------------------------------------------------------===//
 // MinSIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5c93be887107bb6..84096354e6afe33 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {
 
 // -----
 
-// CHECK-LABEL: @test_minf(
-func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_minimumf(
+func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
   // CHECK-NEXT:  %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
@@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
 
 // -----
 
-// CHECK-LABEL: @test_maxf(
-func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
+// CHECK-LABEL: @test_maximumf(
+func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-DAG:   %[[C0:.+]] = arith.constant
   // CHECK-NEXT:  %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
@@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
 
 // -----
 
+// CHECK-LABEL: @test_minnumf(
+func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0.0
+  // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
+  // CHECK-NEXT:  return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %inf = arith.constant 0x7F800000 : f32
+  %0 = arith.minnumf %c0, %arg0 : f32
+  %1 = arith.minnumf %arg0, %arg0 : f32
+  %2 = arith.minnumf %inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_maxnumf(
+func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant
+  // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
+  // CHECK-NEXT:   return %[[X]], %arg0, %arg0
+  %c0 = arith.constant 0.0 : f32
+  %-inf = arith.constant 0xFF800000 : f32
+  %0 = arith.maxnumf %c0, %arg0 : f32
+  %1 = arith.maxnumf %arg0, %arg0 : f32
+  %2 = arith.maxnumf %-inf, %arg0 : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
 // CHECK-LABEL: @test_addf(
 func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
   // CHECK-DAG:   %[[C2:.+]] = arith.constant 2.0
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 5b5618bb03676bf..88cc0072c7c5704 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
                %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
-  %max_vector = arith.maximumf %v1, %v2 : vector<4xf32>
-  %max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
-  %max_float = arith.maximumf %f1, %f2 : f32
+  %maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32>
+  %maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32>
+  %maximum_float = arith.maximumf %f1, %f2 : f32
+  %maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32>
+  %maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32>
+  %maxnum_float = arith.maxnumf %f1, %f2 : f32
   %max_signed = arith.maxsi %i1, %i2 : i32
   %max_unsigned = arith.maxui %i1, %i2 : i32
   return
@@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
                %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>,
                %f1: f32, %f2: f32,
                %i1: i32, %i2: i32) {
-  %min_vector = arith.minimumf %v1, %v2 : vector<4xf32>
-  %min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
-  %min_float = arith.minimumf %f1, %f2 : f32
+  %minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32>
+  %minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32>
+  %minimum_float = arith.minimumf %f1, %f2 : f32
+  %minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32>
+  %minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32>
+  %minnum_float = arith.minnumf %f1, %f2 : f32
   %min_signed = arith.minsi %i1, %i2 : i32
   %min_unsigned = arith.minui %i1, %i2 : i32
   return

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the lowering to LLVM coming next?

@unterumarmung
Copy link
Contributor Author

Is the lowering to LLVM coming next?

Yes, please take a look: #66431

@unterumarmung unterumarmung merged commit ca8cba7 into llvm:main Sep 14, 2023
unterumarmung added a commit to unterumarmung/llvm-project that referenced this pull request Sep 14, 2023
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics.

Please **note**: this PR is part of a stack of patches and depends on llvm#66429.
unterumarmung added a commit that referenced this pull request Sep 14, 2023
This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
    
The commit addresses the task 1.4 of the RFC by adding LLVM lowering to
the corresponding LLVM intrinsics.
    
Please **note**: this PR is part of a stack of patches and depends on
#66429.
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

Here we introduce new operations for floating-point numbers: `minnum`
and `maxnum`.
These operations have different semantics than `minumumf` and `maximumf`
ops.
They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.

This patch addresses the 1.3 task from the RFC.
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
)

This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.
    
The commit addresses the task 1.4 of the RFC by adding LLVM lowering to
the corresponding LLVM intrinsics.
    
Please **note**: this PR is part of a stack of patches and depends on
llvm#66429.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:arith mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants