Skip to content

Commit ea7f211

Browse files
committed
[mlir] Add polynomial approximation for math::ExpOp
Similar to fast_exp in https://github.com/boulos/syrah Differential Revision: https://reviews.llvm.org/D97599
1 parent 74c883f commit ea7f211

File tree

3 files changed

+200
-5
lines changed

3 files changed

+200
-5
lines changed

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 149 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
// that do not rely on any of the library functions.
1111
//
1212
//===----------------------------------------------------------------------===//
13-
1413
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1514
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1615
#include "mlir/Dialect/Math/IR/Math.h"
@@ -20,6 +19,7 @@
2019
#include "mlir/IR/ImplicitLocOpBuilder.h"
2120
#include "mlir/Transforms/DialectConversion.h"
2221
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include <limits.h>
2323

2424
using namespace mlir;
2525
using namespace mlir::vector;
@@ -28,6 +28,8 @@ using TypePredicate = llvm::function_ref<bool(Type)>;
2828

2929
static bool isF32(Type type) { return type.isF32(); }
3030

31+
static bool isI32(Type type) { return type.isInteger(32); }
32+
3133
// Returns vector width if the element type is matching the predicate (scalars
3234
// that do match the predicate have width equal to `1`).
3335
static Optional<int> vectorWidth(Type type, TypePredicate pred) {
@@ -153,6 +155,30 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
153155
return {normalizedFraction, exponent};
154156
}
155157

158+
// Computes exp2 for an i32 argument.
159+
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
160+
assert(isI32(elementType(arg.getType())) && "argument must be i32 type");
161+
162+
int width = vectorWidth(arg.getType());
163+
164+
auto bcast = [&](Value value) -> Value {
165+
return broadcast(builder, value, width);
166+
};
167+
168+
auto f32Vec = broadcast(builder.getF32Type(), width);
169+
// The exponent of f32 located at 23-bit.
170+
auto exponetBitLocation = bcast(i32Cst(builder, 23));
171+
// Set the exponent bias to zero.
172+
auto bias = bcast(i32Cst(builder, 127));
173+
174+
Value biasedArg = builder.create<AddIOp>(arg, bias);
175+
Value exp2ValueInt =
176+
builder.create<ShiftLeftOp>(biasedArg, exponetBitLocation);
177+
Value exp2ValueF32 = builder.create<LLVM::BitcastOp>(f32Vec, exp2ValueInt);
178+
179+
return exp2ValueF32;
180+
}
181+
156182
//----------------------------------------------------------------------------//
157183
// TanhOp approximation.
158184
//----------------------------------------------------------------------------//
@@ -230,6 +256,11 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
230256
return success();
231257
}
232258

259+
#define LN2_VALUE \
260+
0.693147180559945309417232121458176568075500134360255254120680009493393621L
261+
#define LN2E_VALUE \
262+
1.442695040888963407359924681001892137426645954152985934135449406931109219L
263+
233264
//----------------------------------------------------------------------------//
234265
// LogOp approximation.
235266
//----------------------------------------------------------------------------//
@@ -247,9 +278,6 @@ struct LogApproximation : public OpRewritePattern<math::LogOp> {
247278
};
248279
} // namespace
249280

250-
#define LN2_VALUE \
251-
0.693147180559945309417232121458176568075500134360255254120680009493393621L
252-
253281
LogicalResult
254282
LogApproximation::matchAndRewrite(math::LogOp op,
255283
PatternRewriter &rewriter) const {
@@ -353,9 +381,125 @@ LogApproximation::matchAndRewrite(math::LogOp op,
353381
return success();
354382
}
355383

384+
//----------------------------------------------------------------------------//
385+
// Exp approximation.
386+
//----------------------------------------------------------------------------//
387+
388+
namespace {
389+
390+
struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
391+
public:
392+
using OpRewritePattern::OpRewritePattern;
393+
394+
LogicalResult matchAndRewrite(math::ExpOp op,
395+
PatternRewriter &rewriter) const final;
396+
};
397+
} // namespace
398+
399+
// Approximate exp(x) using its reduced range exp(y) where y is in the range
400+
// [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x)
401+
// = exp(y) * 2^k. exp(y).
402+
LogicalResult
403+
ExpApproximation::matchAndRewrite(math::ExpOp op,
404+
PatternRewriter &rewriter) const {
405+
auto width = vectorWidth(op.operand().getType(), isF32);
406+
if (!width.hasValue())
407+
return rewriter.notifyMatchFailure(op, "unsupported operand type");
408+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
409+
410+
// TODO: Consider a common pattern rewriter with all methods below to
411+
// write the approximations.
412+
auto bcast = [&](Value value) -> Value {
413+
return broadcast(builder, value, *width);
414+
};
415+
auto fmla = [&](Value a, Value b, Value c) {
416+
return builder.create<FmaFOp>(a, b, c);
417+
};
418+
auto mul = [&](Value a, Value b) -> Value {
419+
return builder.create<MulFOp>(a, b);
420+
};
421+
auto sub = [&](Value a, Value b) -> Value {
422+
return builder.create<SubFOp>(a, b);
423+
};
424+
auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
425+
426+
Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
427+
Value cstLN2E = bcast(f32Cst(builder, static_cast<float>(LN2E_VALUE)));
428+
429+
// Polynomial coefficients.
430+
Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0));
431+
Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0));
432+
Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f));
433+
Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f));
434+
Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f));
435+
Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f));
436+
437+
Value x = op.operand();
438+
439+
// Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2)
440+
Value xL2Inv = mul(x, cstLN2E);
441+
Value kF32 = floor(xL2Inv);
442+
Value kLn2 = mul(kF32, cstLn2);
443+
Value y = sub(x, kLn2);
444+
445+
// Use Estrin's evaluation scheme with 3 independent parts:
446+
// P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4
447+
Value y2 = mul(y, y);
448+
Value y4 = mul(y2, y2);
449+
450+
Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
451+
Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
452+
Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
453+
Value expY = fmla(q1, y2, q0);
454+
expY = fmla(q2, y4, expY);
455+
456+
auto i32Vec = broadcast(builder.getI32Type(), *width);
457+
458+
// exp2(k)
459+
Value k = builder.create<FPToSIOp>(kF32, i32Vec);
460+
Value exp2KValue = exp2I32(builder, k);
461+
462+
// exp(x) = exp(y) * exp2(k)
463+
expY = mul(expY, exp2KValue);
464+
465+
// Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its
466+
// partitioned as the following:
467+
// exp(x) = 0, x <= -inf
468+
// exp(x) = underflow (min_float), x <= -88
469+
// exp(x) = inf (min_float), x >= 88
470+
// Note: |k| = 127 is the value where the 8-bits exponent saturates.
471+
Value zerof32Const = bcast(f32Cst(builder, 0));
472+
auto constPosInfinity =
473+
bcast(f32Cst(builder, std::numeric_limits<float>::infinity()));
474+
auto constNegIfinity =
475+
bcast(f32Cst(builder, -std::numeric_limits<float>::infinity()));
476+
auto underflow = bcast(f32Cst(builder, std::numeric_limits<float>::min()));
477+
478+
Value kMaxConst = bcast(i32Cst(builder, 127));
479+
Value kMaxNegConst = bcast(i32Cst(builder, -127));
480+
Value rightBound = builder.create<CmpIOp>(CmpIPredicate::sle, k, kMaxConst);
481+
Value leftBound = builder.create<CmpIOp>(CmpIPredicate::sge, k, kMaxNegConst);
482+
483+
Value isNegInfinityX =
484+
builder.create<CmpFOp>(CmpFPredicate::OEQ, x, constNegIfinity);
485+
Value isPostiveX =
486+
builder.create<CmpFOp>(CmpFPredicate::OGT, x, zerof32Const);
487+
Value isComputable = builder.create<AndOp>(rightBound, leftBound);
488+
489+
expY = builder.create<SelectOp>(
490+
isComputable, expY,
491+
builder.create<SelectOp>(
492+
isPostiveX, constPosInfinity,
493+
builder.create<SelectOp>(isNegInfinityX, zerof32Const, underflow)));
494+
495+
rewriter.replaceOp(op, expY);
496+
497+
return success();
498+
}
499+
356500
//----------------------------------------------------------------------------//
357501

358502
void mlir::populateMathPolynomialApproximationPatterns(
359503
OwningRewritePatternList &patterns, MLIRContext *ctx) {
360-
patterns.insert<TanhApproximation, LogApproximation>(ctx);
504+
patterns.insert<TanhApproximation, LogApproximation, ExpApproximation>(ctx);
361505
}

mlir/test/Dialect/Math/polynomial-approximation.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,16 @@ func @vector(%arg0: vector<8xf32>) -> vector<8xf32> {
2020
%1 = math.log %0 : vector<8xf32>
2121
return %1 : vector<8xf32>
2222
}
23+
24+
// CHECK-LABEL: @exp_scalar
25+
func @exp_scalar(%arg0: f32) -> f32 {
26+
%0 = math.exp %arg0 : f32
27+
return %0 : f32
28+
}
29+
30+
// CHECK-LABEL: @exp_vector
31+
func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
32+
// CHECK-NOT: math.exp
33+
%0 = math.exp %arg0 : vector<8xf32>
34+
return %0 : vector<8xf32>
35+
}

mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,46 @@ func @log() {
7171
return
7272
}
7373

74+
// -------------------------------------------------------------------------- //
75+
// Log.
76+
// -------------------------------------------------------------------------- //
77+
func @exp() {
78+
// CHECK: 2.71828
79+
%0 = constant 1.0 : f32
80+
%1 = math.exp %0 : f32
81+
vector.print %1 : f32
82+
83+
// CHECK: 0.778802, 2.117, 2.71828, 3.85742
84+
%2 = constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32>
85+
%3 = math.exp %2 : vector<4xf32>
86+
vector.print %3 : vector<4xf32>
87+
88+
// CHECK: 1
89+
%zero = constant 0.0 : f32
90+
%exp_zero = math.exp %zero : f32
91+
vector.print %exp_zero : f32
92+
93+
// CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf
94+
%special_vec = constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32>
95+
%exp_special_vec = math.exp %special_vec : vector<4xf32>
96+
vector.print %exp_special_vec : vector<4xf32>
97+
98+
// CHECK: inf
99+
%inf = constant 0x7f800000 : f32
100+
%exp_inf = math.exp %inf : f32
101+
vector.print %exp_inf : f32
102+
103+
// CHECK: 0
104+
%negative_inf = constant 0xff800000 : f32
105+
%exp_negative_inf = math.exp %negative_inf : f32
106+
vector.print %exp_negative_inf : f32
107+
108+
return
109+
}
110+
74111
func @main() {
75112
call @tanh(): () -> ()
76113
call @log(): () -> ()
114+
call @exp(): () -> ()
77115
return
78116
}

0 commit comments

Comments
 (0)