-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Math] Add pass to legalize math functions to f32-or-higher #78361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: Krzysztof Drewniak (krzysz00) ChangesSince most of the operations in the Versions of this lowering are already implicitly present in some passes, like ConvertGPUToROCDL. However, because those are implicit rewrites, they hide the floating-point extension and truncation, preventing anyone from writing passes that operate on those implitic extf/truncf pairs. Exposing this legalization explicitly is needed to allow lowening 8-bit floats on AMD GPUs, as the implementation of extf and truncf on that platform requires the complex logic found in ArithToAMDGPU, which runs before the GPU to ROCDL lowering. Full diff: https://github.com/llvm/llvm-project/pull/78361.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9e6759ef229d6f..010dde5ea73847 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -16,12 +16,15 @@ namespace math {
#define GEN_PASS_DECL
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_DECL_MATHUPLIFTTOFMA
+#define GEN_PASS_DECL_MATHLEGALIZETOF32
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
} // namespace math
+class ConversionTarget;
class RewritePatternSet;
+class TypeConverter;
void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanPattern(RewritePatternSet &patterns);
@@ -48,6 +51,13 @@ void populateMathPolynomialApproximationPatterns(
void populateUpliftToFMAPatterns(RewritePatternSet &patterns);
+namespace math {
+void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter);
+void populateLegalizeToF32ConversionTarget(ConversionTarget &target,
+ TypeConverter &typeConverter);
+void populateLegalizeToF32Patterns(RewritePatternSet &patterns,
+ TypeConverter &typeConverter);
+} // namespace math
} // namespace mlir
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index d81a92b0371e31..e870e714bfda58 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -19,4 +19,21 @@ def MathUpliftToFMA : Pass<"math-uplift-to-fma"> {
let dependentDialects = ["math::MathDialect"];
}
+def MathLegalizeToF32 : Pass<"math-legalize-to-f32"> {
+ let summary = "Legalize floating-point math ops on low-precision floats";
+ let description = [{
+ On many targets, the math functions are not implemented for floating-point
+ types less precise than IEEE single-precision (aka f32), such as half-floats,
+ bfloat16, or 8-bit floats.
+
+ This pass explicitly legalizes these math functions by inserting
+ `arith.extf` and `arith.truncf` pairs around said op, which preserves
+ the original semantics while enabling lowering.
+
+ As an exception, this pass does not legalize `math.fma`, because
+ that is an operation frequently implemented at low precisions.
+ }];
+ let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 2d446b453edc91..2a5b4fbcb52712 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
ExpandPatterns.cpp
+ LegalizeToF32.cpp
PolynomialApproximation.cpp
UpliftToFMA.cpp
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
new file mode 100644
index 00000000000000..d281790e877152
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -0,0 +1,118 @@
+//===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements legalizing math operations on small floating-point
+// types through arith.extf and arith.truncf.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHLEGALIZETOF32
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+using namespace mlir;
+namespace {
+struct LegalizeToF32RewritePattern final : ConversionPattern {
+ LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
+ : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+struct LegalizeToF32Pass final
+ : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void mlir::math::populateLegalizeToF32TypeConverter(
+ TypeConverter &typeConverter) {
+ typeConverter.addConversion(
+ [](Type type) -> std::optional<Type> { return type; });
+ typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
+ if (type.getWidth() < 32)
+ return Float32Type::get(type.getContext());
+ return std::nullopt;
+ });
+ typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
+ if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
+ return type.clone(Float32Type::get(type.getContext()));
+ return std::nullopt;
+ });
+ typeConverter.addTargetMaterialization(
+ [](OpBuilder &b, Type target, ValueRange input, Location loc) {
+ return b.create<arith::ExtFOp>(loc, target, input);
+ });
+}
+
+void mlir::math::populateLegalizeToF32ConversionTarget(
+ ConversionTarget &target, TypeConverter &typeConverter) {
+ target.addDynamicallyLegalDialect<MathDialect>(
+ [&typeConverter](Operation *op) -> bool {
+ return typeConverter.isLegal(op);
+ });
+ target.addLegalOp<FmaOp>();
+ target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
+}
+
+LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op->getLoc();
+ const TypeConverter *converter = getTypeConverter();
+ if (converter->isLegal(op))
+ return rewriter.notifyMatchFailure(loc, "op already legal");
+ OperationState newOp(loc, op->getName());
+ newOp.addOperands(operands);
+
+ SmallVector<Type> newResultTypes;
+ if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
+ return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+ newOp.addTypes(newResultTypes);
+ newOp.addAttributes(op->getAttrs());
+ Operation *legalized = rewriter.create(newOp);
+ SmallVector<Value> results = legalized->getResults();
+ for (auto [result, newType, origType] :
+ llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+ if (newType != origType)
+ result = rewriter.create<arith::TruncFOp>(loc, origType, result);
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+}
+
+void mlir::math::populateLegalizeToF32Patterns(RewritePatternSet &patterns,
+ TypeConverter &typeConverter) {
+ patterns.add<LegalizeToF32RewritePattern>(typeConverter,
+ patterns.getContext());
+}
+
+void LegalizeToF32Pass::runOnOperation() {
+ Operation *op = getOperation();
+ MLIRContext &ctx = getContext();
+
+ TypeConverter typeConverter;
+ math::populateLegalizeToF32TypeConverter(typeConverter);
+ ConversionTarget target(ctx);
+ math::populateLegalizeToF32ConversionTarget(target, typeConverter);
+ RewritePatternSet patterns(&ctx);
+ math::populateLegalizeToF32Patterns(patterns, typeConverter);
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/Math/legalize-to-f32.mlir b/mlir/test/Dialect/Math/legalize-to-f32.mlir
new file mode 100644
index 00000000000000..3f648c9379955b
--- /dev/null
+++ b/mlir/test/Dialect/Math/legalize-to-f32.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s --split-input-file -math-legalize-to-f32 | FileCheck %s
+
+// CHECK-LABEL: @sin
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+func.func @sin(%arg0: f16) -> f16 {
+ // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+ // CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+ // CHECK: return [[TRUNCF]] : f16
+ %0 = math.sin %arg0 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @fpowi
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: i32)
+func.func @fpowi(%arg0: f16, %arg1: i32) -> f16 {
+ // CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+ // CHECK: [[FPOWI:%.+]] = math.fpowi [[EXTF]], [[ARG1]]
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf [[FPOWI]]
+ // CHECK: return [[TRUNCF]] : f16
+ %0 = math.fpowi %arg0, %arg1 : f16, i32
+ return %0 : f16
+}
+
+// CHECK-LABEL: @fma
+// CHECK-SAME: ([[ARG0:%.+]]: f16, [[ARG1:%.+]]: f16, [[ARG2:%.+]]: f16)
+// CHECK: [[FMA:%.+]] = math.fma [[ARG0]], [[ARG1]], [[ARG2]]
+// CHECK: return [[FMA]] : f16
+func.func @fma(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
+ %0 = math.fma %arg0, %arg1, %arg2 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @absf_f32
+// CHECK-SAME: ([[ARG0:%.+]]: f32)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f32
+func.func @absf_f32(%arg0: f32) -> f32 {
+ %0 = math.absf %arg0 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @absf_f64
+// CHECK-SAME: ([[ARG0:%.+]]: f64)
+// CHECK: [[ABSF:%.+]] = math.absf [[ARG0]]
+// CHECK: return [[ABSF]] : f64
+func.func @absf_f64(%arg0: f64) -> f64 {
+ %0 = math.absf %arg0 : f64
+ return %0 : f64
+}
+
+// CHECK-LABEL: @sin_vector
+// CHECK-SAME: ([[ARG0:%.+]]: vector<2xbf16>)
+// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF]]
+// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF]] : vector<2xbf16>
+func.func @sin_vector(%arg0: vector<2xbf16>) -> vector<2xbf16> {
+ %0 = math.sin %arg0 : vector<2xbf16>
+ return %0 : vector<2xbf16>
+}
+
+// CHECK-LABEL: @fastmath
+// CHECK: math.sin %{{.+}} fastmath<nsz>
+func.func @fastmath(%arg0: f16) -> f16 {
+ %0 = math.sin %arg0 fastmath<nsz> : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: @sequences
+// CHECK-SAME: ([[ARG0:%.+]]: f16)
+// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
+// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
+// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[ABSF]]
+// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF0]]
+// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
+// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[SIN]]
+// CHECK: return [[TRUNCF1]] : f16
+func.func @sequences(%arg0: f16) -> f16 {
+ %0 = math.absf %arg0 : f16
+ %1 = math.sin %0 : f16
+ return %1 : f16
+}
|
return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); | ||
newOp.addTypes(newResultTypes); | ||
newOp.addAttributes(op->getAttrs()); | ||
Operation *legalized = rewriter.create(newOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wondering, can we, instead of using low-level API to construct op, just clone original op and update args/return types on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went for your suggestion - minus cloning the op, just mutating in place - what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure it's allowed in dialect conversion framework without cloning (or even with cloning, actually) as it should be able to 'revert' the changes if necessary. But with cloning we can at least use rewriter.replaceOp
instead of RAU. Ping @joker-eph I guess?
Regarding Intintionally(typo) don't tell the rewriter we're doing this to prevent spurious attempts to legalize the consumer
comment, I think you can avoid this by using addDynamicallyLegalOp
and always return true for non-math ops.
Also, I think updateRootInPlace
was recently renamed, can you rebase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, I had the OperationState
version so I could copy over attributes, properties, etc. while creating a new operation.
I'm not sure clone()
+ mutations really buys us anything over just the mutations. And looking at dialect conversion's resets, the only rollback you'd want to be able to do is cancelOpModification
which will reset the operands (though not result types, apparently)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I thought clone version will be shorter and will help to avoid using OperationState
builder, which is unconventional outside the builders, but it opens more questions on what is actually allowed in conversion framework. You can revert to old version if you like, at least I'm sure it should work with conversion framework, sorry.
Since most of the operations in the `math` dialect don't have low-precision implementations, add the -math-legalize-to-f32 pass that goes through and brackets low-precision math funcitons (like `math.sin %0 : f16`) with `arith.extf` and `arith.truncf`. This preserves the original semantics of the math operation but allows lowering to proceed. Versions of this lowering are already implicitly present in some passes, like ConvertGPUToROCDL. However, because those are implicit rewrites, they hide the floating-point extension and truncation, preventing anyone from writing passes that operate on those implitic extf/truncf pairs. Exposing this legalization explicitly is needed to allow lowening 8-bit floats on AMD GPUs, as the implementation of extf and truncf on that platform requires the complex logic found in ArithToAMDGPU, which runs before the GPU to ROCDL lowering.
695a98e
to
fd57c20
Compare
Since most of the operations in the
math
dialect don't have low-precision implementations, add the -math-legalize-to-f32 pass that goes through and brackets low-precision math funcitons (likemath.sin %0 : f16
) witharith.extf
andarith.truncf
. This preserves the original semantics of the math operation but allows lowering to proceed.Versions of this lowering are already implicitly present in some passes, like ConvertGPUToROCDL. However, because those are implicit rewrites, they hide the floating-point extension and truncation, preventing anyone from writing passes that operate on those implitic extf/truncf pairs.
Exposing this legalization explicitly is needed to allow lowening 8-bit floats on AMD GPUs, as the implementation of extf and truncf on that platform requires the complex logic found in ArithToAMDGPU, which runs before the GPU to ROCDL lowering.