@@ -1667,28 +1667,125 @@ void mlir::populatePolynomialApproximateErfPattern(
1667
1667
patterns.add <ErfPolynomialApproximation>(patterns.getContext ());
1668
1668
}
1669
1669
1670
+ template <typename OpType>
1671
+ static void
1672
+ populateMathF32ExpansionPattern (RewritePatternSet &patterns,
1673
+ llvm::function_ref<bool (StringRef)> predicate) {
1674
+ if (predicate (OpType::getOperationName ())) {
1675
+ patterns.add <ReuseF32Expansion<OpType>>(patterns.getContext ());
1676
+ }
1677
+ }
1678
+
1679
+ void mlir::populateMathF32ExpansionPatterns (
1680
+ RewritePatternSet &patterns,
1681
+ llvm::function_ref<bool (StringRef)> predicate) {
1682
+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
1683
+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
1684
+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
1685
+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
1686
+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
1687
+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
1688
+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
1689
+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
1690
+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
1691
+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
1692
+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
1693
+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
1694
+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
1695
+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
1696
+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
1697
+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
1698
+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
1699
+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
1700
+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
1701
+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
1702
+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
1703
+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
1704
+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
1705
+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
1706
+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
1707
+ }
1708
+
1709
+ template <typename OpType, typename PatternType>
1710
+ static void populateMathPolynomialApproximationPattern (
1711
+ RewritePatternSet &patterns,
1712
+ llvm::function_ref<bool (StringRef)> predicate) {
1713
+ if (predicate (OpType::getOperationName ())) {
1714
+ patterns.add <PatternType>(patterns.getContext ());
1715
+ }
1716
+ }
1717
+
1718
+ void mlir::populateMathPolynomialApproximationPatterns (
1719
+ RewritePatternSet &patterns,
1720
+ llvm::function_ref<bool (StringRef)> predicate) {
1721
+ populateMathPolynomialApproximationPattern<AcosOp,
1722
+ AcosPolynomialApproximation>(
1723
+ patterns, predicate);
1724
+ populateMathPolynomialApproximationPattern<AsinOp,
1725
+ AsinPolynomialApproximation>(
1726
+ patterns, predicate);
1727
+ populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1728
+ patterns, predicate);
1729
+ populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1730
+ patterns, predicate);
1731
+ populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1732
+ patterns, predicate);
1733
+ populateMathPolynomialApproximationPattern<
1734
+ CosOp, SinAndCosApproximation<false , math::CosOp>>(patterns, predicate);
1735
+ populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1736
+ patterns, predicate);
1737
+ populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1738
+ patterns, predicate);
1739
+ populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1740
+ patterns, predicate);
1741
+ populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1742
+ patterns, predicate);
1743
+ populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1744
+ patterns, predicate);
1745
+ populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1746
+ patterns, predicate);
1747
+ populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1748
+ patterns, predicate);
1749
+ populateMathPolynomialApproximationPattern<
1750
+ SinOp, SinAndCosApproximation<true , math::SinOp>>(patterns, predicate);
1751
+ populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1752
+ patterns, predicate);
1753
+ }
1754
+
1670
1755
void mlir::populateMathPolynomialApproximationPatterns (
1671
1756
RewritePatternSet &patterns,
1672
1757
const MathPolynomialApproximationOptions &options) {
1673
- // Patterns for leveraging existing f32 lowerings on other data types.
1674
- patterns
1675
- .add <ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676
- ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677
- ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678
- ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679
- ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680
- ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1681
- patterns.getContext ());
1682
-
1683
- patterns
1684
- .add <AtanApproximation, Atan2Approximation, TanhApproximation,
1685
- LogApproximation, Log2Approximation, Log1pApproximation,
1686
- ErfPolynomialApproximation, AsinPolynomialApproximation,
1687
- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688
- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1689
- SinAndCosApproximation<false , math::CosOp>>(patterns.getContext ());
1758
+ mlir::populateMathF32ExpansionPatterns (patterns, [](StringRef name) -> bool {
1759
+ return llvm::is_contained (
1760
+ {math::AtanOp::getOperationName (), math::Atan2Op::getOperationName (),
1761
+ math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
1762
+ math::Log2Op::getOperationName (), math::Log1pOp::getOperationName (),
1763
+ math::ErfOp::getOperationName (), math::ExpOp::getOperationName (),
1764
+ math::ExpM1Op::getOperationName (), math::CbrtOp::getOperationName (),
1765
+ math::SinOp::getOperationName (), math::CosOp::getOperationName ()},
1766
+ name);
1767
+ });
1768
+
1769
+ populateMathPolynomialApproximationPatterns (
1770
+ patterns, [](StringRef name) -> bool {
1771
+ return llvm::is_contained (
1772
+ {math::AtanOp::getOperationName (),
1773
+ math::Atan2Op::getOperationName (),
1774
+ math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
1775
+ math::Log2Op::getOperationName (),
1776
+ math::Log1pOp::getOperationName (), math::ErfOp::getOperationName (),
1777
+ math::AsinOp::getOperationName (), math::AcosOp::getOperationName (),
1778
+ math::ExpOp::getOperationName (), math::ExpM1Op::getOperationName (),
1779
+ math::CbrtOp::getOperationName (), math::SinOp::getOperationName (),
1780
+ math::CosOp::getOperationName ()},
1781
+ name);
1782
+ });
1783
+
1690
1784
if (options.enableAvx2 ) {
1691
- patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1692
- patterns.getContext ());
1785
+ auto predicateRsqrt = [](StringRef name) {
1786
+ return name == math::RsqrtOp::getOperationName ();
1787
+ };
1788
+ mlir::populateMathF32ExpansionPatterns (patterns, predicateRsqrt);
1789
+ mlir::populateMathPolynomialApproximationPatterns (patterns, predicateRsqrt);
1693
1790
}
1694
1791
}
0 commit comments