-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. #126103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: Benoit Jacob (bjacob) ChangesThe existing
This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to. Full diff: https://github.com/llvm/llvm-project/pull/126103.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f0f17c6adcb088e..5abdb90c45df00d 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -48,6 +48,27 @@ struct MathPolynomialApproximationOptions {
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
+// Adds patterns to convert to f32 around math functions for which `predicate`
+// returns true.
+void populateMathF32ExpansionPatterns(
+ RewritePatternSet &patterns,
+ const std::function<bool(StringRef)> &predicate);
+
+// Adds patterns to enable polynomial approximations for math functions for
+// which `predicate` returns true.
+void populateMathPolynomialApproximationPatterns(
+ RewritePatternSet &patterns,
+ const std::function<bool(StringRef)> &predicate);
+
+// Legacy. Calls both populateMathF32ExpansionPatterns and
+// populateMathPolynomialApproximationPatterns with predicates enabling a
+// certain set of math function rewrites, that probably can't be changed for
+// compatibility reasons. Notice that unlike
+// populateMathPolynomialApproximationPatterns(patterns, predicate), this
+// overload also calls populateMathF32ExpansionPatterns.
+// Prefer calling these functions directly:
+// * populateMathF32ExpansionPatterns(patterns, predicate)
+// * populateMathPolynomialApproximationPatterns(patterns, predicate)
void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options = {});
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 24c892f68b50316..1db5c0cfca28988 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1667,28 +1667,158 @@ void mlir::populatePolynomialApproximateErfPattern(
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
}
+void mlir::populateMathF32ExpansionPatterns(
+ RewritePatternSet &patterns,
+ const std::function<bool(StringRef)> &predicate) {
+ MLIRContext *context = patterns.getContext();
+ if (predicate("acos")) {
+ patterns.add<ReuseF32Expansion<math::AcosOp>>(context);
+ }
+ if (predicate("acosh")) {
+ patterns.add<ReuseF32Expansion<math::AcoshOp>>(context);
+ }
+ if (predicate("asin")) {
+ patterns.add<ReuseF32Expansion<math::AsinOp>>(context);
+ }
+ if (predicate("asinh")) {
+ patterns.add<ReuseF32Expansion<math::AsinhOp>>(context);
+ }
+ if (predicate("atan")) {
+ patterns.add<ReuseF32Expansion<math::AtanOp>>(context);
+ }
+ if (predicate("atan2")) {
+ patterns.add<ReuseF32Expansion<math::Atan2Op>>(context);
+ }
+ if (predicate("atanh")) {
+ patterns.add<ReuseF32Expansion<math::AtanhOp>>(context);
+ }
+ if (predicate("cbrt")) {
+ patterns.add<ReuseF32Expansion<math::CbrtOp>>(context);
+ }
+ if (predicate("cos")) {
+ patterns.add<ReuseF32Expansion<math::CosOp>>(context);
+ }
+ if (predicate("cosh")) {
+ patterns.add<ReuseF32Expansion<math::CoshOp>>(context);
+ }
+ if (predicate("erf")) {
+ patterns.add<ReuseF32Expansion<math::ErfOp>>(context);
+ }
+ if (predicate("exp")) {
+ patterns.add<ReuseF32Expansion<math::ExpOp>>(context);
+ }
+ if (predicate("exp2")) {
+ patterns.add<ReuseF32Expansion<math::Exp2Op>>(context);
+ }
+ if (predicate("expm1")) {
+ patterns.add<ReuseF32Expansion<math::ExpM1Op>>(context);
+ }
+ if (predicate("log")) {
+ patterns.add<ReuseF32Expansion<math::LogOp>>(context);
+ }
+ if (predicate("log10")) {
+ patterns.add<ReuseF32Expansion<math::Log10Op>>(context);
+ }
+ if (predicate("log2")) {
+ patterns.add<ReuseF32Expansion<math::Log2Op>>(context);
+ }
+ if (predicate("log1p")) {
+ patterns.add<ReuseF32Expansion<math::Log1pOp>>(context);
+ }
+ if (predicate("powf")) {
+ patterns.add<ReuseF32Expansion<math::PowFOp>>(context);
+ }
+ if (predicate("rsqrt")) {
+ patterns.add<ReuseF32Expansion<math::RsqrtOp>>(context);
+ }
+ if (predicate("sin")) {
+ patterns.add<ReuseF32Expansion<math::SinOp>>(context);
+ }
+ if (predicate("sinh")) {
+ patterns.add<ReuseF32Expansion<math::SinhOp>>(context);
+ }
+ if (predicate("sqrt")) {
+ patterns.add<ReuseF32Expansion<math::SqrtOp>>(context);
+ }
+ if (predicate("tan")) {
+ patterns.add<ReuseF32Expansion<math::TanOp>>(context);
+ }
+ if (predicate("tanh")) {
+ patterns.add<ReuseF32Expansion<math::TanhOp>>(context);
+ }
+}
+
+void mlir::populateMathPolynomialApproximationPatterns(
+ RewritePatternSet &patterns,
+ const std::function<bool(StringRef)> &predicate) {
+ MLIRContext *context = patterns.getContext();
+ if (predicate("acos")) {
+ patterns.add<AcosPolynomialApproximation>(context);
+ }
+ if (predicate("asin")) {
+ patterns.add<AsinPolynomialApproximation>(context);
+ }
+ if (predicate("atan")) {
+ patterns.add<AtanApproximation>(context);
+ }
+ if (predicate("atan2")) {
+ patterns.add<Atan2Approximation>(context);
+ }
+ if (predicate("cbrt")) {
+ patterns.add<CbrtApproximation>(context);
+ }
+ if (predicate("cos")) {
+ patterns.add<SinAndCosApproximation<false, math::CosOp>>(context);
+ }
+ if (predicate("erf")) {
+ patterns.add<ErfPolynomialApproximation>(context);
+ }
+ if (predicate("exp")) {
+ patterns.add<ExpApproximation>(context);
+ }
+ if (predicate("expm1")) {
+ patterns.add<ExpM1Approximation>(context);
+ }
+ if (predicate("log")) {
+ patterns.add<LogApproximation>(context);
+ }
+ if (predicate("log2")) {
+ patterns.add<Log2Approximation>(context);
+ }
+ if (predicate("log1p")) {
+ patterns.add<Log1pApproximation>(context);
+ }
+ if (predicate("rsqrt")) {
+ patterns.add<RsqrtApproximation>(context);
+ }
+ if (predicate("sin")) {
+ patterns.add<SinAndCosApproximation<true, math::SinOp>>(context);
+ }
+ if (predicate("tanh")) {
+ patterns.add<TanhApproximation>(context);
+ }
+}
+
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
- // Patterns for leveraging existing f32 lowerings on other data types.
- patterns
- .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
- ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
- ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
- ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
- ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
- ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
- patterns.getContext());
-
- patterns
- .add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, AsinPolynomialApproximation,
- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+ mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) {
+ return name == "atan" || name == "atan2" || name == "tanh" ||
+ name == "log" || name == "log2" || name == "log1p" ||
+ name == "erf" || name == "exp" || name == "expm1" ||
+ name == "cbrt" || name == "sin" || name == "cos";
+ });
+
+ populateMathPolynomialApproximationPatterns(patterns, [](StringRef name) {
+ return name == "atan" || name == "atan2" || name == "tanh" ||
+ name == "log" || name == "log2" || name == "log1p" ||
+ name == "erf" || name == "asin" || name == "acos" || name == "exp" ||
+ name == "expm1" || name == "cbrt" || name == "sin" || name == "cos";
+ });
+
if (options.enableAvx2) {
- patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
- patterns.getContext());
+ auto predicateRsqrt = [](StringRef name) { return name == "rsqrt"; };
+ mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt);
+ mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt);
}
}
|
Signed-off-by: Benoit Jacob <[email protected]>
caa1603
to
bd15e6b
Compare
Updated. @kuhar @MaheshRavishankar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…nction rewrites. (llvm#126103) The existing `mlir::populateMathPolynomialApproximationPatterns` is coarse-grained and inflexible: - It populates 2 distinct classes of patterns: (1) polynomial approximations, (2) expansions of operands to f32. - It does not offer knobs to select which math functions to apply the rewrites to. This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to. Signed-off-by: Benoit Jacob <[email protected]>
…nction rewrites. (llvm#126103) The existing `mlir::populateMathPolynomialApproximationPatterns` is coarse-grained and inflexible: - It populates 2 distinct classes of patterns: (1) polynomial approximations, (2) expansions of operands to f32. - It does not offer knobs to select which math functions to apply the rewrites to. This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to. Signed-off-by: Benoit Jacob <[email protected]>
…nction rewrites. (llvm#126103) The existing `mlir::populateMathPolynomialApproximationPatterns` is coarse-grained and inflexible: - It populates 2 distinct classes of patterns: (1) polynomial approximations, (2) expansions of operands to f32. - It does not offer knobs to select which math functions to apply the rewrites to. This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to. Signed-off-by: Benoit Jacob <[email protected]>
…ns (#130782) This is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm. In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm. Note that we're only updating the new API added in #126103. The legacy one (`mlir::populateMathPolynomialApproximationPatterns`) is left unmodified to encourage users to move out of it.
The existing
mlir::populateMathPolynomialApproximationPatterns
is coarse-grained and inflexible:This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to.