-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math] add benefit arg to populate math approximations/expansions #130782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is a follow-up to llvm#127291, which added the benefit arg to lowerings to intrinsics and libm. In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm. Note that we're only updating the new API added in llvm#126103. The legacy one (`mlir::populateMathPolynomialApproximationPatterns`) is left unmodified to encourage users to move out of it.
@llvm/pr-subscribers-mlir-math Author: Emilio Cota (cota) ChangesThis is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm. In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm. Note that we're only updating the new API added in #126103. The legacy one ( Full diff: https://github.com/llvm/llvm-project/pull/130782.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9adc1c6940a15..c0fe5d3be448a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -52,12 +53,14 @@ void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
// Adds patterns to convert to f32 around math functions for which `predicate`
// returns true.
void populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Adds patterns to enable polynomial approximations for math functions for
// which `predicate` returns true.
void populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Legacy. Calls both populateMathF32ExpansionPatterns and
// populateMathPolynomialApproximationPatterns with predicates enabling a
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 167eebd786dba..a26e380232a91 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1776,90 +1776,93 @@ void mlir::populatePolynomialApproximateErfcPattern(
template <typename OpType>
static void
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
+ patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
}
}
void mlir::populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
- populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
}
template <typename OpType, typename PatternType>
static void populateMathPolynomialApproximationPattern(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<PatternType>(patterns.getContext());
+ patterns.add<PatternType>(patterns.getContext(), benefit);
}
}
void mlir::populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
populateMathPolynomialApproximationPattern<AcosOp,
AcosPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AsinOp,
AsinPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
+ CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ErfcOp,
ErfcPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
+ SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
}
void mlir::populateMathPolynomialApproximationPatterns(
|
@llvm/pr-subscribers-mlir Author: Emilio Cota (cota) ChangesThis is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm. In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm. Note that we're only updating the new API added in #126103. The legacy one ( Full diff: https://github.com/llvm/llvm-project/pull/130782.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9adc1c6940a15..c0fe5d3be448a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -52,12 +53,14 @@ void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
// Adds patterns to convert to f32 around math functions for which `predicate`
// returns true.
void populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Adds patterns to enable polynomial approximations for math functions for
// which `predicate` returns true.
void populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Legacy. Calls both populateMathF32ExpansionPatterns and
// populateMathPolynomialApproximationPatterns with predicates enabling a
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 167eebd786dba..a26e380232a91 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1776,90 +1776,93 @@ void mlir::populatePolynomialApproximateErfcPattern(
template <typename OpType>
static void
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
+ patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
}
}
void mlir::populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
- populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
}
template <typename OpType, typename PatternType>
static void populateMathPolynomialApproximationPattern(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<PatternType>(patterns.getContext());
+ patterns.add<PatternType>(patterns.getContext(), benefit);
}
}
void mlir::populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
populateMathPolynomialApproximationPattern<AcosOp,
AcosPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AsinOp,
AsinPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
+ CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ErfcOp,
ErfcPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
+ SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
}
void mlir::populateMathPolynomialApproximationPatterns(
|
This is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm.
In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm.
Note that we're only updating the new API added in #126103. The legacy one (
mlir::populateMathPolynomialApproximationPatterns
) is left unmodified to encourage users to move out of it.