Skip to content

Commit ced23aa

Browse files
authored
[MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. (#126103)
The existing `mlir::populateMathPolynomialApproximationPatterns` is coarse-grained and inflexible: - It populates 2 distinct classes of patterns: (1) polynomial approximations, (2) expansions of operands to f32. - It does not offer knobs to select which math functions to apply the rewrites to. This PR adds finer-grained populate-patterns functions, which take a predicate lambda allowing the caller to control which math functions to apply rewrites to. Signed-off-by: Benoit Jacob <[email protected]>
1 parent ad61e53 commit ced23aa

File tree

2 files changed

+135
-19
lines changed

2 files changed

+135
-19
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,25 @@ struct MathPolynomialApproximationOptions {
4848
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
4949
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
5050

51+
// Adds patterns to convert to f32 around math functions for which `predicate`
52+
// returns true.
53+
void populateMathF32ExpansionPatterns(
54+
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
55+
56+
// Adds patterns to enable polynomial approximations for math functions for
57+
// which `predicate` returns true.
58+
void populateMathPolynomialApproximationPatterns(
59+
RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
60+
61+
// Legacy. Calls both populateMathF32ExpansionPatterns and
62+
// populateMathPolynomialApproximationPatterns with predicates enabling a
63+
// certain set of math function rewrites, that probably can't be changed for
64+
// compatibility reasons. Notice that unlike
65+
// populateMathPolynomialApproximationPatterns(patterns, predicate), this
66+
// overload also calls populateMathF32ExpansionPatterns.
67+
// Prefer calling these functions directly:
68+
// * populateMathF32ExpansionPatterns(patterns, predicate)
69+
// * populateMathPolynomialApproximationPatterns(patterns, predicate)
5170
void populateMathPolynomialApproximationPatterns(
5271
RewritePatternSet &patterns,
5372
const MathPolynomialApproximationOptions &options = {});

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 116 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,28 +1667,125 @@ void mlir::populatePolynomialApproximateErfPattern(
16671667
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
16681668
}
16691669

1670+
template <typename OpType>
1671+
static void
1672+
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
1673+
llvm::function_ref<bool(StringRef)> predicate) {
1674+
if (predicate(OpType::getOperationName())) {
1675+
patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
1676+
}
1677+
}
1678+
1679+
void mlir::populateMathF32ExpansionPatterns(
1680+
RewritePatternSet &patterns,
1681+
llvm::function_ref<bool(StringRef)> predicate) {
1682+
populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
1683+
populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
1684+
populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
1685+
populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
1686+
populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
1687+
populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
1688+
populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
1689+
populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
1690+
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
1691+
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
1692+
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
1693+
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
1694+
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
1695+
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
1696+
populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
1697+
populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
1698+
populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
1699+
populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
1700+
populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
1701+
populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
1702+
populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
1703+
populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
1704+
populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
1705+
populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
1706+
populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
1707+
}
1708+
1709+
template <typename OpType, typename PatternType>
1710+
static void populateMathPolynomialApproximationPattern(
1711+
RewritePatternSet &patterns,
1712+
llvm::function_ref<bool(StringRef)> predicate) {
1713+
if (predicate(OpType::getOperationName())) {
1714+
patterns.add<PatternType>(patterns.getContext());
1715+
}
1716+
}
1717+
1718+
void mlir::populateMathPolynomialApproximationPatterns(
1719+
RewritePatternSet &patterns,
1720+
llvm::function_ref<bool(StringRef)> predicate) {
1721+
populateMathPolynomialApproximationPattern<AcosOp,
1722+
AcosPolynomialApproximation>(
1723+
patterns, predicate);
1724+
populateMathPolynomialApproximationPattern<AsinOp,
1725+
AsinPolynomialApproximation>(
1726+
patterns, predicate);
1727+
populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1728+
patterns, predicate);
1729+
populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1730+
patterns, predicate);
1731+
populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1732+
patterns, predicate);
1733+
populateMathPolynomialApproximationPattern<
1734+
CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
1735+
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1736+
patterns, predicate);
1737+
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1738+
patterns, predicate);
1739+
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1740+
patterns, predicate);
1741+
populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1742+
patterns, predicate);
1743+
populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1744+
patterns, predicate);
1745+
populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1746+
patterns, predicate);
1747+
populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1748+
patterns, predicate);
1749+
populateMathPolynomialApproximationPattern<
1750+
SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
1751+
populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1752+
patterns, predicate);
1753+
}
1754+
16701755
void mlir::populateMathPolynomialApproximationPatterns(
16711756
RewritePatternSet &patterns,
16721757
const MathPolynomialApproximationOptions &options) {
1673-
// Patterns for leveraging existing f32 lowerings on other data types.
1674-
patterns
1675-
.add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676-
ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677-
ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678-
ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679-
ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680-
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681-
patterns.getContext());
1682-
1683-
patterns
1684-
.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1685-
LogApproximation, Log2Approximation, Log1pApproximation,
1686-
ErfPolynomialApproximation, AsinPolynomialApproximation,
1687-
AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688-
CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1689-
SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
1758+
mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) -> bool {
1759+
return llvm::is_contained(
1760+
{math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1761+
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1762+
math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1763+
math::ErfOp::getOperationName(), math::ExpOp::getOperationName(),
1764+
math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(),
1765+
math::SinOp::getOperationName(), math::CosOp::getOperationName()},
1766+
name);
1767+
});
1768+
1769+
populateMathPolynomialApproximationPatterns(
1770+
patterns, [](StringRef name) -> bool {
1771+
return llvm::is_contained(
1772+
{math::AtanOp::getOperationName(),
1773+
math::Atan2Op::getOperationName(),
1774+
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1775+
math::Log2Op::getOperationName(),
1776+
math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1777+
math::AsinOp::getOperationName(), math::AcosOp::getOperationName(),
1778+
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1779+
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1780+
math::CosOp::getOperationName()},
1781+
name);
1782+
});
1783+
16901784
if (options.enableAvx2) {
1691-
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692-
patterns.getContext());
1785+
auto predicateRsqrt = [](StringRef name) {
1786+
return name == math::RsqrtOp::getOperationName();
1787+
};
1788+
mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt);
1789+
mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt);
16931790
}
16941791
}

0 commit comments

Comments
 (0)