Skip to content

Commit 5670262

Browse files
committed
[mlir][math] Add Polynomial Approximation for acosh, asinh, atanh ops
1 parent 1241e76 commit 5670262

File tree

5 files changed

+189
-99
lines changed

5 files changed

+189
-99
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ void populateExpandCtlzPattern(RewritePatternSet &patterns);
3030
void populateExpandTanPattern(RewritePatternSet &patterns);
3131
void populateExpandSinhPattern(RewritePatternSet &patterns);
3232
void populateExpandCoshPattern(RewritePatternSet &patterns);
33-
void populateExpandTanhPattern(RewritePatternSet &patterns);
33+
void populateExpandAsinhPattern(RewritePatternSet &patterns);
34+
void populateExpandAcoshPattern(RewritePatternSet &patterns);
35+
void populateExpandAtanhPattern(RewritePatternSet &patterns);
3436
void populateExpandFmaFPattern(RewritePatternSet &patterns);
3537
void populateExpandFloorFPattern(RewritePatternSet &patterns);
3638
void populateExpandCeilFPattern(RewritePatternSet &patterns);

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 73 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
7373
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
7474
Value operand = op.getOperand();
7575
Type opType = operand.getType();
76-
Value exp = b.create<math::ExpOp>(operand);
7776

78-
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
79-
Value nexp = b.create<arith::DivFOp>(one, exp);
77+
Value exp = b.create<math::ExpOp>(operand);
78+
Value neg = b.create<arith::NegFOp>(operand);
79+
Value nexp = b.create<math::ExpOp>(neg);
8080
Value sub = b.create<arith::SubFOp>(exp, nexp);
81-
Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
82-
Value div = b.create<arith::DivFOp>(sub, two);
83-
rewriter.replaceOp(op, div);
81+
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
82+
Value res = b.create<arith::MulFOp>(sub, half);
83+
rewriter.replaceOp(op, res);
8484
return success();
8585
}
8686

@@ -89,54 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
8989
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
9090
Value operand = op.getOperand();
9191
Type opType = operand.getType();
92-
Value exp = b.create<math::ExpOp>(operand);
9392

94-
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
95-
Value nexp = b.create<arith::DivFOp>(one, exp);
93+
Value exp = b.create<math::ExpOp>(operand);
94+
Value neg = b.create<arith::NegFOp>(operand);
95+
Value nexp = b.create<math::ExpOp>(neg);
9696
Value add = b.create<arith::AddFOp>(exp, nexp);
97-
Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
98-
Value div = b.create<arith::DivFOp>(add, two);
99-
rewriter.replaceOp(op, div);
100-
return success();
101-
}
102-
103-
/// Expands tanh op into
104-
/// 1-exp^{-2x} / 1+exp^{-2x}
105-
/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
106-
/// We compute a "signs" value which is -1 if input is negative and +1 if input
107-
/// is positive. Then multiply the input by this value, guaranteeing that the
108-
/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
109-
/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
110-
/// result by `sign(x)` to retain sign of the real result.
111-
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
112-
auto floatType = op.getOperand().getType();
113-
Location loc = op.getLoc();
114-
Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
115-
Value one = createFloatConst(loc, floatType, 1.0, rewriter);
116-
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
117-
118-
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
119-
Value isNegative = rewriter.create<arith::CmpFOp>(
120-
loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
121-
Value isNegativeFloat =
122-
rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
123-
Value isNegativeTimesNegTwo =
124-
rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
125-
Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
126-
127-
// Normalize input to positive value: y = sign(x) * x
128-
Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
129-
130-
// Decompose on normalized input
131-
Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
132-
Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
133-
Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
134-
Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
135-
Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
136-
137-
// Multiply result by sign(x) to retain signs from negative inputs
138-
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
139-
97+
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
98+
Value res = b.create<arith::MulFOp>(add, half);
99+
rewriter.replaceOp(op, res);
140100
return success();
141101
}
142102

@@ -152,6 +112,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
152112
return success();
153113
}
154114

115+
// asinh(float x) -> log(x + sqrt(x**2 + 1))
116+
static LogicalResult convertAsinhOp(math::AsinhOp op,
117+
PatternRewriter &rewriter) {
118+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
119+
Value operand = op.getOperand();
120+
Type opType = operand.getType();
121+
122+
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
123+
Value fma = b.create<math::FmaOp>(operand, operand, one);
124+
Value sqrt = b.create<math::SqrtOp>(fma);
125+
Value add = b.create<arith::AddFOp>(operand, sqrt);
126+
Value res = b.create<math::LogOp>(add);
127+
rewriter.replaceOp(op, res);
128+
return success();
129+
}
130+
131+
// acosh(float x) -> log(x + sqrt(x**2 - 1))
132+
static LogicalResult convertAcoshOp(math::AcoshOp op,
133+
PatternRewriter &rewriter) {
134+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
135+
Value operand = op.getOperand();
136+
Type opType = operand.getType();
137+
138+
Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
139+
Value fma = b.create<math::FmaOp>(operand, operand, negOne);
140+
Value sqrt = b.create<math::SqrtOp>(fma);
141+
Value add = b.create<arith::AddFOp>(operand, sqrt);
142+
Value res = b.create<math::LogOp>(add);
143+
rewriter.replaceOp(op, res);
144+
return success();
145+
}
146+
147+
// atanh(float x) -> log((1 + x) / (1 - x)) / 2
148+
static LogicalResult convertAtanhOp(math::AtanhOp op,
149+
PatternRewriter &rewriter) {
150+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
151+
Value operand = op.getOperand();
152+
Type opType = operand.getType();
153+
154+
Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
155+
Value add = b.create<arith::AddFOp>(operand, one);
156+
Value neg = b.create<arith::NegFOp>(operand);
157+
Value sub = b.create<arith::AddFOp>(neg, one);
158+
Value div = b.create<arith::DivFOp>(add, sub);
159+
Value log = b.create<math::LogOp>(div);
160+
Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
161+
Value res = b.create<arith::MulFOp>(log, half);
162+
rewriter.replaceOp(op, res);
163+
return success();
164+
}
165+
155166
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
156167
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
157168
Value operandA = op.getOperand(0);
@@ -580,8 +591,16 @@ void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
580591
patterns.add(convertTanOp);
581592
}
582593

583-
void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
584-
patterns.add(convertTanhOp);
594+
void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
595+
patterns.add(convertAsinhOp);
596+
}
597+
598+
void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
599+
patterns.add(convertAcoshOp);
600+
}
601+
602+
void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
603+
patterns.add(convertAtanhOp);
585604
}
586605

587606
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,5 @@
11
// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
22

3-
// CHECK-LABEL: func @tanh
4-
func.func @tanh(%arg: f32) -> f32 {
5-
%res = math.tanh %arg : f32
6-
return %res : f32
7-
}
8-
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
9-
// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
10-
// CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32
11-
// CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32
12-
// CHECK: %[[VAL1:.+]] = arith.uitofp %[[VAL0]] : i1 to f32
13-
// CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32
14-
// CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32
15-
// CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32
16-
// CHECK: %[[NEGDOUBLEDX:.+]] = arith.mulf %[[POSX]], %[[TWO]] : f32
17-
// CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
18-
// CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
19-
// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
20-
// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
21-
// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32
22-
// CHECK: return %[[RESULT]]
23-
24-
// -----
25-
26-
27-
// CHECK-LABEL: func @vector_tanh
28-
func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
29-
// CHECK-NOT: math.tanh
30-
%res = math.tanh %arg : vector<4xf32>
31-
return %res : vector<4xf32>
32-
}
33-
34-
// -----
35-
363
// CHECK-LABEL: func @tan
374
func.func @tan(%arg: f32) -> f32 {
385
%res = math.tan %arg : f32

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ void TestExpandMathPass::runOnOperation() {
4141
populateExpandTanPattern(patterns);
4242
populateExpandSinhPattern(patterns);
4343
populateExpandCoshPattern(patterns);
44-
populateExpandTanhPattern(patterns);
44+
populateExpandAsinhPattern(patterns);
45+
populateExpandAcoshPattern(patterns);
46+
populateExpandAtanhPattern(patterns);
4547
populateExpandFmaFPattern(patterns);
4648
populateExpandFloorFPattern(patterns);
4749
populateExpandCeilFPattern(patterns);

mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -700,21 +700,119 @@ func.func @cosh() {
700700
}
701701

702702
// -------------------------------------------------------------------------- //
703-
// Tanh.
703+
// Asinh.
704704
// -------------------------------------------------------------------------- //
705705

706-
func.func @tanh_8xf32(%a : vector<8xf32>) {
707-
%r = math.tanh %a : vector<8xf32>
708-
vector.print %r : vector<8xf32>
706+
func.func @asinh_f32(%a : f32) {
707+
%r = math.asinh %a : f32
708+
vector.print %r : f32
709+
return
710+
}
711+
712+
func.func @asinh_3xf32(%a : vector<3xf32>) {
713+
%r = math.asinh %a : vector<3xf32>
714+
vector.print %r : vector<3xf32>
709715
return
710716
}
711717

712-
func.func @tanh() {
713-
// CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1
714-
%v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32>
715-
call @tanh_8xf32(%v3) : (vector<8xf32>) -> ()
718+
func.func @asinh() {
719+
// CHECK: 0
720+
%zero = arith.constant 0.0 : f32
721+
call @asinh_f32(%zero) : (f32) -> ()
716722

717-
return
723+
// CHECK: 0.881374
724+
%cst1 = arith.constant 1.0 : f32
725+
call @asinh_f32(%cst1) : (f32) -> ()
726+
727+
// CHECK: -0.881374
728+
%cst2 = arith.constant -1.0 : f32
729+
call @asinh_f32(%cst2) : (f32) -> ()
730+
731+
// CHECK: 1.81845
732+
%cst3 = arith.constant 3.0 : f32
733+
call @asinh_f32(%cst3) : (f32) -> ()
734+
735+
// CHECK: 0.247466, 0.790169, 1.44364
736+
%vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
737+
call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
738+
739+
return
740+
}
741+
742+
// -------------------------------------------------------------------------- //
743+
// Acosh.
744+
// -------------------------------------------------------------------------- //
745+
746+
func.func @acosh_f32(%a : f32) {
747+
%r = math.acosh %a : f32
748+
vector.print %r : f32
749+
return
750+
}
751+
752+
func.func @acosh_3xf32(%a : vector<3xf32>) {
753+
%r = math.acosh %a : vector<3xf32>
754+
vector.print %r : vector<3xf32>
755+
return
756+
}
757+
758+
func.func @acosh() {
759+
// CHECK: 0
760+
%zero = arith.constant 1.0 : f32
761+
call @acosh_f32(%zero) : (f32) -> ()
762+
763+
// CHECK: 1.31696
764+
%cst1 = arith.constant 2.0 : f32
765+
call @acosh_f32(%cst1) : (f32) -> ()
766+
767+
// CHECK: 2.99322
768+
%cst2 = arith.constant 10.0 : f32
769+
call @acosh_f32(%cst2) : (f32) -> ()
770+
771+
// CHECK: 0.962424, 1.76275, 2.47789
772+
%vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
773+
call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
774+
775+
return
776+
}
777+
778+
// -------------------------------------------------------------------------- //
779+
// Atanh.
780+
// -------------------------------------------------------------------------- //
781+
782+
func.func @atanh_f32(%a : f32) {
783+
%r = math.atanh %a : f32
784+
vector.print %r : f32
785+
return
786+
}
787+
788+
func.func @atanh_3xf32(%a : vector<3xf32>) {
789+
%r = math.atanh %a : vector<3xf32>
790+
vector.print %r : vector<3xf32>
791+
return
792+
}
793+
794+
func.func @atanh() {
795+
// CHECK: 0
796+
%zero = arith.constant 0.0 : f32
797+
call @atanh_f32(%zero) : (f32) -> ()
798+
799+
// CHECK: 0.549306
800+
%cst1 = arith.constant 0.5 : f32
801+
call @atanh_f32(%cst1) : (f32) -> ()
802+
803+
// CHECK: -0.549306
804+
%cst2 = arith.constant -0.5 : f32
805+
call @atanh_f32(%cst2) : (f32) -> ()
806+
807+
// CHECK: inf
808+
%cst3 = arith.constant 1.0 : f32
809+
call @atanh_f32(%cst3) : (f32) -> ()
810+
811+
// CHECK: 0.255413, 0.394229, 2.99448
812+
%vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
813+
call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
814+
815+
return
718816
}
719817

720818
func.func @main() {
@@ -724,6 +822,8 @@ func.func @main() {
724822
call @roundeven() : () -> ()
725823
call @sinh() : () -> ()
726824
call @cosh() : () -> ()
727-
call @tanh() : () -> ()
825+
call @asinh() : () -> ()
826+
call @acosh() : () -> ()
827+
call @atanh() : () -> ()
728828
return
729829
}

0 commit comments

Comments
 (0)