1
- // ===- ExpandTanh .cpp - Code to perform expanding tanh op ---------- -------===//
1
+ // ===- ExpandPatterns .cpp - Code to expand various math operations. -------===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file implements expansion of tanh op .
9
+ // This file implements expansion of various math operations .
10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
12
23
23
using namespace mlir ;
24
24
25
25
// / Create a float constant.
26
- static Value createFloatConst (Location loc, Type type, double value,
26
+ static Value createFloatConst (Location loc, Type type, APFloat value,
27
27
OpBuilder &b) {
28
- auto attr = b.getFloatAttr (getElementTypeOrSelf (type), value);
28
+ bool losesInfo = false ;
29
+ auto eltType = getElementTypeOrSelf (type);
30
+ // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
31
+ value.convert (cast<FloatType>(eltType).getFloatSemantics (),
32
+ APFloat::rmNearestTiesToEven, &losesInfo);
33
+ auto attr = b.getFloatAttr (eltType, value);
29
34
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
30
35
return b.create <arith::ConstantOp>(loc,
31
36
DenseElementsAttr::get (shapedTy, attr));
@@ -34,7 +39,12 @@ static Value createFloatConst(Location loc, Type type, double value,
34
39
return b.create <arith::ConstantOp>(loc, attr);
35
40
}
36
41
37
- // / Create a float constant.
42
+ static Value createFloatConst (Location loc, Type type, double value,
43
+ OpBuilder &b) {
44
+ return createFloatConst (loc, type, APFloat (value), b);
45
+ }
46
+
47
+ // / Create an integer constant.
38
48
static Value createIntConst (Location loc, Type type, int64_t value,
39
49
OpBuilder &b) {
40
50
auto attr = b.getIntegerAttr (getElementTypeOrSelf (type), value);
@@ -202,6 +212,69 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
202
212
rewriter.replaceOp (op, ret);
203
213
return success ();
204
214
}
215
+
216
+ // Convert `math.fpowi` to a series of `arith.mulf` operations.
217
+ // If the power is negative, we divide one by the result.
218
+ // If both the base and power are zero, the result is 1.
219
+ static LogicalResult convertFPowICstOp (math::FPowIOp op,
220
+ PatternRewriter &rewriter) {
221
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
222
+ Value base = op.getOperand (0 );
223
+ Value power = op.getOperand (1 );
224
+ Type baseType = base.getType ();
225
+
226
+ Attribute cstAttr;
227
+ if (!matchPattern (power, m_Constant (&cstAttr)))
228
+ return failure ();
229
+
230
+ APInt value;
231
+ if (!matchPattern (cstAttr, m_ConstantInt (&value)))
232
+ return failure ();
233
+
234
+ int64_t powerInt = value.getSExtValue ();
235
+ bool isNegative = powerInt < 0 ;
236
+ int64_t absPower = std::abs (powerInt);
237
+ Value one = createFloatConst (op->getLoc (), baseType, 1.00 , rewriter);
238
+ Value res = createFloatConst (op->getLoc (), baseType, 1.00 , rewriter);
239
+
240
+ while (absPower > 0 ) {
241
+ if (absPower & 1 )
242
+ res = b.create <arith::MulFOp>(baseType, base, res);
243
+ absPower >>= 1 ;
244
+ base = b.create <arith::MulFOp>(baseType, base, base);
245
+ }
246
+
247
+ // Make sure not to introduce UB in case of negative power.
248
+ if (isNegative) {
249
+ auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf (baseType))
250
+ .getFloatSemantics ();
251
+ Value zero =
252
+ createFloatConst (op->getLoc (), baseType,
253
+ APFloat::getZero (sem, /* Negative=*/ false ), rewriter);
254
+ Value negZero =
255
+ createFloatConst (op->getLoc (), baseType,
256
+ APFloat::getZero (sem, /* Negative=*/ true ), rewriter);
257
+ Value posInfinity =
258
+ createFloatConst (op->getLoc (), baseType,
259
+ APFloat::getInf (sem, /* Negative=*/ false ), rewriter);
260
+ Value negInfinity =
261
+ createFloatConst (op->getLoc (), baseType,
262
+ APFloat::getInf (sem, /* Negative=*/ true ), rewriter);
263
+ Value zeroEqCheck =
264
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
265
+ Value negZeroEqCheck =
266
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
267
+ res = b.create <arith::DivFOp>(baseType, one, res);
268
+ res =
269
+ b.create <arith::SelectOp>(op->getLoc (), zeroEqCheck, posInfinity, res);
270
+ res = b.create <arith::SelectOp>(op->getLoc (), negZeroEqCheck, negInfinity,
271
+ res);
272
+ }
273
+
274
+ rewriter.replaceOp (op, res);
275
+ return success ();
276
+ }
277
+
205
278
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
206
279
static LogicalResult convertPowfOp (math::PowFOp op, PatternRewriter &rewriter) {
207
280
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
@@ -517,6 +590,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
517
590
patterns.add (convertPowfOp);
518
591
}
519
592
593
+ void mlir::populateExpandFPowIPattern (RewritePatternSet &patterns) {
594
+ patterns.add (convertFPowICstOp);
595
+ }
596
+
520
597
void mlir::populateExpandRoundFPattern (RewritePatternSet &patterns) {
521
598
patterns.add (convertRoundOp);
522
599
}
0 commit comments