@@ -50,10 +50,10 @@ template <typename Op>
50
50
struct ScalarOpToLibmCall : public OpRewritePattern <Op> {
51
51
public:
52
52
using OpRewritePattern<Op>::OpRewritePattern;
53
- ScalarOpToLibmCall (MLIRContext *context, StringRef floatFunc ,
54
- StringRef doubleFunc)
55
- : OpRewritePattern<Op>(context), floatFunc(floatFunc),
56
- doubleFunc (doubleFunc){};
53
+ ScalarOpToLibmCall (MLIRContext *context, PatternBenefit benefit ,
54
+ StringRef floatFunc, StringRef doubleFunc)
55
+ : OpRewritePattern<Op>(context, benefit ), floatFunc(floatFunc),
56
+ doubleFunc (doubleFunc) {};
57
57
58
58
LogicalResult matchAndRewrite (Op op, PatternRewriter &rewriter) const final ;
59
59
@@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
62
62
};
63
63
64
64
template <typename OpTy>
65
- void populatePatternsForOp (RewritePatternSet &patterns, MLIRContext *ctx,
66
- StringRef floatFunc, StringRef doubleFunc) {
67
- patterns.add <VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
68
- patterns.add <ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
65
+ void populatePatternsForOp (RewritePatternSet &patterns, PatternBenefit benefit,
66
+ MLIRContext *ctx, StringRef floatFunc,
67
+ StringRef doubleFunc) {
68
+ patterns.add <VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
69
+ patterns.add <ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
69
70
}
70
71
71
72
} // namespace
@@ -159,42 +160,54 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
159
160
return success ();
160
161
}
161
162
162
- void mlir::populateMathToLibmConversionPatterns (RewritePatternSet &patterns) {
163
+ void mlir::populateMathToLibmConversionPatterns (RewritePatternSet &patterns,
164
+ PatternBenefit benefit) {
163
165
MLIRContext *ctx = patterns.getContext ();
164
166
165
- populatePatternsForOp<math::AbsFOp>(patterns, ctx, " fabsf" , " fabs" );
166
- populatePatternsForOp<math::AcosOp>(patterns, ctx, " acosf" , " acos" );
167
- populatePatternsForOp<math::AcoshOp>(patterns, ctx, " acoshf" , " acosh" );
168
- populatePatternsForOp<math::AsinOp>(patterns, ctx, " asinf" , " asin" );
169
- populatePatternsForOp<math::AsinhOp>(patterns, ctx, " asinhf" , " asinh" );
170
- populatePatternsForOp<math::Atan2Op>(patterns, ctx, " atan2f" , " atan2" );
171
- populatePatternsForOp<math::AtanOp>(patterns, ctx, " atanf" , " atan" );
172
- populatePatternsForOp<math::AtanhOp>(patterns, ctx, " atanhf" , " atanh" );
173
- populatePatternsForOp<math::CbrtOp>(patterns, ctx, " cbrtf" , " cbrt" );
174
- populatePatternsForOp<math::CeilOp>(patterns, ctx, " ceilf" , " ceil" );
175
- populatePatternsForOp<math::CosOp>(patterns, ctx, " cosf" , " cos" );
176
- populatePatternsForOp<math::CoshOp>(patterns, ctx, " coshf" , " cosh" );
177
- populatePatternsForOp<math::ErfOp>(patterns, ctx, " erff" , " erf" );
178
- populatePatternsForOp<math::ExpOp>(patterns, ctx, " expf" , " exp" );
179
- populatePatternsForOp<math::Exp2Op>(patterns, ctx, " exp2f" , " exp2" );
180
- populatePatternsForOp<math::ExpM1Op>(patterns, ctx, " expm1f" , " expm1" );
181
- populatePatternsForOp<math::FloorOp>(patterns, ctx, " floorf" , " floor" );
182
- populatePatternsForOp<math::FmaOp>(patterns, ctx, " fmaf" , " fma" );
183
- populatePatternsForOp<math::LogOp>(patterns, ctx, " logf" , " log" );
184
- populatePatternsForOp<math::Log2Op>(patterns, ctx, " log2f" , " log2" );
185
- populatePatternsForOp<math::Log10Op>(patterns, ctx, " log10f" , " log10" );
186
- populatePatternsForOp<math::Log1pOp>(patterns, ctx, " log1pf" , " log1p" );
187
- populatePatternsForOp<math::PowFOp>(patterns, ctx, " powf" , " pow" );
188
- populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, " roundevenf" ,
167
+ populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, " fabsf" , " fabs" );
168
+ populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, " acosf" , " acos" );
169
+ populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, " acoshf" ,
170
+ " acosh" );
171
+ populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, " asinf" , " asin" );
172
+ populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, " asinhf" ,
173
+ " asinh" );
174
+ populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, " atan2f" ,
175
+ " atan2" );
176
+ populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, " atanf" , " atan" );
177
+ populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, " atanhf" ,
178
+ " atanh" );
179
+ populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, " cbrtf" , " cbrt" );
180
+ populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, " ceilf" , " ceil" );
181
+ populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, " cosf" , " cos" );
182
+ populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, " coshf" , " cosh" );
183
+ populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, " erff" , " erf" );
184
+ populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, " expf" , " exp" );
185
+ populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, " exp2f" , " exp2" );
186
+ populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, " expm1f" ,
187
+ " expm1" );
188
+ populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, " floorf" ,
189
+ " floor" );
190
+ populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, " fmaf" , " fma" );
191
+ populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, " logf" , " log" );
192
+ populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, " log2f" , " log2" );
193
+ populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, " log10f" ,
194
+ " log10" );
195
+ populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, " log1pf" ,
196
+ " log1p" );
197
+ populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, " powf" , " pow" );
198
+ populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, " roundevenf" ,
189
199
" roundeven" );
190
- populatePatternsForOp<math::RoundOp>(patterns, ctx, " roundf" , " round" );
191
- populatePatternsForOp<math::SinOp>(patterns, ctx, " sinf" , " sin" );
192
- populatePatternsForOp<math::SinhOp>(patterns, ctx, " sinhf" , " sinh" );
193
- populatePatternsForOp<math::SqrtOp>(patterns, ctx, " sqrtf" , " sqrt" );
194
- populatePatternsForOp<math::RsqrtOp>(patterns, ctx, " rsqrtf" , " rsqrt" );
195
- populatePatternsForOp<math::TanOp>(patterns, ctx, " tanf" , " tan" );
196
- populatePatternsForOp<math::TanhOp>(patterns, ctx, " tanhf" , " tanh" );
197
- populatePatternsForOp<math::TruncOp>(patterns, ctx, " truncf" , " trunc" );
200
+ populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, " roundf" ,
201
+ " round" );
202
+ populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, " sinf" , " sin" );
203
+ populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, " sinhf" , " sinh" );
204
+ populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, " sqrtf" , " sqrt" );
205
+ populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, " rsqrtf" ,
206
+ " rsqrt" );
207
+ populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, " tanf" , " tan" );
208
+ populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, " tanhf" , " tanh" );
209
+ populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, " truncf" ,
210
+ " trunc" );
198
211
}
199
212
200
213
namespace {
0 commit comments