-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Math] Add optional benefit arg to populate math lowering patterns #127291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: William Moses (wsmoses) ChangesFull diff: https://github.com/llvm/llvm-project/pull/127291.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index 93cd780bba438..b7883fe9a55ff 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -23,6 +23,7 @@ class Pass;
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
+ PatternBenefit benefit = 1,
bool approximateLog1p = true);
void registerConvertMathToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index ab9a1cef20cab..6db661a7b5748 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -19,7 +19,7 @@ class OperationPass;
/// Populate the given list with patterns that convert from Math to Libm calls.
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Create a pass to convert Math operations to libm calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 98680773e00d2..196fad2d8367b 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -304,9 +304,10 @@ struct ConvertMathToLLVMPass
void mlir::populateMathToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ PatternBenefit benefit,
bool approximateLog1p) {
if (approximateLog1p)
- patterns.add<Log1pOpLowering>(converter);
+ patterns.add<Log1pOpLowering>(converter, benefit);
// clang-format off
patterns.add<
AbsFOpLowering,
@@ -337,7 +338,7 @@ void mlir::populateMathToLLVMConversionPatterns(
FTruncOpLowering,
TanOpLowering,
TanhOpLowering
- >(converter);
+ >(converter, benefit);
// clang-format on
}
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51..97ec5cf178f5e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -50,9 +50,9 @@ template <typename Op>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
- ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
+ ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit, StringRef floatFunc,
StringRef doubleFunc)
- : OpRewritePattern<Op>(context), floatFunc(floatFunc),
+ : OpRewritePattern<Op>(context, benegit, ), floatFunc(floatFunc),
doubleFunc(doubleFunc){};
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
@@ -62,10 +62,10 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
};
template <typename OpTy>
-void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
+void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit, MLIRContext *ctx,
StringRef floatFunc, StringRef doubleFunc) {
- patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
- patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
+ patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
+ patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
}
} // namespace
@@ -159,42 +159,42 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
return success();
}
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1) {
MLIRContext *ctx = patterns.getContext();
- populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
- populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
- populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
- populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
- populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
- populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
- populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
- populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
- populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
- populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
- populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
- populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
- populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
- populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
- populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
- populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
- populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
- populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
- populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
- populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
- populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
- populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
- populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
- populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
+ populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
+ populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
+ populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf", "acosh");
+ populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
+ populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf", "asinh");
+ populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f", "atan2");
+ populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
+ populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf", "atanh");
+ populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
+ populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
+ populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
+ populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
+ populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
+ populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
+ populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
+ populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f", "expm1");
+ populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf", "floor");
+ populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
+ populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
+ populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
+ populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f", "log10");
+ populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf", "log1p");
+ populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
+ populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
"roundeven");
- populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
- populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
- populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
- populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
- populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
- populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
- populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
- populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
+ populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf", "round");
+ populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
+ populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
+ populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
+ populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf", "rsqrt");
+ populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
+ populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
+ populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf", "trunc");
}
namespace {
|
You can test this locally with the following command:git-clang-format --diff 625cb5a18576dd5d193da8d0249585cb5245da5c 304c33d153289297e52a3dc94c456e990e8076e6 --extensions h,cpp -- mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp mlir/lib/Conversion/MathToLibm/MathToLibm.cpp View the diff from clang-format here.diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index c5f97c818f..12a6d9c345 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -53,7 +53,7 @@ public:
ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
StringRef floatFunc, StringRef doubleFunc)
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
- doubleFunc(doubleFunc){};
+ doubleFunc(doubleFunc) {};
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
|
Co-authored-by: Ivan R. Ivanov <[email protected]>
PatternBenefit benefit = 1, | ||
bool approximateLog1p = true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would prefer if you made PatternBenefit the last argument. Just prevents churn for existing uses and generally PatternBenefit is the last argumenet in populate functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed up arg order
…rns (llvm#127291) Co-authored-by: Ivan R. Ivanov <[email protected]>
…ns (#130782) This is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm. In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm. Note that we're only updating the new API added in #126103. The legacy one (`mlir::populateMathPolynomialApproximationPatterns`) is left unmodified to encourage users to move out of it.
No description provided.