Skip to content

Commit d39ac3a

Browse files
authored
[mlir][math] Reland 58ef9be (#85436)
The previous implementation decomposes tanh(x) into `(exp(2x) - 1)/(exp(2x)+1), x < 0` `(1 - exp(-2x))/(1 + exp(-2x)), x >= 0` This is fine as it avoids overflow with the exponential, but the whole decomposition is computed for both cases unconditionally, then the result is chosen based off the sign of the input. This results in doing two expensive exp computations. The proposed change avoids doing the whole computation twice by exploiting the reflection symmetry `tanh(-x) = -tanh(x)`. We can "normalize" the input to be positive by setting `y = sign(x) * x`, where the sign of `x` is computed as `sign(x) = (float)(x > 0) * (-2) + 1`. Then compute `z = tanh(y) `with the decomposition above for `x >=0` and "denormalize" the result `z * sign(x)` to retain the sign. The reason it is done this way is that it is very amenable to vectorization. This method trades the duplicate decomposition computations (which takes 5 instructions including an extra expensive exp and div) for 4 cheap instructions to compute the signs value `arith.cmpf `(which is a pre-existing instruction in the previous impl) `arith.sitofp` `arith.mulf` `arith.addf` and 1 more instruction to get the right sign in the result 5. `arith.mulf`. Moreover, numerically, this implementation will yield the exact same results as the previous implementation. As part of the relanding, a casting issue from the original commit has been fixed, i.e. casting bool to float with `uitofp`. Additionally a correctness test with `mlir-cpu-runner` has been added.
1 parent 2c6fb7c commit d39ac3a

File tree

3 files changed

+54
-28
lines changed

3 files changed

+54
-28
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,34 +91,42 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
9191
}
9292

9393
/// Expands tanh op into
94-
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
95-
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
94+
/// 1-exp^{-2x} / 1+exp^{-2x}
95+
/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
96+
/// We compute a "signs" value which is -1 if input is negative and +1 if input
97+
/// is positive. Then multiply the input by this value, guaranteeing that the
98+
/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
99+
/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
100+
/// result by `sign(x)` to retain sign of the real result.
96101
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
97102
auto floatType = op.getOperand().getType();
98103
Location loc = op.getLoc();
104+
Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
99105
Value one = createFloatConst(loc, floatType, 1.0, rewriter);
100-
Value two = createFloatConst(loc, floatType, 2.0, rewriter);
101-
Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
102-
103-
// Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
104-
Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
106+
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
107+
108+
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
109+
Value isNegative = rewriter.create<arith::CmpFOp>(
110+
loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
111+
Value isNegativeFloat =
112+
rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
113+
Value isNegativeTimesNegTwo =
114+
rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
115+
Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
116+
117+
// Normalize input to positive value: y = sign(x) * x
118+
Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
119+
120+
// Decompose on normalized input
121+
Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
105122
Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
106123
Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
107124
Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
108125
Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
109126

110-
// Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
111-
exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
112-
dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
113-
divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
114-
Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
127+
// Multiply result by sign(x) to retain signs from negative inputs
128+
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
115129

116-
// tanh(x) = x >= 0 ? positiveRes : negativeRes
117-
Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
118-
Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
119-
op.getOperand(), zero);
120-
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
121-
negativeRes);
122130
return success();
123131
}
124132

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,18 @@ func.func @tanh(%arg: f32) -> f32 {
77
}
88
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
99
// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
10-
// CHECK-DAG: %[[TWO:.+]] = arith.constant 2.000000e+00 : f32
11-
// CHECK: %[[DOUBLEDX:.+]] = arith.mulf %arg0, %[[TWO]] : f32
12-
// CHECK: %[[NEGDOUBLEDX:.+]] = arith.negf %[[DOUBLEDX]] : f32
10+
// CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32
11+
// CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32
12+
// CHECK: %[[VAL1:.+]] = arith.uitofp %[[VAL0]] : i1 to f32
13+
// CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32
14+
// CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32
15+
// CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32
16+
// CHECK: %[[NEGDOUBLEDX:.+]] = arith.mulf %[[POSX]], %[[TWO]] : f32
1317
// CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
1418
// CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
1519
// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
16-
// CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
17-
// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32
18-
// CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32
19-
// CHECK: %[[DIVISOR2:.+]] = arith.addf %[[EXP2]], %[[ONE]] : f32
20-
// CHECK: %[[RES2:.+]] = arith.divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32
21-
// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
22-
// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
20+
// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
21+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32
2322
// CHECK: return %[[RESULT]]
2423

2524
// -----

mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,12 +683,31 @@ func.func @cosh() {
683683
return
684684
}
685685

686+
// -------------------------------------------------------------------------- //
687+
// Tanh.
688+
// -------------------------------------------------------------------------- //
689+
690+
func.func @tanh_8xf32(%a : vector<8xf32>) {
691+
%r = math.tanh %a : vector<8xf32>
692+
vector.print %r : vector<8xf32>
693+
return
694+
}
695+
696+
func.func @tanh() {
697+
// CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1
698+
%v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32>
699+
call @tanh_8xf32(%v3) : (vector<8xf32>) -> ()
700+
701+
return
702+
}
703+
686704
func.func @main() {
687705
call @exp2f() : () -> ()
688706
call @roundf() : () -> ()
689707
call @powf() : () -> ()
690708
call @roundeven() : () -> ()
691709
call @sinh() : () -> ()
692710
call @cosh() : () -> ()
711+
call @tanh() : () -> ()
693712
return
694713
}

0 commit comments

Comments
 (0)