Skip to content

[MLIR][Math] Add optional benefit arg to populate math lowering patterns #127291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 15, 2025

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented Feb 15, 2025

No description provided.

@wsmoses wsmoses requested a review from ivanradanov February 15, 2025 01:22
@llvmbot llvmbot added the mlir label Feb 15, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 15, 2025

@llvm/pr-subscribers-mlir

Author: William Moses (wsmoses)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/127291.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h (+1)
  • (modified) mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h (+1-1)
  • (modified) mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToLibm/MathToLibm.cpp (+38-38)
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index 93cd780bba438..b7883fe9a55ff 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -23,6 +23,7 @@ class Pass;
 
 void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
                                           RewritePatternSet &patterns,
+					  PatternBenefit benefit = 1,
                                           bool approximateLog1p = true);
 
 void registerConvertMathToLLVMInterface(DialectRegistry &registry);
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index ab9a1cef20cab..6db661a7b5748 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -19,7 +19,7 @@ class OperationPass;
 
 /// Populate the given list with patterns that convert from Math to Libm calls.
 /// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Create a pass to convert Math operations to libm calls.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 98680773e00d2..196fad2d8367b 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -304,9 +304,10 @@ struct ConvertMathToLLVMPass
 
 void mlir::populateMathToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit,
     bool approximateLog1p) {
   if (approximateLog1p)
-    patterns.add<Log1pOpLowering>(converter);
+    patterns.add<Log1pOpLowering>(converter, benefit);
   // clang-format off
   patterns.add<
     AbsFOpLowering,
@@ -337,7 +338,7 @@ void mlir::populateMathToLLVMConversionPatterns(
     FTruncOpLowering,
     TanOpLowering,
     TanhOpLowering
-  >(converter);
+  >(converter, benefit);
   // clang-format on
 }
 
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51..97ec5cf178f5e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -50,9 +50,9 @@ template <typename Op>
 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 public:
   using OpRewritePattern<Op>::OpRewritePattern;
-  ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
+  ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit, StringRef floatFunc,
                      StringRef doubleFunc)
-      : OpRewritePattern<Op>(context), floatFunc(floatFunc),
+      : OpRewritePattern<Op>(context, benegit, ), floatFunc(floatFunc),
         doubleFunc(doubleFunc){};
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
@@ -62,10 +62,10 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 };
 
 template <typename OpTy>
-void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
+void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit, MLIRContext *ctx,
                            StringRef floatFunc, StringRef doubleFunc) {
-  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
-  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
+  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
+  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
 }
 
 } // namespace
@@ -159,42 +159,42 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
   return success();
 }
 
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1) {
   MLIRContext *ctx = patterns.getContext();
 
-  populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
-  populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
-  populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
-  populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
-  populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
-  populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
-  populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
-  populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
-  populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
-  populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
-  populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
-  populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
-  populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
-  populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
-  populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
-  populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
-  populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
-  populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
-  populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
-  populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
-  populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
-  populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
-  populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
-  populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
+  populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
+  populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
+  populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf", "acosh");
+  populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
+  populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf", "asinh");
+  populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f", "atan2");
+  populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
+  populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf", "atanh");
+  populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
+  populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
+  populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
+  populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
+  populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
+  populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
+  populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
+  populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f", "expm1");
+  populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf", "floor");
+  populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
+  populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
+  populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
+  populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f", "log10");
+  populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf", "log1p");
+  populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
+  populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
                                            "roundeven");
-  populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
-  populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
-  populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
-  populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
-  populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
-  populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
-  populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
-  populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
+  populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf", "round");
+  populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
+  populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
+  populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
+  populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf", "rsqrt");
+  populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
+  populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
+  populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf", "trunc");
 }
 
 namespace {

Copy link

github-actions bot commented Feb 15, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 625cb5a18576dd5d193da8d0249585cb5245da5c 304c33d153289297e52a3dc94c456e990e8076e6 --extensions h,cpp -- mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index c5f97c818f..12a6d9c345 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -53,7 +53,7 @@ public:
   ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
                      StringRef floatFunc, StringRef doubleFunc)
       : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
-        doubleFunc(doubleFunc){};
+        doubleFunc(doubleFunc) {};
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
 

Comment on lines 27 to 28
PatternBenefit benefit = 1,
bool approximateLog1p = true);
Copy link
Member

@Groverkss Groverkss Feb 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would prefer if you made PatternBenefit the last argument. Just prevents churn for existing uses and generally PatternBenefit is the last argumenet in populate functions.

Copy link
Member Author

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed up arg order

@wsmoses wsmoses merged commit 5c93eb5 into main Feb 15, 2025
5 of 7 checks passed
@wsmoses wsmoses deleted the users/wm/mathbenefit branch February 15, 2025 04:38
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
cota added a commit that referenced this pull request Mar 11, 2025
…ns (#130782)

This is a follow-up to #127291, which added the benefit arg to lowerings
to intrinsics and libm.

In this change we add the benefit arg to the math approximation and
expansion lowerings, which allows users to establish a preferred order
among all three math lowerings, namely approximations, intrinsics and
libm.

Note that we're only updating the new API added in #126103. The legacy
one (`mlir::populateMathPolynomialApproximationPatterns`) is left
unmodified to encourage users to move out of it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants