Skip to content

Commit 5c93eb5

Browse files
wsmosesivanradanov
andauthored
[MLIR][Math] Add optional benefit arg to populate math lowering patterns (llvm#127291)
Co-authored-by: Ivan R. Ivanov <[email protected]>
1 parent 61ad087 commit 5c93eb5

File tree

4 files changed

+62
-46
lines changed

4 files changed

+62
-46
lines changed

mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
1010
#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
1111

12+
#include "mlir/IR/PatternMatch.h"
1213
#include <memory>
1314

1415
namespace mlir {
@@ -23,7 +24,8 @@ class Pass;
2324

2425
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
2526
RewritePatternSet &patterns,
26-
bool approximateLog1p = true);
27+
bool approximateLog1p = true,
28+
PatternBenefit benefit = 1);
2729

2830
void registerConvertMathToLLVMInterface(DialectRegistry &registry);
2931

mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ class OperationPass;
1919

2020
/// Populate the given list with patterns that convert from Math to Libm calls.
2121
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
22-
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
22+
void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
23+
PatternBenefit benefit = 1);
2324

2425
/// Create a pass to convert Math operations to libm calls.
2526
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ struct ConvertMathToLLVMPass
304304

305305
void mlir::populateMathToLLVMConversionPatterns(
306306
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
307-
bool approximateLog1p) {
307+
bool approximateLog1p, PatternBenefit benefit) {
308308
if (approximateLog1p)
309-
patterns.add<Log1pOpLowering>(converter);
309+
patterns.add<Log1pOpLowering>(converter, benefit);
310310
// clang-format off
311311
patterns.add<
312312
AbsFOpLowering,
@@ -337,7 +337,7 @@ void mlir::populateMathToLLVMConversionPatterns(
337337
FTruncOpLowering,
338338
TanOpLowering,
339339
TanhOpLowering
340-
>(converter);
340+
>(converter, benefit);
341341
// clang-format on
342342
}
343343

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ template <typename Op>
5050
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
5151
public:
5252
using OpRewritePattern<Op>::OpRewritePattern;
53-
ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
54-
StringRef doubleFunc)
55-
: OpRewritePattern<Op>(context), floatFunc(floatFunc),
56-
doubleFunc(doubleFunc){};
53+
ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
54+
StringRef floatFunc, StringRef doubleFunc)
55+
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
56+
doubleFunc(doubleFunc) {};
5757

5858
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
5959

@@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
6262
};
6363

6464
template <typename OpTy>
65-
void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
66-
StringRef floatFunc, StringRef doubleFunc) {
67-
patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68-
patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
65+
void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
66+
MLIRContext *ctx, StringRef floatFunc,
67+
StringRef doubleFunc) {
68+
patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69+
patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
6970
}
7071

7172
} // namespace
@@ -159,42 +160,54 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
159160
return success();
160161
}
161162

162-
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
163+
void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
164+
PatternBenefit benefit) {
163165
MLIRContext *ctx = patterns.getContext();
164166

165-
populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
166-
populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
167-
populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
168-
populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
169-
populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
170-
populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
171-
populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
172-
populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
173-
populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
174-
populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
175-
populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
176-
populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
177-
populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
178-
populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
179-
populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
180-
populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
181-
populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
182-
populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
183-
populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
184-
populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
185-
populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
186-
populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
187-
populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
188-
populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
167+
populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
168+
populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
169+
populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf",
170+
"acosh");
171+
populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
172+
populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf",
173+
"asinh");
174+
populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f",
175+
"atan2");
176+
populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
177+
populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf",
178+
"atanh");
179+
populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
180+
populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
181+
populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
182+
populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
183+
populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
184+
populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
185+
populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
186+
populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",
187+
"expm1");
188+
populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf",
189+
"floor");
190+
populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
191+
populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
192+
populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
193+
populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f",
194+
"log10");
195+
populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf",
196+
"log1p");
197+
populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
198+
populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
189199
"roundeven");
190-
populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
191-
populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
192-
populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
193-
populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
194-
populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
195-
populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
196-
populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
197-
populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
200+
populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf",
201+
"round");
202+
populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
203+
populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
204+
populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
205+
populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf",
206+
"rsqrt");
207+
populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
208+
populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
209+
populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf",
210+
"trunc");
198211
}
199212

200213
namespace {

0 commit comments

Comments
 (0)