@@ -1667,28 +1667,158 @@ void mlir::populatePolynomialApproximateErfPattern(
1667
1667
patterns.add <ErfPolynomialApproximation>(patterns.getContext ());
1668
1668
}
1669
1669
1670
+ void mlir::populateMathF32ExpansionPatterns (
1671
+ RewritePatternSet &patterns,
1672
+ const std::function<bool (StringRef)> &predicate) {
1673
+ MLIRContext *context = patterns.getContext ();
1674
+ if (predicate (" acos" )) {
1675
+ patterns.add <ReuseF32Expansion<math::AcosOp>>(context);
1676
+ }
1677
+ if (predicate (" acosh" )) {
1678
+ patterns.add <ReuseF32Expansion<math::AcoshOp>>(context);
1679
+ }
1680
+ if (predicate (" asin" )) {
1681
+ patterns.add <ReuseF32Expansion<math::AsinOp>>(context);
1682
+ }
1683
+ if (predicate (" asinh" )) {
1684
+ patterns.add <ReuseF32Expansion<math::AsinhOp>>(context);
1685
+ }
1686
+ if (predicate (" atan" )) {
1687
+ patterns.add <ReuseF32Expansion<math::AtanOp>>(context);
1688
+ }
1689
+ if (predicate (" atan2" )) {
1690
+ patterns.add <ReuseF32Expansion<math::Atan2Op>>(context);
1691
+ }
1692
+ if (predicate (" atanh" )) {
1693
+ patterns.add <ReuseF32Expansion<math::AtanhOp>>(context);
1694
+ }
1695
+ if (predicate (" cbrt" )) {
1696
+ patterns.add <ReuseF32Expansion<math::CbrtOp>>(context);
1697
+ }
1698
+ if (predicate (" cos" )) {
1699
+ patterns.add <ReuseF32Expansion<math::CosOp>>(context);
1700
+ }
1701
+ if (predicate (" cosh" )) {
1702
+ patterns.add <ReuseF32Expansion<math::CoshOp>>(context);
1703
+ }
1704
+ if (predicate (" erf" )) {
1705
+ patterns.add <ReuseF32Expansion<math::ErfOp>>(context);
1706
+ }
1707
+ if (predicate (" exp" )) {
1708
+ patterns.add <ReuseF32Expansion<math::ExpOp>>(context);
1709
+ }
1710
+ if (predicate (" exp2" )) {
1711
+ patterns.add <ReuseF32Expansion<math::Exp2Op>>(context);
1712
+ }
1713
+ if (predicate (" expm1" )) {
1714
+ patterns.add <ReuseF32Expansion<math::ExpM1Op>>(context);
1715
+ }
1716
+ if (predicate (" log" )) {
1717
+ patterns.add <ReuseF32Expansion<math::LogOp>>(context);
1718
+ }
1719
+ if (predicate (" log10" )) {
1720
+ patterns.add <ReuseF32Expansion<math::Log10Op>>(context);
1721
+ }
1722
+ if (predicate (" log2" )) {
1723
+ patterns.add <ReuseF32Expansion<math::Log2Op>>(context);
1724
+ }
1725
+ if (predicate (" log1p" )) {
1726
+ patterns.add <ReuseF32Expansion<math::Log1pOp>>(context);
1727
+ }
1728
+ if (predicate (" powf" )) {
1729
+ patterns.add <ReuseF32Expansion<math::PowFOp>>(context);
1730
+ }
1731
+ if (predicate (" rsqrt" )) {
1732
+ patterns.add <ReuseF32Expansion<math::RsqrtOp>>(context);
1733
+ }
1734
+ if (predicate (" sin" )) {
1735
+ patterns.add <ReuseF32Expansion<math::SinOp>>(context);
1736
+ }
1737
+ if (predicate (" sinh" )) {
1738
+ patterns.add <ReuseF32Expansion<math::SinhOp>>(context);
1739
+ }
1740
+ if (predicate (" sqrt" )) {
1741
+ patterns.add <ReuseF32Expansion<math::SqrtOp>>(context);
1742
+ }
1743
+ if (predicate (" tan" )) {
1744
+ patterns.add <ReuseF32Expansion<math::TanOp>>(context);
1745
+ }
1746
+ if (predicate (" tanh" )) {
1747
+ patterns.add <ReuseF32Expansion<math::TanhOp>>(context);
1748
+ }
1749
+ }
1750
+
1751
+ void mlir::populateMathPolynomialApproximationPatterns (
1752
+ RewritePatternSet &patterns,
1753
+ const std::function<bool (StringRef)> &predicate) {
1754
+ MLIRContext *context = patterns.getContext ();
1755
+ if (predicate (" acos" )) {
1756
+ patterns.add <AcosPolynomialApproximation>(context);
1757
+ }
1758
+ if (predicate (" asin" )) {
1759
+ patterns.add <AsinPolynomialApproximation>(context);
1760
+ }
1761
+ if (predicate (" atan" )) {
1762
+ patterns.add <AtanApproximation>(context);
1763
+ }
1764
+ if (predicate (" atan2" )) {
1765
+ patterns.add <Atan2Approximation>(context);
1766
+ }
1767
+ if (predicate (" cbrt" )) {
1768
+ patterns.add <CbrtApproximation>(context);
1769
+ }
1770
+ if (predicate (" cos" )) {
1771
+ patterns.add <SinAndCosApproximation<false , math::CosOp>>(context);
1772
+ }
1773
+ if (predicate (" erf" )) {
1774
+ patterns.add <ErfPolynomialApproximation>(context);
1775
+ }
1776
+ if (predicate (" exp" )) {
1777
+ patterns.add <ExpApproximation>(context);
1778
+ }
1779
+ if (predicate (" expm1" )) {
1780
+ patterns.add <ExpM1Approximation>(context);
1781
+ }
1782
+ if (predicate (" log" )) {
1783
+ patterns.add <LogApproximation>(context);
1784
+ }
1785
+ if (predicate (" log2" )) {
1786
+ patterns.add <Log2Approximation>(context);
1787
+ }
1788
+ if (predicate (" log1p" )) {
1789
+ patterns.add <Log1pApproximation>(context);
1790
+ }
1791
+ if (predicate (" rsqrt" )) {
1792
+ patterns.add <RsqrtApproximation>(context);
1793
+ }
1794
+ if (predicate (" sin" )) {
1795
+ patterns.add <SinAndCosApproximation<true , math::SinOp>>(context);
1796
+ }
1797
+ if (predicate (" tanh" )) {
1798
+ patterns.add <TanhApproximation>(context);
1799
+ }
1800
+ }
1801
+
1670
1802
void mlir::populateMathPolynomialApproximationPatterns (
1671
1803
RewritePatternSet &patterns,
1672
1804
const MathPolynomialApproximationOptions &options) {
1673
- // Patterns for leveraging existing f32 lowerings on other data types.
1674
- patterns
1675
- .add <ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676
- ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677
- ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678
- ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679
- ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680
- ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681
- patterns.getContext ());
1682
-
1683
- patterns
1684
- .add <AtanApproximation, Atan2Approximation, TanhApproximation,
1685
- LogApproximation, Log2Approximation, Log1pApproximation,
1686
- ErfPolynomialApproximation, AsinPolynomialApproximation,
1687
- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688
- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1689
- SinAndCosApproximation<false , math::CosOp>>(patterns.getContext ());
1805
+ mlir::populateMathF32ExpansionPatterns (patterns, [](StringRef name) {
1806
+ return name == " atan" || name == " atan2" || name == " tanh" ||
1807
+ name == " log" || name == " log2" || name == " log1p" ||
1808
+ name == " erf" || name == " exp" || name == " expm1" ||
1809
+ name == " cbrt" || name == " sin" || name == " cos" ;
1810
+ });
1811
+
1812
+ populateMathPolynomialApproximationPatterns (patterns, [](StringRef name) {
1813
+ return name == " atan" || name == " atan2" || name == " tanh" ||
1814
+ name == " log" || name == " log2" || name == " log1p" ||
1815
+ name == " erf" || name == " asin" || name == " acos" || name == " exp" ||
1816
+ name == " expm1" || name == " cbrt" || name == " sin" || name == " cos" ;
1817
+ });
1818
+
1690
1819
if (options.enableAvx2 ) {
1691
- patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692
- patterns.getContext ());
1820
+ auto predicateRsqrt = [](StringRef name) { return name == " rsqrt" ; };
1821
+ mlir::populateMathF32ExpansionPatterns (patterns, predicateRsqrt);
1822
+ mlir::populateMathPolynomialApproximationPatterns (patterns, predicateRsqrt);
1693
1823
}
1694
1824
}
0 commit comments