Skip to content

Commit de4dab2

Browse files
author
git apple-llvm automerger
committed
Merge commit '58ef9bec0713' from llvm.org/main into next
2 parents cb78f77 + 58ef9be commit de4dab2

File tree

2 files changed

+32
-27
lines changed

2 files changed

+32
-27
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,34 +91,40 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
9191
}
9292

9393
/// Expands tanh op into
94-
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
95-
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
94+
/// 1-exp^{-2x} / 1+exp^{-2x}
95+
/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
96+
/// We compute a "signs" value which is -1 if input is negative and +1 if input
97+
/// is positive. Then multiply the input by this value, guaranteeing that the
98+
/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
99+
/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
100+
/// result by `sign(x)` to retain sign of the real result.
96101
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
97102
auto floatType = op.getOperand().getType();
98103
Location loc = op.getLoc();
104+
Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
99105
Value one = createFloatConst(loc, floatType, 1.0, rewriter);
100-
Value two = createFloatConst(loc, floatType, 2.0, rewriter);
101-
Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
106+
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
107+
108+
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
109+
Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
110+
op.getOperand(), zero);
111+
sign = rewriter.create<arith::SIToFPOp>(loc, floatType, sign);
112+
sign = rewriter.create<arith::MulFOp>(loc, sign, negTwo);
113+
sign = rewriter.create<arith::AddFOp>(loc, sign, one);
102114

103-
// Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
104-
Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
115+
// Normalize input to positive value: y = sign(x) * x
116+
Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
117+
118+
// Decompose on normalized input
119+
Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
105120
Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
106121
Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
107122
Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
108123
Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
109124

110-
// Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
111-
exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
112-
dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
113-
divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
114-
Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
125+
// Multiply result by sign(x) to retain signs from negative inputs
126+
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
115127

116-
// tanh(x) = x >= 0 ? positiveRes : negativeRes
117-
Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
118-
Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
119-
op.getOperand(), zero);
120-
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
121-
negativeRes);
122128
return success();
123129
}
124130

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,18 @@ func.func @tanh(%arg: f32) -> f32 {
77
}
88
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
99
// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
10-
// CHECK-DAG: %[[TWO:.+]] = arith.constant 2.000000e+00 : f32
11-
// CHECK: %[[DOUBLEDX:.+]] = arith.mulf %arg0, %[[TWO]] : f32
12-
// CHECK: %[[NEGDOUBLEDX:.+]] = arith.negf %[[DOUBLEDX]] : f32
10+
// CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32
11+
// CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32
12+
// CHECK: %[[VAL1:.+]] = arith.sitofp %[[VAL0]] : i1 to f32
13+
// CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32
14+
// CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32
15+
// CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32
16+
// CHECK: %[[NEGDOUBLEDX:.+]] = arith.mulf %[[POSX]], %[[TWO]] : f32
1317
// CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
1418
// CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
1519
// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
16-
// CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
17-
// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32
18-
// CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32
19-
// CHECK: %[[DIVISOR2:.+]] = arith.addf %[[EXP2]], %[[ONE]] : f32
20-
// CHECK: %[[RES2:.+]] = arith.divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32
21-
// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
22-
// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
20+
// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
21+
// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32
2322
// CHECK: return %[[RESULT]]
2423

2524
// -----

0 commit comments

Comments
 (0)