10
10
// that do not rely on any of the library functions.
11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
-
14
13
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
15
14
#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
16
15
#include " mlir/Dialect/Math/IR/Math.h"
20
19
#include " mlir/IR/ImplicitLocOpBuilder.h"
21
20
#include " mlir/Transforms/DialectConversion.h"
22
21
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
22
+ #include < limits.h>
23
23
24
24
using namespace mlir ;
25
25
using namespace mlir ::vector;
@@ -28,6 +28,8 @@ using TypePredicate = llvm::function_ref<bool(Type)>;
28
28
29
29
static bool isF32 (Type type) { return type.isF32 (); }
30
30
31
+ static bool isI32 (Type type) { return type.isInteger (32 ); }
32
+
31
33
// Returns vector width if the element type is matching the predicate (scalars
32
34
// that do match the predicate have width equal to `1`).
33
35
static Optional<int > vectorWidth (Type type, TypePredicate pred) {
@@ -153,6 +155,30 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
153
155
return {normalizedFraction, exponent};
154
156
}
155
157
158
+ // Computes exp2 for an i32 argument.
159
+ static Value exp2I32 (ImplicitLocOpBuilder &builder, Value arg) {
160
+ assert (isI32 (elementType (arg.getType ())) && " argument must be i32 type" );
161
+
162
+ int width = vectorWidth (arg.getType ());
163
+
164
+ auto bcast = [&](Value value) -> Value {
165
+ return broadcast (builder, value, width);
166
+ };
167
+
168
+ auto f32Vec = broadcast (builder.getF32Type (), width);
169
+ // The exponent of f32 located at 23-bit.
170
+ auto exponetBitLocation = bcast (i32Cst (builder, 23 ));
171
+ // Set the exponent bias to zero.
172
+ auto bias = bcast (i32Cst (builder, 127 ));
173
+
174
+ Value biasedArg = builder.create <AddIOp>(arg, bias);
175
+ Value exp2ValueInt =
176
+ builder.create <ShiftLeftOp>(biasedArg, exponetBitLocation);
177
+ Value exp2ValueF32 = builder.create <LLVM::BitcastOp>(f32Vec, exp2ValueInt);
178
+
179
+ return exp2ValueF32;
180
+ }
181
+
156
182
// ----------------------------------------------------------------------------//
157
183
// TanhOp approximation.
158
184
// ----------------------------------------------------------------------------//
@@ -230,6 +256,11 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
230
256
return success ();
231
257
}
232
258
259
+ #define LN2_VALUE \
260
+ 0 .693147180559945309417232121458176568075500134360255254120680009493393621L
261
+ #define LN2E_VALUE \
262
+ 1 .442695040888963407359924681001892137426645954152985934135449406931109219L
263
+
233
264
// ----------------------------------------------------------------------------//
234
265
// LogOp approximation.
235
266
// ----------------------------------------------------------------------------//
@@ -247,9 +278,6 @@ struct LogApproximation : public OpRewritePattern<math::LogOp> {
247
278
};
248
279
} // namespace
249
280
250
- #define LN2_VALUE \
251
- 0 .693147180559945309417232121458176568075500134360255254120680009493393621L
252
-
253
281
LogicalResult
254
282
LogApproximation::matchAndRewrite (math::LogOp op,
255
283
PatternRewriter &rewriter) const {
@@ -353,9 +381,125 @@ LogApproximation::matchAndRewrite(math::LogOp op,
353
381
return success ();
354
382
}
355
383
384
+ // ----------------------------------------------------------------------------//
385
+ // Exp approximation.
386
+ // ----------------------------------------------------------------------------//
387
+
388
+ namespace {
389
+
390
+ struct ExpApproximation : public OpRewritePattern <math::ExpOp> {
391
+ public:
392
+ using OpRewritePattern::OpRewritePattern;
393
+
394
+ LogicalResult matchAndRewrite (math::ExpOp op,
395
+ PatternRewriter &rewriter) const final ;
396
+ };
397
+ } // namespace
398
+
399
+ // Approximate exp(x) using its reduced range exp(y) where y is in the range
400
+ // [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
401
+ // = exp(y) * 2^k. exp(y).
402
+ LogicalResult
403
+ ExpApproximation::matchAndRewrite (math::ExpOp op,
404
+ PatternRewriter &rewriter) const {
405
+ auto width = vectorWidth (op.operand ().getType (), isF32);
406
+ if (!width.hasValue ())
407
+ return rewriter.notifyMatchFailure (op, " unsupported operand type" );
408
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
409
+
410
+ // TODO: Consider a common pattern rewriter with all methods below to
411
+ // write the approximations.
412
+ auto bcast = [&](Value value) -> Value {
413
+ return broadcast (builder, value, *width);
414
+ };
415
+ auto fmla = [&](Value a, Value b, Value c) {
416
+ return builder.create <FmaFOp>(a, b, c);
417
+ };
418
+ auto mul = [&](Value a, Value b) -> Value {
419
+ return builder.create <MulFOp>(a, b);
420
+ };
421
+ auto sub = [&](Value a, Value b) -> Value {
422
+ return builder.create <SubFOp>(a, b);
423
+ };
424
+ auto floor = [&](Value a) { return builder.create <FloorFOp>(a); };
425
+
426
+ Value cstLn2 = bcast (f32Cst (builder, static_cast <float >(LN2_VALUE)));
427
+ Value cstLN2E = bcast (f32Cst (builder, static_cast <float >(LN2E_VALUE)));
428
+
429
+ // Polynomial coefficients.
430
+ Value cstCephesExpP0 = bcast (f32Cst (builder, 1.0 ));
431
+ Value cstCephesExpP1 = bcast (f32Cst (builder, 1.0 ));
432
+ Value cstCephesExpP2 = bcast (f32Cst (builder, 0 .49970514590562437052f ));
433
+ Value cstCephesExpP3 = bcast (f32Cst (builder, 0 .16873890085469545053f ));
434
+ Value cstCephesExpP4 = bcast (f32Cst (builder, 0 .03668965196652099192f ));
435
+ Value cstCephesExpP5 = bcast (f32Cst (builder, 0 .01314350012789660196f ));
436
+
437
+ Value x = op.operand ();
438
+
439
+ // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
440
+ Value xL2Inv = mul (x, cstLN2E);
441
+ Value kF32 = floor (xL2Inv);
442
+ Value kLn2 = mul (kF32 , cstLn2);
443
+ Value y = sub (x, kLn2 );
444
+
445
+ // Use Estrin's evaluation scheme with 3 independent parts:
446
+ // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
447
+ Value y2 = mul (y, y);
448
+ Value y4 = mul (y2, y2);
449
+
450
+ Value q0 = fmla (cstCephesExpP1, y, cstCephesExpP0);
451
+ Value q1 = fmla (cstCephesExpP3, y, cstCephesExpP2);
452
+ Value q2 = fmla (cstCephesExpP5, y, cstCephesExpP4);
453
+ Value expY = fmla (q1, y2, q0);
454
+ expY = fmla (q2, y4, expY);
455
+
456
+ auto i32Vec = broadcast (builder.getI32Type (), *width);
457
+
458
+ // exp2(k)
459
+ Value k = builder.create <FPToSIOp>(kF32 , i32Vec);
460
+ Value exp2KValue = exp2I32 (builder, k);
461
+
462
+ // exp(x) = exp(y) * exp2(k)
463
+ expY = mul (expY, exp2KValue);
464
+
465
+ // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
466
+ // partitioned as the following:
467
+ // exp(x) = 0, x <= -inf
468
+ // exp(x) = underflow (min_float), x <= -88
469
+ // exp(x) = inf (min_float), x >= 88
470
+ // Note: |k| = 127 is the value where the 8-bits exponent saturates.
471
+ Value zerof32Const = bcast (f32Cst (builder, 0 ));
472
+ auto constPosInfinity =
473
+ bcast (f32Cst (builder, std::numeric_limits<float >::infinity ()));
474
+ auto constNegIfinity =
475
+ bcast (f32Cst (builder, -std::numeric_limits<float >::infinity ()));
476
+ auto underflow = bcast (f32Cst (builder, std::numeric_limits<float >::min ()));
477
+
478
+ Value kMaxConst = bcast (i32Cst (builder, 127 ));
479
+ Value kMaxNegConst = bcast (i32Cst (builder, -127 ));
480
+ Value rightBound = builder.create <CmpIOp>(CmpIPredicate::sle, k, kMaxConst );
481
+ Value leftBound = builder.create <CmpIOp>(CmpIPredicate::sge, k, kMaxNegConst );
482
+
483
+ Value isNegInfinityX =
484
+ builder.create <CmpFOp>(CmpFPredicate::OEQ, x, constNegIfinity);
485
+ Value isPostiveX =
486
+ builder.create <CmpFOp>(CmpFPredicate::OGT, x, zerof32Const);
487
+ Value isComputable = builder.create <AndOp>(rightBound, leftBound);
488
+
489
+ expY = builder.create <SelectOp>(
490
+ isComputable, expY,
491
+ builder.create <SelectOp>(
492
+ isPostiveX, constPosInfinity,
493
+ builder.create <SelectOp>(isNegInfinityX, zerof32Const, underflow)));
494
+
495
+ rewriter.replaceOp (op, expY);
496
+
497
+ return success ();
498
+ }
499
+
356
500
// ----------------------------------------------------------------------------//
357
501
358
502
void mlir::populateMathPolynomialApproximationPatterns (
359
503
OwningRewritePatternList &patterns, MLIRContext *ctx) {
360
- patterns.insert <TanhApproximation, LogApproximation>(ctx);
504
+ patterns.insert <TanhApproximation, LogApproximation, ExpApproximation >(ctx);
361
505
}
0 commit comments