-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][arith] Introduce minnumf
and maxnumf
operations
#66429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`. These operations have different semantics than `minumumf` and `maximumf` ops. They follow the eponymous LLVM intrinsics semantics, which differs in the handling positive and negative zeros and NaNs. This patch addresses the 1.3 task from the RFC.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir ChangesThis patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.Here we introduce new operations for floating-point numbers: This patch addresses the 1.3 task from the RFC. -- 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 07708cf2d78a964..58e5385bf3ff268 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MaxNumFOp +//===----------------------------------------------------------------------===// + +def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> { + let summary = "floating-point maximum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type + ``` + + Returns the maximum of the two arguments. + If the arguments are -0.0 and +0.0, then the result is either of them. + If one of the arguments is NaN, then the result is the other argument. + + Example: + + ```mlir + // Scalar floating-point maximum. + %a = arith.maxnumf %b, %c : f64 + ``` + }]; + let hasFolder = 1; +} + + //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// @@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MinNumFOp +//===----------------------------------------------------------------------===// + +def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> { + let summary = "floating-point minimum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type + ``` + + Returns the minimum of the two arguments. + If the arguments are -0.0 and +0.0, then the result is either of them. + If one of the arguments is NaN, then the result is the other argument. + + Example: + + ```mlir + // Scalar floating-point minimum. + %a = arith.minnumf %b, %c : f64 + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 1e34ac598860f52..d39c5b6051122e4 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { - // maxf(x,x) -> x + // maximumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); - // maxf(x, -inf) -> x + // maximumf(x, -inf) -> x if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) return getLhs(); @@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); } +//===----------------------------------------------------------------------===// +// MaxNumFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { + // maxnumf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // maxnumf(x, -inf) -> x + if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) + return getLhs(); + + return constFoldBinaryOp<FloatAttr>( + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); +} + + //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// @@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { - // minf(x,x) -> x + // minimumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); - // minf(x, +inf) -> x + // minimumf(x, +inf) -> x if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) return getLhs(); @@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); } +//===----------------------------------------------------------------------===// +// MinNumFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { + // minnumf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // minnumf(x, +inf) -> x + if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) + return getLhs(); + + return constFoldBinaryOp<FloatAttr>( + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); }); +} + //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 5c93be887107bb6..84096354e6afe33 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- -// CHECK-LABEL: @test_minf( -func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) { +// CHECK-LABEL: @test_minimumf( +func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 // CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 @@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) { // ----- -// CHECK-LABEL: @test_maxf( -func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) { +// CHECK-LABEL: @test_maximumf( +func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-DAG: %[[C0:.+]] = arith.constant // CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 @@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) { // ----- +// CHECK-LABEL: @test_minnumf( +func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 + // CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0, %arg0 + %c0 = arith.constant 0.0 : f32 + %inf = arith.constant 0x7F800000 : f32 + %0 = arith.minnumf %c0, %arg0 : f32 + %1 = arith.minnumf %arg0, %arg0 : f32 + %2 = arith.minnumf %inf, %arg0 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_maxnumf( +func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant + // CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0, %arg0 + %c0 = arith.constant 0.0 : f32 + %-inf = arith.constant 0xFF800000 : f32 + %0 = arith.maxnumf %c0, %arg0 : f32 + %1 = arith.maxnumf %arg0, %arg0 : f32 + %2 = arith.maxnumf %-inf, %arg0 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + // CHECK-LABEL: @test_addf( func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) { // CHECK-DAG: %[[C2:.+]] = arith.constant 2.0 diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 5b5618bb03676bf..88cc0072c7c5704 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>, %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>, %f1: f32, %f2: f32, %i1: i32, %i2: i32) { - %max_vector = arith.maximumf %v1, %v2 : vector<4xf32> - %max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32> - %max_float = arith.maximumf %f1, %f2 : f32 + %maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32> + %maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32> + %maximum_float = arith.maximumf %f1, %f2 : f32 + %maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32> + %maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32> + %maxnum_float = arith.maxnumf %f1, %f2 : f32 %max_signed = arith.maxsi %i1, %i2 : i32 %max_unsigned = arith.maxui %i1, %i2 : i32 return @@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>, %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>, %f1: f32, %f2: f32, %i1: i32, %i2: i32) { - %min_vector = arith.minimumf %v1, %v2 : vector<4xf32> - %min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32> - %min_float = arith.minimumf %f1, %f2 : f32 + %minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32> + %minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32> + %minimum_float = arith.minimumf %f1, %f2 : f32 + %minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32> + %minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32> + %minnum_float = arith.minnumf %f1, %f2 : f32 %min_signed = arith.minsi %i1, %i2 : i32 %min_unsigned = arith.minui %i1, %i2 : i32 return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the lowering to LLVM coming next?
Yes, please take a look: #66431 |
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics. Please **note**: this PR is part of a stack of patches and depends on llvm#66429.
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics. Please **note**: this PR is part of a stack of patches and depends on #66429.
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. Here we introduce new operations for floating-point numbers: `minnum` and `maxnum`. These operations have different semantics than `minumumf` and `maximumf` ops. They follow the eponymous LLVM intrinsics semantics, which differs in the handling positive and negative zeros and NaNs. This patch addresses the 1.3 task from the RFC.
) This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics. Please **note**: this PR is part of a stack of patches and depends on llvm#66429.
This patch is part of a larger initiative aimed at fixing floating-point
max
andmin
operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.Here we introduce new operations for floating-point numbers:
minnum
andmaxnum
.These operations have different semantics than
minumumf
andmaximumf
ops.They follow the eponymous LLVM intrinsics semantics, which differs
in the handling positive and negative zeros and NaNs.
This patch addresses the 1.3 task from the RFC.