Skip to content

[mlir] Refactor LegalizeToF32 to specify extra supported float types and target type as arguments #108815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 27, 2024

Conversation

dhernandez0
Copy link
Contributor

@dhernandez0 dhernandez0 commented Sep 16, 2024

Instead of hardcoding all fp smaller than 32 bits are unsupported we provide a way to pass supported floating point types as well as the target type. fp64 and fp32 are implicitly supported.

CC: @krzysz00 @manupak

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Sep 16, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir-arith

Author: Daniel Hernandez-Juarez (dhernandez0)

Changes

Instead of hardcoding all fp smaller than 32 bits are unsupported we provide a way to pass supported floating point types as well as the target type. fp64 and fp32 are implicitly supported.


Patch is 24.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108815.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Utils/Utils.h (+4)
  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+4-3)
  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.td (+10-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+3-24)
  • (modified) mlir/lib/Dialect/Arith/Utils/Utils.cpp (+19)
  • (modified) mlir/lib/Dialect/Math/Transforms/CMakeLists.txt (+1-1)
  • (added) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+164)
  • (removed) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (-118)
  • (added) mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir (+137)
  • (renamed) mlir/test/Dialect/Math/extend-to-supported-types.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 76f5825025739b..81da2d208d7269 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -130,6 +130,10 @@ namespace arith {
 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
                     Type resultType);
+                    
+// Map strings to float types.
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
+
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 2dd7f6431f03e1..b6ee5bc93ff694 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -56,10 +56,11 @@ void populateMathPolynomialApproximationPatterns(
 void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
 
 namespace math {
-void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
-void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
+void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+    Type targetType);
+void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target,
                                            TypeConverter &typeConverter);
-void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
+void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns,
                                    TypeConverter &typeConverter);
 } // namespace math
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index e870e714bfda58..a84c89020d4f36 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -19,7 +19,7 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
   let dependentDialects = ["math::MathDialect"];
 }
 
-def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
+def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
   let summary = "Legalize floating-point math ops on low-precision floats";
   let description = [{
     On many targets, the math functions are not implemented for floating-point
@@ -28,11 +28,19 @@ def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
 
     This pass explicitly legalizes these math functions by inserting
     `arith.extf` and `arith.truncf` pairs around said op, which preserves
-    the original semantics while enabling lowering.
+    the original semantics while enabling lowering. The extra supported floating-point
+    types for the target are passed as arguments. Types f64 and f32 are implicitly 
+    supported.
 
     As an exception, this pass does not legalize `math.fma`, because
     that is an operation frequently implemented at low precisions.
   }];
+  let options = [
+    ListOption<"extraTypeStrs", "extra-types", "std::string",
+      "MLIR types with arithmetic support on a given target (f64 and f32 are implicitly supported)">,
+    Option<"targetTypeStr", "target-type", "std::string", "\"f32\"",
+      "MLIR type to convert the unsupported source types to">,
+  ];
   let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
 }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a5ee6edc6320d5..eb8f74270d2ef8 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
@@ -49,28 +50,6 @@ struct EmulateFloatPattern final : ConversionPattern {
 };
 } // end namespace
 
-/// Map strings to float types. This function is here because no one else needs
-/// it yet, feel free to abstract it out.
-static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
-                                               StringRef name) {
-  Builder b(ctx);
-  return llvm::StringSwitch<std::optional<FloatType>>(name)
-      .Case("f6E3M2FN", b.getFloat6E3M2FNType())
-      .Case("f8E5M2", b.getFloat8E5M2Type())
-      .Case("f8E4M3", b.getFloat8E4M3Type())
-      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
-      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
-      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
-      .Case("f8E3M4", b.getFloat8E3M4Type())
-      .Case("bf16", b.getBF16Type())
-      .Case("f16", b.getF16Type())
-      .Case("f32", b.getF32Type())
-      .Case("f64", b.getF64Type())
-      .Case("f80", b.getF80Type())
-      .Case("f128", b.getF128Type())
-      .Default(std::nullopt);
-}
-
 LogicalResult EmulateFloatPattern::match(Operation *op) const {
   if (getTypeConverter()->isLegal(op))
     return failure();
@@ -154,7 +133,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
   SmallVector<Type> sourceTypes;
   Type targetType;
 
-  std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
+  std::optional<FloatType> maybeTargetType = arith::parseFloatType(ctx, targetTypeStr);
   if (!maybeTargetType) {
     emitError(UnknownLoc::get(ctx), "could not map target type '" +
                                         targetTypeStr +
@@ -164,7 +143,7 @@ void EmulateUnsupportedFloatsPass::runOnOperation() {
   targetType = *maybeTargetType;
   for (StringRef sourceTypeStr : sourceTypeStrs) {
     std::optional<FloatType> maybeSourceType =
-        parseFloatType(ctx, sourceTypeStr);
+        arith::parseFloatType(ctx, sourceTypeStr);
     if (!maybeSourceType) {
       emitError(UnknownLoc::get(ctx), "could not map source type '" +
                                           sourceTypeStr +
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index e75db84b75e280..c4d11f5474a15e 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -357,4 +357,23 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
       [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
 }
 
+/// Map strings to float types.
+std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
+  Builder b(ctx);
+
+  return llvm::StringSwitch<std::optional<FloatType>>(name)
+      .Case("f8E5M2", b.getFloat8E5M2Type())
+      .Case("f8E4M3", b.getFloat8E4M3Type())
+      .Case("f8E4M3FN", b.getFloat8E4M3FNType())
+      .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
+      .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
+      .Case("bf16", b.getBF16Type())
+      .Case("f16", b.getF16Type())
+      .Case("f32", b.getF32Type())
+      .Case("f64", b.getF64Type())
+      .Case("f80", b.getF80Type())
+      .Case("f128", b.getF128Type())
+      .Default(std::nullopt);
+}
+
 } // namespace mlir::arith
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2a5b4fbcb52712..e1c0c2410c1269 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
   ExpandPatterns.cpp
-  LegalizeToF32.cpp
+  ExtendToSupportedTypes.cpp
   PolynomialApproximation.cpp
   UpliftToFMA.cpp
 
diff --git a/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
new file mode 100644
index 00000000000000..fff2a962333017
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp
@@ -0,0 +1,164 @@
+//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements legalizing math operations on unsupported floating-point
+// types through arith.extf and arith.truncf.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+
+namespace {
+struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
+  ExtendToSupportedTypesRewritePattern(TypeConverter &converter,
+                                       MLIRContext *context)
+      : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+struct ExtendToSupportedTypesPass
+    : mlir::math::impl::MathExtendToSupportedTypesBase<
+          ExtendToSupportedTypesPass> {
+  using math::impl::MathExtendToSupportedTypesBase<
+      ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void mlir::math::populateExtendToSupportedTypesTypeConverter(
+    TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
+    Type targetType) {
+
+  typeConverter.addConversion(
+      [](Type type) -> std::optional<Type> { return type; });
+  typeConverter.addConversion(
+      [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
+        if (!sourceTypes.contains(type))
+          return targetType;
+
+        return std::nullopt;
+      });
+  typeConverter.addConversion(
+      [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
+        if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
+          if (!sourceTypes.contains(type))
+            return type.clone(targetType);
+
+        return std::nullopt;
+      });
+  typeConverter.addTargetMaterialization(
+      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
+        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+        extFOp.setFastmath(arith::FastMathFlags::contract);
+        return extFOp;
+      });
+}
+
+void mlir::math::populateExtendToSupportedTypesConversionTarget(
+    ConversionTarget &target, TypeConverter &typeConverter) {
+  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
+    if (isa<MathDialect>(op->getDialect()))
+      return typeConverter.isLegal(op);
+    return true;
+  });
+  target.addLegalOp<FmaOp>();
+  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
+}
+
+LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op->getLoc();
+  const TypeConverter *converter = getTypeConverter();
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
+
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
+    if (newType != origType) {
+      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
+      truncFOp.setFastmath(arith::FastMathFlags::contract);
+      result = truncFOp.getResult();
+    }
+  }
+  rewriter.replaceOp(op, results);
+  return success();
+}
+
+void mlir::math::populateExtendToSupportedTypesPatterns(
+    RewritePatternSet &patterns, TypeConverter &typeConverter) {
+  patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
+                                                     patterns.getContext());
+}
+
+void ExtendToSupportedTypesPass::runOnOperation() {
+  Operation *op = getOperation();
+  MLIRContext *ctx = &getContext();
+
+  // Parse target type
+  std::optional<Type> maybeTargetType =
+      arith::parseFloatType(ctx, targetTypeStr);
+  if (!maybeTargetType.has_value()) {
+    emitError(UnknownLoc::get(ctx), "could not map target type '" +
+                                        targetTypeStr +
+                                        "' to a known floating-point type");
+    return signalPassFailure();
+  }
+  Type targetType = maybeTargetType.value();
+
+  // Parse source types
+  llvm::SetVector<Type> sourceTypes;
+  for (const auto &extraTypeStr : extraTypeStrs) {
+    std::optional<FloatType> maybeExtraType =
+        arith::parseFloatType(ctx, extraTypeStr);
+    if (!maybeExtraType.has_value()) {
+      emitError(UnknownLoc::get(ctx), "could not map source type '" +
+                                          extraTypeStr +
+                                          "' to a known floating-point type");
+      return signalPassFailure();
+    }
+    sourceTypes.insert(maybeExtraType.value());
+  }
+  // f64 and f32 are implicitly supported
+  Builder b(ctx);
+  sourceTypes.insert(b.getF64Type());
+  sourceTypes.insert(b.getF32Type());
+
+  TypeConverter typeConverter;
+  math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
+                                                    targetType);
+  ConversionTarget target(*ctx);
+  math::populateExtendToSupportedTypesConversionTarget(target, typeConverter);
+  RewritePatternSet patterns(ctx);
+  math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    return signalPassFailure();
+}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
deleted file mode 100644
index 2e60fe455dcade..00000000000000
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ /dev/null
@@ -1,118 +0,0 @@
-//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements legalizing math operations on small floating-point
-// types through arith.extf and arith.truncf.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/STLExtras.h"
-
-namespace mlir::math {
-#define GEN_PASS_DEF_MATHLEGALIZETOF32
-#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
-} // namespace mlir::math
-
-using namespace mlir;
-namespace {
-struct LegalizeToF32RewritePattern final : ConversionPattern {
-  LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
-      : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-
-struct LegalizeToF32Pass final
-    : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
-  void runOnOperation() override;
-};
-} // namespace
-
-void mlir::math::populateLegalizeToF32TypeConverter(
-    TypeConverter &typeConverter) {
-  typeConverter.addConversion(
-      [](Type type) -> std::optional<Type> { return type; });
-  typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
-    if (type.getWidth() < 32)
-      return Float32Type::get(type.getContext());
-    return std::nullopt;
-  });
-  typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
-    if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
-      return type.clone(Float32Type::get(type.getContext()));
-    return std::nullopt;
-  });
-  typeConverter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
-}
-
-void mlir::math::populateLegalizeToF32ConversionTarget(
-    ConversionTarget &target, TypeConverter &typeConverter) {
-  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
-    if (isa<MathDialect>(op->getDialect()))
-      return typeConverter.isLegal(op);
-    return true;
-  });
-  target.addLegalOp<FmaOp>();
-  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
-}
-
-LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  Location loc = op->getLoc();
-  const TypeConverter *converter = getTypeConverter();
-  FailureOr<Operation *> legalized =
-      convertOpResultTypes(op, operands, *converter, rewriter);
-  if (failed(legalized))
-    return failure();
-
-  SmallVector<Value> results = (*legalized)->getResults();
-  for (auto [result, newType, origType] : llvm::zip_equal(
-           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
-    if (newType != origType) {
-      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      result = truncFOp.getResult();
-    }
-  }
-  rewriter.replaceOp(op, results);
-  return success();
-}
-
-void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
-                                               TypeConverter &typeConverter) {
-  patterns.add<LegalizeToF32RewritePattern>(typeConverter,
-                                            patterns.getContext());
-}
-
-void LegalizeToF32Pass::runOnOperation() {
-  Operation *op = getOperation();
-  MLIRContext &ctx = getContext();
-
-  TypeConverter typeConverter;
-  math::populateLegalizeToF32TypeConverter(typeConverter);
-  ConversionTarget target(ctx);
-  math::populateLegalizeToF32ConversionTarget(target, typeConverter);
-  RewritePatternSet patterns(&ctx);
-  math::populateLegalizeToF32Patterns(patterns, typeConverter);
-  if (failed(applyPartialConversion(op, target, std::move(patterns))))
-    return signalPassFailure();
-}
diff --git a/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
new file mode 100644
index 00000000000000..6a61b3ba6251fa
--- /dev/null
+++ b/mlir/test/Dialect/Math/extend-to-supported-types-f16.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt %s --split-input-file -math-extend-to-supported-types="extra-types=f16 target-type=f32" | FileCheck %s
+
+// CHECK-LABEL: @sin_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2)
+func.func @sin_f8E5M2(%arg0: f8E5M2) -> f8E5M2 {
+  // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+  // CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+  // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+  // CHECK: return [[TRUNCF]] : f8E5M2
+  %0 = math.sin %arg0 : f8E5M2
+  return %0 : f8E5M2
+}
+
+// CHECK-LABEL: @sin
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+func.func @sin(%arg0: f16) -> f16 {
+  // CHECK16: [[SIN:%.+]] = math.sin [[ARG0]] : f16
+  // CHECK16: return [[SIN]] : f16
+  %0 = math.sin %arg0 : f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @fpowi_f8E5M2
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2, [[ARG1:%.+]]: i32)
+func.func @fpowi_f8...
[truncated]

Copy link

github-actions bot commented Sep 16, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@dhernandez0 dhernandez0 force-pushed the legalize-f16-supported-fp branch 4 times, most recently from 413f1d4 to e2db2f4 Compare September 17, 2024 14:07
…get type as arguments

Instead of hardcoding all fp smaller than 32 bits are unsupported we provide a way to pass supported floating point types as well as the target type.
Copy link
Contributor

@sergey-kozub sergey-kozub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this pass actually used? I can't find any references in llvm-project

@krzysz00
Copy link
Contributor

There're downstream uses - rocMLIR's one of them but given that we've gotten PRs on that pass, it's got usage

@krzysz00 krzysz00 merged commit 1fd1f65 into llvm:main Sep 27, 2024
8 checks passed
Copy link

@dhernandez0 Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@llvm-ci
Copy link
Collaborator

llvm-ci commented Sep 27, 2024

LLVM Buildbot has detected a new failure on builder clang-aarch64-sve-vls-2stage running on linaro-g3-02 while building mlir at step 11 "build stage 2".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/4/builds/2473

Here is the relevant piece of the build log for the reference
Step 11 (build stage 2) failure: 'ninja' (failure)
...
[7876/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/data-to-inits.cpp.o
[7877/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/VectorSubscripts.cpp.o
[7878/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertConstant.cpp.o
[7879/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/check-call.cpp.o
[7880/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/IterationSpace.cpp.o
[7881/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertArrayConstructor.cpp.o
[7882/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/check-do-forall.cpp.o
[7883/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/check-cuda.cpp.o
[7884/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/scope.cpp.o
[7885/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/Clauses.cpp.o
FAILED: tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/Clauses.cpp.o 
/home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/stage1.install/bin/clang++ -DFLANG_INCLUDE_TESTS=1 -DFLANG_LITTLE_ENDIAN=1 -DGTEST_HAS_RTTI=0 -D_DEBUG -D_GLIBCXX_ASSERTIONS -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -I/home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/stage2/tools/flang/lib/Lower -I/home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/llvm/flang/lib/Lower -I/home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/llvm/flang/include -I/home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/stage2/tools/flang/include -I/home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/stage2/include -I/home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/llvm/llvm/include -isystem /home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/llvm/llvm/../mlir/include -isystem /home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/stage2/tools/mlir/include -isystem /home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/stage2/tools/clang/include -isystem /home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/llvm/llvm/../clang/include -mcpu=neoverse-512tvb -msve-vector-bits=256 -mllvm -treat-scalable-fixed-error-as-warning=false -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -Werror=unguarded-availability-new -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wmissing-field-initializers -pedantic -Wno-long-long -Wc++98-compat-extra-semi -Wimplicit-fallthrough -Wcovered-switch-default -Wno-noexcept-type -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wsuggest-override -Wno-comment -Wstring-conversion -Wmisleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -ffunction-sections -fdata-sections -Wno-deprecated-copy -Wno-string-conversion -Wno-ctad-maybe-unsupported -Wno-unused-command-line-argument -Wstring-conversion           -Wcovered-switch-default -Wno-nested-anon-types -O3 -DNDEBUG -std=c++17  -fno-exceptions -funwind-tables -fno-rtti -UNDEBUG -MD -MT tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/Clauses.cpp.o -MF tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/Clauses.cpp.o.d -o tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/Clauses.cpp.o -c /home/tcwg-buildbot/worker/clang-aarch64-sve-vls-2stage/llvm/flang/lib/Lower/OpenMP/Clauses.cpp
Killed
[7886/8682] Building CXX object tools/flang/lib/Frontend/CMakeFiles/flangFrontend.dir/CompilerInstance.cpp.o
[7887/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/program-tree.cpp.o
[7888/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/unparse-with-symbols.cpp.o
[7889/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/symbol.cpp.o
[7890/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/tools.cpp.o
[7891/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/pointer-assignment.cpp.o
[7892/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/mod-file.cpp.o
[7893/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/resolve-names-utils.cpp.o
[7894/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/rewrite-directives.cpp.o
[7895/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/resolve-labels.cpp.o
[7896/8682] Building CXX object tools/flang/lib/Frontend/CMakeFiles/flangFrontend.dir/CompilerInvocation.cpp.o
[7897/8682] Building CXX object tools/flang/lib/Frontend/CMakeFiles/flangFrontend.dir/FrontendAction.cpp.o
[7898/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/runtime-type-info.cpp.o
[7899/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/definable.cpp.o
[7900/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/resolve-directives.cpp.o
[7901/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/type.cpp.o
[7902/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/rewrite-parse-tree.cpp.o
[7903/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/expression.cpp.o
[7904/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/semantics.cpp.o
[7905/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/resolve-names.cpp.o
[7906/8682] Building CXX object tools/flang/lib/Frontend/CMakeFiles/flangFrontend.dir/FrontendActions.cpp.o
[7907/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/CallInterface.cpp.o
[7908/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/Utils.cpp.o
[7909/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/IO.cpp.o
[7910/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertCall.cpp.o
[7911/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertVariable.cpp.o
[7912/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertProcedureDesignator.cpp.o
[7913/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/Decomposer.cpp.o
[7914/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/HostAssociations.cpp.o
[7915/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/Allocatable.cpp.o
[7916/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/Bridge.cpp.o
[7917/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertExprToHLFIR.cpp.o
[7918/8682] Building CXX object tools/flang/lib/Semantics/CMakeFiles/FortranSemantics.dir/check-omp-structure.cpp.o
[7919/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertType.cpp.o
[7920/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/OpenMP/ClauseProcessor.cpp.o
[7921/8682] Building CXX object tools/flang/lib/Lower/CMakeFiles/FortranLower.dir/ConvertExpr.cpp.o

Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
… and target type as arguments (llvm#108815)

Instead of hardcoding all fp smaller than 32 bits are unsupported we
provide a way to pass supported floating point types as well as the
target type. fp64 and fp32 are implicitly supported.

CC: @krzysz00 @manupak
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants