-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][arith] Add LLVM lowering for maxnumf
, minnumf
ops
#66431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith ChangesThis patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics. Please note: this PR is part of a stack of patches and depends on #66429.Full diff: https://github.com/llvm/llvm-project/pull/66431.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 07708cf2d78a964..58e5385bf3ff268 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -857,6 +857,34 @@ def Arith_MaximumFOp : Arith_FloatBinaryOp<"maximumf", [Commutative]> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MaxNumFOp +//===----------------------------------------------------------------------===// + +def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> { + let summary = "floating-point maximum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arith.maxnumf` ssa-use `,` ssa-use `:` type + ``` + + Returns the maximum of the two arguments. + If the arguments are -0.0 and +0.0, then the result is either of them. + If one of the arguments is NaN, then the result is the other argument. + + Example: + + ```mlir + // Scalar floating-point maximum. + %a = arith.maxnumf %b, %c : f64 + ``` + }]; + let hasFolder = 1; +} + + //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// @@ -901,6 +929,33 @@ def Arith_MinimumFOp : Arith_FloatBinaryOp<"minimumf", [Commutative]> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MinNumFOp +//===----------------------------------------------------------------------===// + +def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> { + let summary = "floating-point minimum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arith.minnumf` ssa-use `,` ssa-use `:` type + ``` + + Returns the minimum of the two arguments. + If the arguments are -0.0 and +0.0, then the result is either of them. + If one of the arguments is NaN, then the result is the other argument. + + Example: + + ```mlir + // Scalar floating-point minimum. + %a = arith.minnumf %b, %c : f64 + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index a695441fd8dd750..337f2dbcbe4edf5 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -57,6 +57,9 @@ using FPToUIOpLowering = using MaximumFOpLowering = VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, arith::AttrConvertFastMathToLLVM>; +using MaxNumFOpLowering = + VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, + arith::AttrConvertFastMathToLLVM>; using MaxSIOpLowering = VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; using MaxUIOpLowering = @@ -64,6 +67,9 @@ using MaxUIOpLowering = using MinimumFOpLowering = VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, arith::AttrConvertFastMathToLLVM>; +using MinNumFOpLowering = + VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, + arith::AttrConvertFastMathToLLVM>; using MinSIOpLowering = VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; using MinUIOpLowering = @@ -496,9 +502,11 @@ void mlir::arith::populateArithToLLVMConversionPatterns( IndexCastOpSILowering, IndexCastOpUILowering, MaximumFOpLowering, + MaxNumFOpLowering, MaxSIOpLowering, MaxUIOpLowering, MinimumFOpLowering, + MinNumFOpLowering, MinSIOpLowering, MinUIOpLowering, MulFOpLowering, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 1e34ac598860f52..d39c5b6051122e4 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -927,11 +927,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { - // maxf(x,x) -> x + // maximumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); - // maxf(x, -inf) -> x + // maximumf(x, -inf) -> x if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) return getLhs(); @@ -940,6 +940,25 @@ OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); } +//===----------------------------------------------------------------------===// +// MaxNumFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { + // maxnumf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // maxnumf(x, -inf) -> x + if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) + return getLhs(); + + return constFoldBinaryOp<FloatAttr>( + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); +} + + //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// @@ -995,11 +1014,11 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { - // minf(x,x) -> x + // minimumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); - // minf(x, +inf) -> x + // minimumf(x, +inf) -> x if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) return getLhs(); @@ -1008,6 +1027,24 @@ OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); } +//===----------------------------------------------------------------------===// +// MinNumFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { + // minnumf(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + // minnumf(x, +inf) -> x + if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) + return getLhs(); + + return constFoldBinaryOp<FloatAttr>( + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); }); +} + //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index 5855f7b3b9904fd..6f614b113788c7e 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -526,6 +526,10 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 { %0 = arith.minimumf %arg0, %arg1 : f32 // CHECK: = llvm.intr.maximum(%arg0, %arg1) : (f32, f32) -> f32 %1 = arith.maximumf %arg0, %arg1 : f32 + // CHECK: = llvm.intr.minnum(%arg0, %arg1) : (f32, f32) -> f32 + %2 = arith.minnumf %arg0, %arg1 : f32 + // CHECK: = llvm.intr.maxnum(%arg0, %arg1) : (f32, f32) -> f32 + %3 = arith.maxnumf %arg0, %arg1 : f32 return %0 : f32 } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 5c93be887107bb6..84096354e6afe33 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1635,8 +1635,8 @@ func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) { // ----- -// CHECK-LABEL: @test_minf( -func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) { +// CHECK-LABEL: @test_minimumf( +func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 // CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 @@ -1650,8 +1650,8 @@ func.func @test_minf(%arg0 : f32) -> (f32, f32, f32) { // ----- -// CHECK-LABEL: @test_maxf( -func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) { +// CHECK-LABEL: @test_maximumf( +func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) { // CHECK-DAG: %[[C0:.+]] = arith.constant // CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]] // CHECK-NEXT: return %[[X]], %arg0, %arg0 @@ -1665,6 +1665,36 @@ func.func @test_maxf(%arg0 : f32) -> (f32, f32, f32) { // ----- +// CHECK-LABEL: @test_minnumf( +func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.0 + // CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0, %arg0 + %c0 = arith.constant 0.0 : f32 + %inf = arith.constant 0x7F800000 : f32 + %0 = arith.minnumf %c0, %arg0 : f32 + %1 = arith.minnumf %arg0, %arg0 : f32 + %2 = arith.minnumf %inf, %arg0 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: @test_maxnumf( +func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant + // CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]] + // CHECK-NEXT: return %[[X]], %arg0, %arg0 + %c0 = arith.constant 0.0 : f32 + %-inf = arith.constant 0xFF800000 : f32 + %0 = arith.maxnumf %c0, %arg0 : f32 + %1 = arith.maxnumf %arg0, %arg0 : f32 + %2 = arith.maxnumf %-inf, %arg0 : f32 + return %0, %1, %2 : f32, f32, f32 +} + +// ----- + // CHECK-LABEL: @test_addf( func.func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) { // CHECK-DAG: %[[C2:.+]] = arith.constant 2.0 diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir index 5b5618bb03676bf..88cc0072c7c5704 100644 --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1071,9 +1071,12 @@ func.func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>, %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>, %f1: f32, %f2: f32, %i1: i32, %i2: i32) { - %max_vector = arith.maximumf %v1, %v2 : vector<4xf32> - %max_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32> - %max_float = arith.maximumf %f1, %f2 : f32 + %maximum_vector = arith.maximumf %v1, %v2 : vector<4xf32> + %maximum_scalable_vector = arith.maximumf %sv1, %sv2 : vector<[4]xf32> + %maximum_float = arith.maximumf %f1, %f2 : f32 + %maxnum_vector = arith.maxnumf %v1, %v2 : vector<4xf32> + %maxnum_scalable_vector = arith.maxnumf %sv1, %sv2 : vector<[4]xf32> + %maxnum_float = arith.maxnumf %f1, %f2 : f32 %max_signed = arith.maxsi %i1, %i2 : i32 %max_unsigned = arith.maxui %i1, %i2 : i32 return @@ -1084,9 +1087,12 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>, %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>, %f1: f32, %f2: f32, %i1: i32, %i2: i32) { - %min_vector = arith.minimumf %v1, %v2 : vector<4xf32> - %min_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32> - %min_float = arith.minimumf %f1, %f2 : f32 + %minimum_vector = arith.minimumf %v1, %v2 : vector<4xf32> + %minimum_scalable_vector = arith.minimumf %sv1, %sv2 : vector<[4]xf32> + %minimum_float = arith.minimumf %f1, %f2 : f32 + %minnum_vector = arith.minnumf %v1, %v2 : vector<4xf32> + %minnum_scalable_vector = arith.minnumf %sv1, %sv2 : vector<[4]xf32> + %minnum_float = arith.minnumf %f1, %f2 : f32 %min_signed = arith.minsi %i1, %i2 : i32 %min_unsigned = arith.minui %i1, %i2 : i32 return |
maxnum
, minnum
opsmaxnum
, minnum
ops
maxnum
, minnum
opsmaxnumf
, minnumf
ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks!
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics. Please **note**: this PR is part of a stack of patches and depends on llvm#66429.
e2188fe
to
1f1c922
Compare
) This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics. Please **note**: this PR is part of a stack of patches and depends on llvm#66429.
This patch is part of a larger initiative aimed at fixing floating-point
max
andmin
operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.The commit addresses the task 1.4 of the RFC by adding LLVM lowering to the corresponding LLVM intrinsics.
Please note: this PR is part of a stack of patches and depends on #66429.