Skip to content

[MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. #126103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 10, 2025

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Feb 6, 2025

The existing mlir::populateMathPolynomialApproximationPatterns is coarse-grained and inflexible:

  • It populates 2 distinct classes of patterns: (1) polynomial approximations, (2) expansions of operands to f32.
  • It does not offer knobs to select which math functions to apply the rewrites to.

This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to.

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Benoit Jacob (bjacob)

Changes

The existing mlir::populateMathPolynomialApproximationPatterns is coarse-grained and inflexible:

  • It populates 2 distinct classes of patterns: (1) polynomial approximations, (2) expansions of operands to f32.
  • It does not offer knobs to select which math functions to apply the rewrites to.

This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to.


Full diff: https://github.com/llvm/llvm-project/pull/126103.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+21)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+149-19)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f0f17c6adcb088e..5abdb90c45df00d 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -48,6 +48,27 @@ struct MathPolynomialApproximationOptions {
 void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
 void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
 
+// Adds patterns to convert to f32 around math functions for which `predicate`
+// returns true.
+void populateMathF32ExpansionPatterns(
+    RewritePatternSet &patterns,
+    const std::function<bool(StringRef)> &predicate);
+
+// Adds patterns to enable polynomial approximations for math functions for
+// which `predicate` returns true.
+void populateMathPolynomialApproximationPatterns(
+    RewritePatternSet &patterns,
+    const std::function<bool(StringRef)> &predicate);
+
+// Legacy. Calls both populateMathF32ExpansionPatterns and
+// populateMathPolynomialApproximationPatterns with predicates enabling a
+// certain set of math function rewrites, that probably can't be changed for
+// compatibility reasons. Notice that unlike
+// populateMathPolynomialApproximationPatterns(patterns, predicate), this
+// overload also calls populateMathF32ExpansionPatterns.
+// Prefer calling these functions directly:
+// * populateMathF32ExpansionPatterns(patterns, predicate)
+// * populateMathPolynomialApproximationPatterns(patterns, predicate)
 void populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options = {});
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 24c892f68b50316..1db5c0cfca28988 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1667,28 +1667,158 @@ void mlir::populatePolynomialApproximateErfPattern(
   patterns.add<ErfPolynomialApproximation>(patterns.getContext());
 }
 
+void mlir::populateMathF32ExpansionPatterns(
+    RewritePatternSet &patterns,
+    const std::function<bool(StringRef)> &predicate) {
+  MLIRContext *context = patterns.getContext();
+  if (predicate("acos")) {
+    patterns.add<ReuseF32Expansion<math::AcosOp>>(context);
+  }
+  if (predicate("acosh")) {
+    patterns.add<ReuseF32Expansion<math::AcoshOp>>(context);
+  }
+  if (predicate("asin")) {
+    patterns.add<ReuseF32Expansion<math::AsinOp>>(context);
+  }
+  if (predicate("asinh")) {
+    patterns.add<ReuseF32Expansion<math::AsinhOp>>(context);
+  }
+  if (predicate("atan")) {
+    patterns.add<ReuseF32Expansion<math::AtanOp>>(context);
+  }
+  if (predicate("atan2")) {
+    patterns.add<ReuseF32Expansion<math::Atan2Op>>(context);
+  }
+  if (predicate("atanh")) {
+    patterns.add<ReuseF32Expansion<math::AtanhOp>>(context);
+  }
+  if (predicate("cbrt")) {
+    patterns.add<ReuseF32Expansion<math::CbrtOp>>(context);
+  }
+  if (predicate("cos")) {
+    patterns.add<ReuseF32Expansion<math::CosOp>>(context);
+  }
+  if (predicate("cosh")) {
+    patterns.add<ReuseF32Expansion<math::CoshOp>>(context);
+  }
+  if (predicate("erf")) {
+    patterns.add<ReuseF32Expansion<math::ErfOp>>(context);
+  }
+  if (predicate("exp")) {
+    patterns.add<ReuseF32Expansion<math::ExpOp>>(context);
+  }
+  if (predicate("exp2")) {
+    patterns.add<ReuseF32Expansion<math::Exp2Op>>(context);
+  }
+  if (predicate("expm1")) {
+    patterns.add<ReuseF32Expansion<math::ExpM1Op>>(context);
+  }
+  if (predicate("log")) {
+    patterns.add<ReuseF32Expansion<math::LogOp>>(context);
+  }
+  if (predicate("log10")) {
+    patterns.add<ReuseF32Expansion<math::Log10Op>>(context);
+  }
+  if (predicate("log2")) {
+    patterns.add<ReuseF32Expansion<math::Log2Op>>(context);
+  }
+  if (predicate("log1p")) {
+    patterns.add<ReuseF32Expansion<math::Log1pOp>>(context);
+  }
+  if (predicate("powf")) {
+    patterns.add<ReuseF32Expansion<math::PowFOp>>(context);
+  }
+  if (predicate("rsqrt")) {
+    patterns.add<ReuseF32Expansion<math::RsqrtOp>>(context);
+  }
+  if (predicate("sin")) {
+    patterns.add<ReuseF32Expansion<math::SinOp>>(context);
+  }
+  if (predicate("sinh")) {
+    patterns.add<ReuseF32Expansion<math::SinhOp>>(context);
+  }
+  if (predicate("sqrt")) {
+    patterns.add<ReuseF32Expansion<math::SqrtOp>>(context);
+  }
+  if (predicate("tan")) {
+    patterns.add<ReuseF32Expansion<math::TanOp>>(context);
+  }
+  if (predicate("tanh")) {
+    patterns.add<ReuseF32Expansion<math::TanhOp>>(context);
+  }
+}
+
+void mlir::populateMathPolynomialApproximationPatterns(
+    RewritePatternSet &patterns,
+    const std::function<bool(StringRef)> &predicate) {
+  MLIRContext *context = patterns.getContext();
+  if (predicate("acos")) {
+    patterns.add<AcosPolynomialApproximation>(context);
+  }
+  if (predicate("asin")) {
+    patterns.add<AsinPolynomialApproximation>(context);
+  }
+  if (predicate("atan")) {
+    patterns.add<AtanApproximation>(context);
+  }
+  if (predicate("atan2")) {
+    patterns.add<Atan2Approximation>(context);
+  }
+  if (predicate("cbrt")) {
+    patterns.add<CbrtApproximation>(context);
+  }
+  if (predicate("cos")) {
+    patterns.add<SinAndCosApproximation<false, math::CosOp>>(context);
+  }
+  if (predicate("erf")) {
+    patterns.add<ErfPolynomialApproximation>(context);
+  }
+  if (predicate("exp")) {
+    patterns.add<ExpApproximation>(context);
+  }
+  if (predicate("expm1")) {
+    patterns.add<ExpM1Approximation>(context);
+  }
+  if (predicate("log")) {
+    patterns.add<LogApproximation>(context);
+  }
+  if (predicate("log2")) {
+    patterns.add<Log2Approximation>(context);
+  }
+  if (predicate("log1p")) {
+    patterns.add<Log1pApproximation>(context);
+  }
+  if (predicate("rsqrt")) {
+    patterns.add<RsqrtApproximation>(context);
+  }
+  if (predicate("sin")) {
+    patterns.add<SinAndCosApproximation<true, math::SinOp>>(context);
+  }
+  if (predicate("tanh")) {
+    patterns.add<TanhApproximation>(context);
+  }
+}
+
 void mlir::populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options) {
-  // Patterns for leveraging existing f32 lowerings on other data types.
-  patterns
-      .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
-           ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
-           ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
-           ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
-           ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
-           ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
-          patterns.getContext());
-
-  patterns
-      .add<AtanApproximation, Atan2Approximation, TanhApproximation,
-           LogApproximation, Log2Approximation, Log1pApproximation,
-           ErfPolynomialApproximation, AsinPolynomialApproximation,
-           AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-           CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-           SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+  mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) {
+    return name == "atan" || name == "atan2" || name == "tanh" ||
+           name == "log" || name == "log2" || name == "log1p" ||
+           name == "erf" || name == "exp" || name == "expm1" ||
+           name == "cbrt" || name == "sin" || name == "cos";
+  });
+
+  populateMathPolynomialApproximationPatterns(patterns, [](StringRef name) {
+    return name == "atan" || name == "atan2" || name == "tanh" ||
+           name == "log" || name == "log2" || name == "log1p" ||
+           name == "erf" || name == "asin" || name == "acos" || name == "exp" ||
+           name == "expm1" || name == "cbrt" || name == "sin" || name == "cos";
+  });
+
   if (options.enableAvx2) {
-    patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
-        patterns.getContext());
+    auto predicateRsqrt = [](StringRef name) { return name == "rsqrt"; };
+    mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt);
+    mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt);
   }
 }

Signed-off-by: Benoit Jacob <[email protected]>
@bjacob bjacob requested a review from kuhar February 10, 2025 11:58
@bjacob
Copy link
Contributor Author

bjacob commented Feb 10, 2025

Updated. @kuhar @MaheshRavishankar

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@MaheshRavishankar MaheshRavishankar merged commit ced23aa into llvm:main Feb 10, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…nction rewrites. (llvm#126103)

The existing `mlir::populateMathPolynomialApproximationPatterns` is
coarse-grained and inflexible:
- It populates 2 distinct classes of patterns: (1) polynomial
approximations, (2) expansions of operands to f32.
- It does not offer knobs to select which math functions to apply the
rewrites to.

This PR adds finer-grained populate-patterns functions, which take a
predicate lambda allowing the caller to control which math functions to
apply rewrites to.

Signed-off-by: Benoit Jacob <[email protected]>
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…nction rewrites. (llvm#126103)

The existing `mlir::populateMathPolynomialApproximationPatterns` is
coarse-grained and inflexible:
- It populates 2 distinct classes of patterns: (1) polynomial
approximations, (2) expansions of operands to f32.
- It does not offer knobs to select which math functions to apply the
rewrites to.

This PR adds finer-grained populate-patterns functions, which take a
predicate lambda allowing the caller to control which math functions to
apply rewrites to.

Signed-off-by: Benoit Jacob <[email protected]>
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…nction rewrites. (llvm#126103)

The existing `mlir::populateMathPolynomialApproximationPatterns` is
coarse-grained and inflexible:
- It populates 2 distinct classes of patterns: (1) polynomial
approximations, (2) expansions of operands to f32.
- It does not offer knobs to select which math functions to apply the
rewrites to.

This PR adds finer-grained populate-patterns functions, which take a
predicate lambda allowing the caller to control which math functions to
apply rewrites to.

Signed-off-by: Benoit Jacob <[email protected]>
cota added a commit that referenced this pull request Mar 11, 2025
…ns (#130782)

This is a follow-up to #127291, which added the benefit arg to lowerings
to intrinsics and libm.

In this change we add the benefit arg to the math approximation and
expansion lowerings, which allows users to establish a preferred order
among all three math lowerings, namely approximations, intrinsics and
libm.

Note that we're only updating the new API added in #126103. The legacy
one (`mlir::populateMathPolynomialApproximationPatterns`) is left
unmodified to encourage users to move out of it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants