Skip to content

[MLIR][Math] Add floating point value folders #127947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2025
Merged

[MLIR][Math] Add floating point value folders #127947

merged 1 commit into from
Feb 20, 2025

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented Feb 20, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: William Moses (wsmoses)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/127947.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+4)
  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+64)
  • (modified) mlir/test/Dialect/Math/canonicalize.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 16ce4e2366c76..56370388dea87 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -736,6 +736,7 @@ def Math_IsFiniteOp : Math_FloatClassificationOp<"isfinite"> {
     %f = math.isfinite %a : f32
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -754,6 +755,7 @@ def Math_IsInfOp : Math_FloatClassificationOp<"isinf"> {
     %f = math.isinf %a : f32
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -772,6 +774,7 @@ def Math_IsNaNOp : Math_FloatClassificationOp<"isnan"> {
     %f = math.isnan %a : f32
     ```
   }];
+  let hasFolder = 1;
 }
 
 
@@ -791,6 +794,7 @@ def Math_IsNormalOp : Math_FloatClassificationOp<"isnormal"> {
     %f = math.isnormal %a : f32
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 9c4d88e2191ce..26441a9d78658 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -579,6 +579,70 @@ OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// IsFiniteOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
+  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
+    return BoolAttr::get(val.getContext(), val.getValue().isFinite());
+  }
+  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
+    return DenseElementsAttr::get(
+        cast<ShapedType>(getType()),
+        APInt(1, splat.getSplatValue<APFloat>().isFinite()));
+  }
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// IsInfOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
+  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
+    return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
+  }
+  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
+    return DenseElementsAttr::get(
+        cast<ShapedType>(getType()),
+        APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
+  }
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// IsNaNOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
+  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
+    return BoolAttr::get(val.getContext(), val.getValue().isNaN());
+  }
+  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
+    return DenseElementsAttr::get(
+        cast<ShapedType>(getType()),
+        APInt(1, splat.getSplatValue<APFloat>().isNaN()));
+  }
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// IsNormalOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
+  if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
+    return BoolAttr::get(val.getContext(), val.getValue().isNormal());
+  }
+  if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
+    return DenseElementsAttr::get(
+        cast<ShapedType>(getType()),
+        APInt(1, splat.getSplatValue<APFloat>().isNormal()));
+  }
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TanOp folder
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index d24f7649269fe..f5c57f312aa7a 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -492,3 +492,75 @@ func.func @abs_poison() -> f32 {
   %1 = math.absf %0 : f32
   return %1 : f32
 }
+
+// CHECK-LABEL: @isfinite_fold
+// CHECK: %[[cst:.+]] = arith.constant true
+  // CHECK: return %[[cst]]
+func.func @isfinite_fold() -> i1 {
+  %c = arith.constant 2.0 : f32
+  %r = math.isfinite %c : f32
+  return %r : i1
+}
+
+// CHECK-LABEL: @isfinite_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<true>
+// CHECK: return %[[cst]]
+func.func @isfinite_fold_vec() -> (vector<4xi1>) {
+  %v1 = arith.constant dense<2.0> : vector<4xf32>
+  %0 = math.isfinite %v1 : vector<4xf32>
+  return %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: @isinf_fold
+// CHECK: %[[cst:.+]] = arith.constant false
+  // CHECK: return %[[cst]]
+func.func @isinf_fold() -> i1 {
+  %c = arith.constant 2.0 : f32
+  %r = math.isinf %c : f32
+  return %r : i1
+}
+
+// CHECK-LABEL: @isinf_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<false>
+// CHECK: return %[[cst]]
+func.func @isinf_fold_vec() -> (vector<4xi1>) {
+  %v1 = arith.constant dense<2.0> : vector<4xf32>
+  %0 = math.isinf %v1 : vector<4xf32>
+  return %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: @isnan_fold
+// CHECK: %[[cst:.+]] = arith.constant false
+  // CHECK: return %[[cst]]
+func.func @isnan_fold() -> i1 {
+  %c = arith.constant 2.0 : f32
+  %r = math.isnan %c : f32
+  return %r : i1
+}
+
+// CHECK-LABEL: @isnan_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<false>
+// CHECK: return %[[cst]]
+func.func @isnan_fold_vec() -> (vector<4xi1>) {
+  %v1 = arith.constant dense<2.0> : vector<4xf32>
+  %0 = math.isnan %v1 : vector<4xf32>
+  return %0 : vector<4xi1>
+}
+
+// CHECK-LABEL: @isnormal_fold
+// CHECK: %[[cst:.+]] = arith.constant true
+  // CHECK: return %[[cst]]
+func.func @isnormal_fold() -> i1 {
+  %c = arith.constant 2.0 : f32
+  %r = math.isnormal %c : f32
+  return %r : i1
+}
+
+// CHECK-LABEL: @isnormal_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<true>
+// CHECK: return %[[cst]]
+func.func @isnormal_fold_vec() -> (vector<4xi1>) {
+  %v1 = arith.constant dense<2.0> : vector<4xf32>
+  %0 = math.isnormal %v1 : vector<4xf32>
+  return %0 : vector<4xi1>
+}

Copy link
Contributor

@ivanradanov ivanradanov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good % nit

@wsmoses wsmoses merged commit 1b610e6 into main Feb 20, 2025
8 checks passed
@wsmoses wsmoses deleted the users/wm/mlirfld branch February 20, 2025 15:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants