@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
73
73
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
74
74
Value operand = op.getOperand ();
75
75
Type opType = operand.getType ();
76
- Value exp = b.create <math::ExpOp>(operand);
77
76
78
- Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
79
- Value nexp = b.create <arith::DivFOp>(one, exp);
77
+ Value exp = b.create <math::ExpOp>(operand);
78
+ Value neg = b.create <arith::NegFOp>(operand);
79
+ Value nexp = b.create <math::ExpOp>(neg);
80
80
Value sub = b.create <arith::SubFOp>(exp, nexp);
81
- Value two = createFloatConst (op->getLoc (), opType, 2.0 , rewriter);
82
- Value div = b.create <arith::DivFOp >(sub, two );
83
- rewriter.replaceOp (op, div );
81
+ Value half = createFloatConst (op->getLoc (), opType, 0.5 , rewriter);
82
+ Value res = b.create <arith::MulFOp >(sub, half );
83
+ rewriter.replaceOp (op, res );
84
84
return success ();
85
85
}
86
86
@@ -89,54 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
89
89
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
90
90
Value operand = op.getOperand ();
91
91
Type opType = operand.getType ();
92
- Value exp = b.create <math::ExpOp>(operand);
93
92
94
- Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
95
- Value nexp = b.create <arith::DivFOp>(one, exp);
93
+ Value exp = b.create <math::ExpOp>(operand);
94
+ Value neg = b.create <arith::NegFOp>(operand);
95
+ Value nexp = b.create <math::ExpOp>(neg);
96
96
Value add = b.create <arith::AddFOp>(exp, nexp);
97
- Value two = createFloatConst (op->getLoc (), opType, 2.0 , rewriter);
98
- Value div = b.create <arith::DivFOp>(add, two);
99
- rewriter.replaceOp (op, div);
100
- return success ();
101
- }
102
-
103
- // / Expands tanh op into
104
- // / 1-exp^{-2x} / 1+exp^{-2x}
105
- // / To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
106
- // / We compute a "signs" value which is -1 if input is negative and +1 if input
107
- // / is positive. Then multiply the input by this value, guaranteeing that the
108
- // / result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
109
- // / 1]. Expand the computation on the input `x * sign(x)`, then multiply the
110
- // / result by `sign(x)` to retain sign of the real result.
111
- static LogicalResult convertTanhOp (math::TanhOp op, PatternRewriter &rewriter) {
112
- auto floatType = op.getOperand ().getType ();
113
- Location loc = op.getLoc ();
114
- Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
115
- Value one = createFloatConst (loc, floatType, 1.0 , rewriter);
116
- Value negTwo = createFloatConst (loc, floatType, -2.0 , rewriter);
117
-
118
- // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
119
- Value isNegative = rewriter.create <arith::CmpFOp>(
120
- loc, arith::CmpFPredicate::OLT, op.getOperand (), zero);
121
- Value isNegativeFloat =
122
- rewriter.create <arith::UIToFPOp>(loc, floatType, isNegative);
123
- Value isNegativeTimesNegTwo =
124
- rewriter.create <arith::MulFOp>(loc, isNegativeFloat, negTwo);
125
- Value sign = rewriter.create <arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
126
-
127
- // Normalize input to positive value: y = sign(x) * x
128
- Value positiveX = rewriter.create <arith::MulFOp>(loc, sign, op.getOperand ());
129
-
130
- // Decompose on normalized input
131
- Value negDoubledX = rewriter.create <arith::MulFOp>(loc, negTwo, positiveX);
132
- Value exp2x = rewriter.create <math::ExpOp>(loc, negDoubledX);
133
- Value dividend = rewriter.create <arith::SubFOp>(loc, one, exp2x);
134
- Value divisor = rewriter.create <arith::AddFOp>(loc, one, exp2x);
135
- Value positiveRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
136
-
137
- // Multiply result by sign(x) to retain signs from negative inputs
138
- rewriter.replaceOpWithNewOp <arith::MulFOp>(op, sign, positiveRes);
139
-
97
+ Value half = createFloatConst (op->getLoc (), opType, 0.5 , rewriter);
98
+ Value res = b.create <arith::MulFOp>(add, half);
99
+ rewriter.replaceOp (op, res);
140
100
return success ();
141
101
}
142
102
@@ -152,6 +112,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
152
112
return success ();
153
113
}
154
114
115
+ // asinh(float x) -> log(x + sqrt(x**2 + 1))
116
+ static LogicalResult convertAsinhOp (math::AsinhOp op,
117
+ PatternRewriter &rewriter) {
118
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
119
+ Value operand = op.getOperand ();
120
+ Type opType = operand.getType ();
121
+
122
+ Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
123
+ Value fma = b.create <math::FmaOp>(operand, operand, one);
124
+ Value sqrt = b.create <math::SqrtOp>(fma);
125
+ Value add = b.create <arith::AddFOp>(operand, sqrt);
126
+ Value res = b.create <math::LogOp>(add);
127
+ rewriter.replaceOp (op, res);
128
+ return success ();
129
+ }
130
+
131
+ // acosh(float x) -> log(x + sqrt(x**2 - 1))
132
+ static LogicalResult convertAcoshOp (math::AcoshOp op,
133
+ PatternRewriter &rewriter) {
134
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
135
+ Value operand = op.getOperand ();
136
+ Type opType = operand.getType ();
137
+
138
+ Value negOne = createFloatConst (op->getLoc (), opType, -1.0 , rewriter);
139
+ Value fma = b.create <math::FmaOp>(operand, operand, negOne);
140
+ Value sqrt = b.create <math::SqrtOp>(fma);
141
+ Value add = b.create <arith::AddFOp>(operand, sqrt);
142
+ Value res = b.create <math::LogOp>(add);
143
+ rewriter.replaceOp (op, res);
144
+ return success ();
145
+ }
146
+
147
+ // atanh(float x) -> log((1 + x) / (1 - x)) / 2
148
+ static LogicalResult convertAtanhOp (math::AtanhOp op,
149
+ PatternRewriter &rewriter) {
150
+ ImplicitLocOpBuilder b (op->getLoc (), rewriter);
151
+ Value operand = op.getOperand ();
152
+ Type opType = operand.getType ();
153
+
154
+ Value one = createFloatConst (op->getLoc (), opType, 1.0 , rewriter);
155
+ Value add = b.create <arith::AddFOp>(operand, one);
156
+ Value neg = b.create <arith::NegFOp>(operand);
157
+ Value sub = b.create <arith::AddFOp>(neg, one);
158
+ Value div = b.create <arith::DivFOp>(add, sub);
159
+ Value log = b.create <math::LogOp>(div);
160
+ Value half = createFloatConst (op->getLoc (), opType, 0.5 , rewriter);
161
+ Value res = b.create <arith::MulFOp>(log, half);
162
+ rewriter.replaceOp (op, res);
163
+ return success ();
164
+ }
165
+
155
166
static LogicalResult convertFmaFOp (math::FmaOp op, PatternRewriter &rewriter) {
156
167
ImplicitLocOpBuilder b (op->getLoc (), rewriter);
157
168
Value operandA = op.getOperand (0 );
@@ -580,8 +591,16 @@ void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
580
591
patterns.add (convertTanOp);
581
592
}
582
593
583
- void mlir::populateExpandTanhPattern (RewritePatternSet &patterns) {
584
- patterns.add (convertTanhOp);
594
+ void mlir::populateExpandAsinhPattern (RewritePatternSet &patterns) {
595
+ patterns.add (convertAsinhOp);
596
+ }
597
+
598
+ void mlir::populateExpandAcoshPattern (RewritePatternSet &patterns) {
599
+ patterns.add (convertAcoshOp);
600
+ }
601
+
602
+ void mlir::populateExpandAtanhPattern (RewritePatternSet &patterns) {
603
+ patterns.add (convertAtanhOp);
585
604
}
586
605
587
606
void mlir::populateExpandFmaFPattern (RewritePatternSet &patterns) {
0 commit comments