@@ -91,34 +91,40 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
91
91
}
92
92
93
93
// / Expands tanh op into
94
- // / 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
95
- // / 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
94
+ // / 1-exp^{-2x} / 1+exp^{-2x}
95
+ // / To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
96
+ // / We compute a "signs" value which is -1 if input is negative and +1 if input
97
+ // / is positive. Then multiply the input by this value, guaranteeing that the
98
+ // / result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
99
+ // / 1]. Expand the computation on the input `x * sign(x)`, then multiply the
100
+ // / result by `sign(x)` to retain sign of the real result.
96
101
static LogicalResult convertTanhOp (math::TanhOp op, PatternRewriter &rewriter) {
97
102
auto floatType = op.getOperand ().getType ();
98
103
Location loc = op.getLoc ();
104
+ Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
99
105
Value one = createFloatConst (loc, floatType, 1.0 , rewriter);
100
- Value two = createFloatConst (loc, floatType, 2.0 , rewriter);
101
- Value doubledX = rewriter.create <arith::MulFOp>(loc, op.getOperand (), two);
106
+ Value negTwo = createFloatConst (loc, floatType, -2.0 , rewriter);
107
+
108
+ // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
109
+ Value sign = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
110
+ op.getOperand (), zero);
111
+ sign = rewriter.create <arith::SIToFPOp>(loc, floatType, sign);
112
+ sign = rewriter.create <arith::MulFOp>(loc, sign, negTwo);
113
+ sign = rewriter.create <arith::AddFOp>(loc, sign, one);
102
114
103
- // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
104
- Value negDoubledX = rewriter.create <arith::NegFOp>(loc, doubledX);
115
+ // Normalize input to positive value: y = sign(x) * x
116
+ Value positiveX = rewriter.create <arith::MulFOp>(loc, sign, op.getOperand ());
117
+
118
+ // Decompose on normalized input
119
+ Value negDoubledX = rewriter.create <arith::MulFOp>(loc, negTwo, positiveX);
105
120
Value exp2x = rewriter.create <math::ExpOp>(loc, negDoubledX);
106
121
Value dividend = rewriter.create <arith::SubFOp>(loc, one, exp2x);
107
122
Value divisor = rewriter.create <arith::AddFOp>(loc, one, exp2x);
108
123
Value positiveRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
109
124
110
- // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
111
- exp2x = rewriter.create <math::ExpOp>(loc, doubledX);
112
- dividend = rewriter.create <arith::SubFOp>(loc, exp2x, one);
113
- divisor = rewriter.create <arith::AddFOp>(loc, exp2x, one);
114
- Value negativeRes = rewriter.create <arith::DivFOp>(loc, dividend, divisor);
125
+ // Multiply result by sign(x) to retain signs from negative inputs
126
+ rewriter.replaceOpWithNewOp <arith::MulFOp>(op, sign, positiveRes);
115
127
116
- // tanh(x) = x >= 0 ? positiveRes : negativeRes
117
- Value zero = createFloatConst (loc, floatType, 0.0 , rewriter);
118
- Value cmpRes = rewriter.create <arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
119
- op.getOperand (), zero);
120
- rewriter.replaceOpWithNewOp <arith::SelectOp>(op, cmpRes, positiveRes,
121
- negativeRes);
122
128
return success ();
123
129
}
124
130
0 commit comments