Skip to content

Commit caa1603

Browse files
committed
polynomial-approx
Signed-off-by: Benoit Jacob <[email protected]>
1 parent 4717bab commit caa1603

File tree

2 files changed

+170
-19
lines changed

2 files changed

+170
-19
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ struct MathPolynomialApproximationOptions {
4848
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
4949
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
5050

51+
// Adds patterns to convert to f32 around math functions for which `predicate`
52+
// returns true.
53+
void populateMathF32ExpansionPatterns(
54+
RewritePatternSet &patterns,
55+
const std::function<bool(StringRef)> &predicate);
56+
57+
// Adds patterns to enable polynomial approximations for math functions for
58+
// which `predicate` returns true.
59+
void populateMathPolynomialApproximationPatterns(
60+
RewritePatternSet &patterns,
61+
const std::function<bool(StringRef)> &predicate);
62+
63+
// Legacy. Calls both populateMathF32ExpansionPatterns and
64+
// populateMathPolynomialApproximationPatterns with predicates enabling a
65+
// certain set of math function rewrites, that probably can't be changed for
66+
// compatibility reasons. Notice that unlike
67+
// populateMathPolynomialApproximationPatterns(patterns, predicate), this
68+
// overload also calls populateMathF32ExpansionPatterns.
69+
// Prefer calling these functions directly:
70+
// * populateMathF32ExpansionPatterns(patterns, predicate)
71+
// * populateMathPolynomialApproximationPatterns(patterns, predicate)
5172
void populateMathPolynomialApproximationPatterns(
5273
RewritePatternSet &patterns,
5374
const MathPolynomialApproximationOptions &options = {});

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 149 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,28 +1667,158 @@ void mlir::populatePolynomialApproximateErfPattern(
16671667
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
16681668
}
16691669

1670+
void mlir::populateMathF32ExpansionPatterns(
1671+
RewritePatternSet &patterns,
1672+
const std::function<bool(StringRef)> &predicate) {
1673+
MLIRContext *context = patterns.getContext();
1674+
if (predicate("acos")) {
1675+
patterns.add<ReuseF32Expansion<math::AcosOp>>(context);
1676+
}
1677+
if (predicate("acosh")) {
1678+
patterns.add<ReuseF32Expansion<math::AcoshOp>>(context);
1679+
}
1680+
if (predicate("asin")) {
1681+
patterns.add<ReuseF32Expansion<math::AsinOp>>(context);
1682+
}
1683+
if (predicate("asinh")) {
1684+
patterns.add<ReuseF32Expansion<math::AsinhOp>>(context);
1685+
}
1686+
if (predicate("atan")) {
1687+
patterns.add<ReuseF32Expansion<math::AtanOp>>(context);
1688+
}
1689+
if (predicate("atan2")) {
1690+
patterns.add<ReuseF32Expansion<math::Atan2Op>>(context);
1691+
}
1692+
if (predicate("atanh")) {
1693+
patterns.add<ReuseF32Expansion<math::AtanhOp>>(context);
1694+
}
1695+
if (predicate("cbrt")) {
1696+
patterns.add<ReuseF32Expansion<math::CbrtOp>>(context);
1697+
}
1698+
if (predicate("cos")) {
1699+
patterns.add<ReuseF32Expansion<math::CosOp>>(context);
1700+
}
1701+
if (predicate("cosh")) {
1702+
patterns.add<ReuseF32Expansion<math::CoshOp>>(context);
1703+
}
1704+
if (predicate("erf")) {
1705+
patterns.add<ReuseF32Expansion<math::ErfOp>>(context);
1706+
}
1707+
if (predicate("exp")) {
1708+
patterns.add<ReuseF32Expansion<math::ExpOp>>(context);
1709+
}
1710+
if (predicate("exp2")) {
1711+
patterns.add<ReuseF32Expansion<math::Exp2Op>>(context);
1712+
}
1713+
if (predicate("expm1")) {
1714+
patterns.add<ReuseF32Expansion<math::ExpM1Op>>(context);
1715+
}
1716+
if (predicate("log")) {
1717+
patterns.add<ReuseF32Expansion<math::LogOp>>(context);
1718+
}
1719+
if (predicate("log10")) {
1720+
patterns.add<ReuseF32Expansion<math::Log10Op>>(context);
1721+
}
1722+
if (predicate("log2")) {
1723+
patterns.add<ReuseF32Expansion<math::Log2Op>>(context);
1724+
}
1725+
if (predicate("log1p")) {
1726+
patterns.add<ReuseF32Expansion<math::Log1pOp>>(context);
1727+
}
1728+
if (predicate("powf")) {
1729+
patterns.add<ReuseF32Expansion<math::PowFOp>>(context);
1730+
}
1731+
if (predicate("rsqrt")) {
1732+
patterns.add<ReuseF32Expansion<math::RsqrtOp>>(context);
1733+
}
1734+
if (predicate("sin")) {
1735+
patterns.add<ReuseF32Expansion<math::SinOp>>(context);
1736+
}
1737+
if (predicate("sinh")) {
1738+
patterns.add<ReuseF32Expansion<math::SinhOp>>(context);
1739+
}
1740+
if (predicate("sqrt")) {
1741+
patterns.add<ReuseF32Expansion<math::SqrtOp>>(context);
1742+
}
1743+
if (predicate("tan")) {
1744+
patterns.add<ReuseF32Expansion<math::TanOp>>(context);
1745+
}
1746+
if (predicate("tanh")) {
1747+
patterns.add<ReuseF32Expansion<math::TanhOp>>(context);
1748+
}
1749+
}
1750+
1751+
void mlir::populateMathPolynomialApproximationPatterns(
1752+
RewritePatternSet &patterns,
1753+
const std::function<bool(StringRef)> &predicate) {
1754+
MLIRContext *context = patterns.getContext();
1755+
if (predicate("acos")) {
1756+
patterns.add<AcosPolynomialApproximation>(context);
1757+
}
1758+
if (predicate("asin")) {
1759+
patterns.add<AsinPolynomialApproximation>(context);
1760+
}
1761+
if (predicate("atan")) {
1762+
patterns.add<AtanApproximation>(context);
1763+
}
1764+
if (predicate("atan2")) {
1765+
patterns.add<Atan2Approximation>(context);
1766+
}
1767+
if (predicate("cbrt")) {
1768+
patterns.add<CbrtApproximation>(context);
1769+
}
1770+
if (predicate("cos")) {
1771+
patterns.add<SinAndCosApproximation<false, math::CosOp>>(context);
1772+
}
1773+
if (predicate("erf")) {
1774+
patterns.add<ErfPolynomialApproximation>(context);
1775+
}
1776+
if (predicate("exp")) {
1777+
patterns.add<ExpApproximation>(context);
1778+
}
1779+
if (predicate("expm1")) {
1780+
patterns.add<ExpM1Approximation>(context);
1781+
}
1782+
if (predicate("log")) {
1783+
patterns.add<LogApproximation>(context);
1784+
}
1785+
if (predicate("log2")) {
1786+
patterns.add<Log2Approximation>(context);
1787+
}
1788+
if (predicate("log1p")) {
1789+
patterns.add<Log1pApproximation>(context);
1790+
}
1791+
if (predicate("rsqrt")) {
1792+
patterns.add<RsqrtApproximation>(context);
1793+
}
1794+
if (predicate("sin")) {
1795+
patterns.add<SinAndCosApproximation<true, math::SinOp>>(context);
1796+
}
1797+
if (predicate("tanh")) {
1798+
patterns.add<TanhApproximation>(context);
1799+
}
1800+
}
1801+
16701802
void mlir::populateMathPolynomialApproximationPatterns(
16711803
RewritePatternSet &patterns,
16721804
const MathPolynomialApproximationOptions &options) {
1673-
// Patterns for leveraging existing f32 lowerings on other data types.
1674-
patterns
1675-
.add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676-
ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677-
ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678-
ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679-
ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680-
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681-
patterns.getContext());
1682-
1683-
patterns
1684-
.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1685-
LogApproximation, Log2Approximation, Log1pApproximation,
1686-
ErfPolynomialApproximation, AsinPolynomialApproximation,
1687-
AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688-
CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1689-
SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
1805+
mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) {
1806+
return name == "atan" || name == "atan2" || name == "tanh" ||
1807+
name == "log" || name == "log2" || name == "log1p" ||
1808+
name == "erf" || name == "exp" || name == "expm1" ||
1809+
name == "cbrt" || name == "sin" || name == "cos";
1810+
});
1811+
1812+
populateMathPolynomialApproximationPatterns(patterns, [](StringRef name) {
1813+
return name == "atan" || name == "atan2" || name == "tanh" ||
1814+
name == "log" || name == "log2" || name == "log1p" ||
1815+
name == "erf" || name == "asin" || name == "acos" || name == "exp" ||
1816+
name == "expm1" || name == "cbrt" || name == "sin" || name == "cos";
1817+
});
1818+
16901819
if (options.enableAvx2) {
1691-
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692-
patterns.getContext());
1820+
auto predicateRsqrt = [](StringRef name) { return name == "rsqrt"; };
1821+
mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt);
1822+
mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt);
16931823
}
16941824
}

0 commit comments

Comments
 (0)