Skip to content

[mlir][arith] Improve extf folder #80232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 2, 2024
Merged

[mlir][arith] Improve extf folder #80232

merged 3 commits into from
Feb 2, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Feb 1, 2024

  • Use APFloat conversion function to avoid losing information by converting to double. This would be the case with large types like f80 or f128.
  • Check for potential information loss. This is intended for small floating point types that may have values not present in larger ones (e.g., f8e5m2fnuz and f16).
  • Support folding vector constants.

@llvmbot
Copy link
Member

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes
  • Use APFloat conversion function to avoid losing information by converting to double. This would be the case with large types like f80 or f128.
  • Check for potential information loss. This is intended for small floating point types that may have values not present in larger ones (e.g., negative zero).
  • Support folding vector constants.

Full diff: https://github.com/llvm/llvm-project/pull/80232.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+36-11)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+9)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 270df3f3f9e99..7b026310ab7d6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -9,6 +9,7 @@
 #include <cassert>
 #include <cstdint>
 #include <functional>
+#include <optional>
 #include <utility>
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1258,6 +1259,21 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
                                      srcType.getIntOrFloatBitWidth());
 }
 
+/// Attempts to convert `sourceValue` to an APFloat value with
+/// `targetSemantics`, without any information loss or rounding. Return
+/// std::nullopt on failure.
+static std::optional<APFloat>
+convertFloatValue(APFloat sourceValue,
+                  const llvm::fltSemantics &targetSemantics) {
+  bool losesInfo = false;
+  auto status = sourceValue.convert(
+      targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+  if (losesInfo || status != APFloat::opOK)
+    return std::nullopt;
+
+  return sourceValue;
+}
+
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
@@ -1321,14 +1337,21 @@ LogicalResult arith::ExtSIOp::verify() {
 // ExtFOp
 //===----------------------------------------------------------------------===//
 
-/// Always fold extension of FP constants.
+/// Fold extension of float constants when there is no information loss due the
+/// difference in fp semantics.
 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
-  if (!constOperand)
-    return {};
+  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+  return constFoldCastOp<FloatAttr, FloatAttr>(
+      adaptor.getOperands(), getType(),
+      [&targetSemantics](const APFloat &a, bool &castStatus) {
+        if (std::optional<APFloat> result =
+                convertFloatValue(a, targetSemantics))
+          return *result;
 
-  // Convert to target type via 'double'.
-  return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
+        castStatus = false;
+        return a;
+      });
 }
 
 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1403,11 +1426,12 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
       adaptor.getOperands(), getType(),
-      [&targetSemantics](APFloat a, bool &castStatus) {
-        bool losesInfo = false;
-        auto status = a.convert(
-            targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
-        castStatus = !losesInfo && status == APFloat::opOK;
+      [&targetSemantics](const APFloat &a, bool &castStatus) {
+        if (std::optional<APFloat> result =
+                convertFloatValue(a, targetSemantics))
+          return *result;
+
+        castStatus = false;
         return a;
       });
 }
@@ -1496,6 +1520,7 @@ OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
         return apf;
       });
 }
+
 //===----------------------------------------------------------------------===//
 // FPToUIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 44df11ab2433a..57a71bcc8feeb 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -701,6 +701,15 @@ func.func @extFPConstant() -> f64 {
   return %0 : f64
 }
 
+// CHECK-LABEL: @extFPVectorConstant
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf128>
+//       CHECK:   return %[[cres]]
+func.func @extFPVectorConstant() -> vector<2xf128> {
+  %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf80>
+  %0 = arith.extf %cst : vector<2xf80> to vector<2xf128>
+  return %0 : vector<2xf128>
+}
+
 // CHECK-LABEL: @truncConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]

@llvmbot
Copy link
Member

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-mlir-arith

Author: Jakub Kuderski (kuhar)

Changes
  • Use APFloat conversion function to avoid losing information by converting to double. This would be the case with large types like f80 or f128.
  • Check for potential information loss. This is intended for small floating point types that may have values not present in larger ones (e.g., negative zero).
  • Support folding vector constants.

Full diff: https://github.com/llvm/llvm-project/pull/80232.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+36-11)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+9)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 270df3f3f9e99..7b026310ab7d6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -9,6 +9,7 @@
 #include <cassert>
 #include <cstdint>
 #include <functional>
+#include <optional>
 #include <utility>
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1258,6 +1259,21 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
                                      srcType.getIntOrFloatBitWidth());
 }
 
+/// Attempts to convert `sourceValue` to an APFloat value with
+/// `targetSemantics`, without any information loss or rounding. Return
+/// std::nullopt on failure.
+static std::optional<APFloat>
+convertFloatValue(APFloat sourceValue,
+                  const llvm::fltSemantics &targetSemantics) {
+  bool losesInfo = false;
+  auto status = sourceValue.convert(
+      targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+  if (losesInfo || status != APFloat::opOK)
+    return std::nullopt;
+
+  return sourceValue;
+}
+
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
@@ -1321,14 +1337,21 @@ LogicalResult arith::ExtSIOp::verify() {
 // ExtFOp
 //===----------------------------------------------------------------------===//
 
-/// Always fold extension of FP constants.
+/// Fold extension of float constants when there is no information loss due the
+/// difference in fp semantics.
 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
-  if (!constOperand)
-    return {};
+  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+  return constFoldCastOp<FloatAttr, FloatAttr>(
+      adaptor.getOperands(), getType(),
+      [&targetSemantics](const APFloat &a, bool &castStatus) {
+        if (std::optional<APFloat> result =
+                convertFloatValue(a, targetSemantics))
+          return *result;
 
-  // Convert to target type via 'double'.
-  return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
+        castStatus = false;
+        return a;
+      });
 }
 
 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1403,11 +1426,12 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
       adaptor.getOperands(), getType(),
-      [&targetSemantics](APFloat a, bool &castStatus) {
-        bool losesInfo = false;
-        auto status = a.convert(
-            targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
-        castStatus = !losesInfo && status == APFloat::opOK;
+      [&targetSemantics](const APFloat &a, bool &castStatus) {
+        if (std::optional<APFloat> result =
+                convertFloatValue(a, targetSemantics))
+          return *result;
+
+        castStatus = false;
         return a;
       });
 }
@@ -1496,6 +1520,7 @@ OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
         return apf;
       });
 }
+
 //===----------------------------------------------------------------------===//
 // FPToUIOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 44df11ab2433a..57a71bcc8feeb 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -701,6 +701,15 @@ func.func @extFPConstant() -> f64 {
   return %0 : f64
 }
 
+// CHECK-LABEL: @extFPVectorConstant
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf128>
+//       CHECK:   return %[[cres]]
+func.func @extFPVectorConstant() -> vector<2xf128> {
+  %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf80>
+  %0 = arith.extf %cst : vector<2xf80> to vector<2xf128>
+  return %0 : vector<2xf128>
+}
+
 // CHECK-LABEL: @truncConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]

kuhar added 2 commits February 2, 2024 16:22
* Use APFloat conversion function to avoid losing information by
  converting to `double`. This would be the case with large types like
  `f80` or `f128`.
* Check for potential information loss. This is intended for small
  floating point types that may have values not present in larger ones
  (e.g., negative zero).
* Support folding vector constants.
@kuhar kuhar merged commit 5294ad1 into llvm:main Feb 2, 2024
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
* Use APFloat conversion function to avoid losing information by
converting to `double`. This would be the case with large types like
`f80` or `f128`.
* Check for potential information loss. This is intended for small
floating point types that may have values not present in larger ones
(e.g., f8m2e5fnuz and f16).
* Support folding vector constants.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants