Skip to content

Commit 3692fb6

Browse files
authored
[mlir][math] add benefit arg to populate math approximations/expansions (#130782)
This is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm. In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm. Note that we're only updating the new API added in #126103. The legacy one (`mlir::populateMathPolynomialApproximationPatterns`) is left unmodified to encourage users to move out of it.
1 parent d547005 commit 3692fb6

File tree

2 files changed

+59
-53
lines changed

2 files changed

+59
-53
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
1010
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
1111

12+
#include "mlir/IR/PatternMatch.h"
1213
#include "mlir/Pass/Pass.h"
1314

1415
namespace mlir {
@@ -52,12 +53,14 @@ void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
5253
// Adds patterns to convert to f32 around math functions for which `predicate`
5354
// returns true.
5455
void populateMathF32ExpansionPatterns(
55-
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
56+
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
57+
PatternBenefit = 1);
5658

5759
// Adds patterns to enable polynomial approximations for math functions for
5860
// which `predicate` returns true.
5961
void populateMathPolynomialApproximationPatterns(
60-
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
62+
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
63+
PatternBenefit = 1);
6164

6265
// Legacy. Calls both populateMathF32ExpansionPatterns and
6366
// populateMathPolynomialApproximationPatterns with predicates enabling a

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,90 +1776,93 @@ void mlir::populatePolynomialApproximateErfcPattern(
17761776
template <typename OpType>
17771777
static void
17781778
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
1779-
llvm::function_ref<bool(StringRef)> predicate) {
1779+
llvm::function_ref<bool(StringRef)> predicate,
1780+
PatternBenefit benefit) {
17801781
if (predicate(OpType::getOperationName())) {
1781-
patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
1782+
patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
17821783
}
17831784
}
17841785

17851786
void mlir::populateMathF32ExpansionPatterns(
1786-
RewritePatternSet &patterns,
1787-
llvm::function_ref<bool(StringRef)> predicate) {
1788-
populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
1789-
populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
1790-
populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
1791-
populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
1792-
populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
1793-
populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
1794-
populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
1795-
populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
1796-
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
1797-
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
1798-
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
1799-
populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
1800-
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
1801-
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
1802-
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
1803-
populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
1804-
populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
1805-
populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
1806-
populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
1807-
populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
1808-
populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
1809-
populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
1810-
populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
1811-
populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
1812-
populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
1813-
populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
1787+
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1788+
PatternBenefit benefit) {
1789+
populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
1790+
populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
1791+
populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
1792+
populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
1793+
populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
1794+
populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
1795+
populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
1796+
populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
1797+
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
1798+
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
1799+
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
1800+
populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
1801+
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
1802+
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
1803+
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
1804+
populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
1805+
populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
1806+
populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
1807+
populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
1808+
populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
1809+
populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
1810+
populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
1811+
populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
1812+
populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
1813+
populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
1814+
populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
18141815
}
18151816

18161817
template <typename OpType, typename PatternType>
18171818
static void populateMathPolynomialApproximationPattern(
1818-
RewritePatternSet &patterns,
1819-
llvm::function_ref<bool(StringRef)> predicate) {
1819+
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1820+
PatternBenefit benefit) {
18201821
if (predicate(OpType::getOperationName())) {
1821-
patterns.add<PatternType>(patterns.getContext());
1822+
patterns.add<PatternType>(patterns.getContext(), benefit);
18221823
}
18231824
}
18241825

18251826
void mlir::populateMathPolynomialApproximationPatterns(
1826-
RewritePatternSet &patterns,
1827-
llvm::function_ref<bool(StringRef)> predicate) {
1827+
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1828+
PatternBenefit benefit) {
18281829
populateMathPolynomialApproximationPattern<AcosOp,
18291830
AcosPolynomialApproximation>(
1830-
patterns, predicate);
1831+
patterns, predicate, benefit);
18311832
populateMathPolynomialApproximationPattern<AsinOp,
18321833
AsinPolynomialApproximation>(
1833-
patterns, predicate);
1834+
patterns, predicate, benefit);
18341835
populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1835-
patterns, predicate);
1836+
patterns, predicate, benefit);
18361837
populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1837-
patterns, predicate);
1838+
patterns, predicate, benefit);
18381839
populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1839-
patterns, predicate);
1840+
patterns, predicate, benefit);
18401841
populateMathPolynomialApproximationPattern<
1841-
CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
1842+
CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
1843+
benefit);
18421844
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1843-
patterns, predicate);
1845+
patterns, predicate, benefit);
18441846
populateMathPolynomialApproximationPattern<ErfcOp,
18451847
ErfcPolynomialApproximation>(
1846-
patterns, predicate);
1848+
patterns, predicate, benefit);
18471849
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1848-
patterns, predicate);
1850+
patterns, predicate, benefit);
18491851
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1850-
patterns, predicate);
1852+
patterns, predicate, benefit);
18511853
populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1852-
patterns, predicate);
1854+
patterns, predicate, benefit);
18531855
populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1854-
patterns, predicate);
1856+
patterns, predicate, benefit);
18551857
populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1856-
patterns, predicate);
1858+
patterns, predicate, benefit);
18571859
populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1858-
patterns, predicate);
1860+
patterns, predicate, benefit);
18591861
populateMathPolynomialApproximationPattern<
1860-
SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
1862+
SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
1863+
benefit);
18611864
populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1862-
patterns, predicate);
1865+
patterns, predicate, benefit);
18631866
}
18641867

18651868
void mlir::populateMathPolynomialApproximationPatterns(

0 commit comments

Comments
 (0)