Skip to content

Commit f43deca

Browse files
authored
Fix Tan inaccuracies on extreme complex inputs. (#92443)
Specifically, those with small/large absolute values. This ports openxla/xla#10525 and was verified with XLA's test suite.
1 parent bc9823c commit f43deca

File tree

2 files changed

+98
-280
lines changed

2 files changed

+98
-280
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -956,27 +956,12 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
956956
}
957957
};
958958

959-
struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
960-
using OpConversionPattern<complex::TanOp>::OpConversionPattern;
959+
template <typename Op>
960+
struct TanTanhOpConversion : public OpConversionPattern<Op> {
961+
using OpConversionPattern<Op>::OpConversionPattern;
961962

962963
LogicalResult
963-
matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
964-
ConversionPatternRewriter &rewriter) const override {
965-
auto loc = op.getLoc();
966-
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
967-
968-
Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex(), fmf);
969-
Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex(), fmf);
970-
rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos, fmf);
971-
return success();
972-
}
973-
};
974-
975-
struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
976-
using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
977-
978-
LogicalResult
979-
matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
964+
matchAndRewrite(Op op, Op::Adaptor adaptor,
980965
ConversionPatternRewriter &rewriter) const override {
981966
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
982967
auto loc = op.getLoc();
@@ -989,14 +974,20 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
989974
b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
990975
Value imag =
991976
b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
977+
Value negOne = b.create<arith::ConstantOp>(
978+
elementType, b.getFloatAttr(elementType, -1.0));
979+
980+
if constexpr (std::is_same_v<Op, complex::TanOp>) {
981+
// tan(x+yi) = -i*tanh(-y + xi)
982+
std::swap(real, imag);
983+
real = b.create<arith::MulFOp>(real, negOne, fmf);
984+
}
992985

993986
auto cst = [&](APFloat v) {
994987
return b.create<arith::ConstantOp>(elementType,
995988
b.getFloatAttr(elementType, v));
996989
};
997990
Value inf = cst(APFloat::getInf(floatSemantics));
998-
Value negOne = b.create<arith::ConstantOp>(
999-
elementType, b.getFloatAttr(elementType, -1.0));
1000991
Value four = b.create<arith::ConstantOp>(elementType,
1001992
b.getFloatAttr(elementType, 4.0));
1002993
Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
@@ -1054,6 +1045,12 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
10541045
b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
10551046
}
10561047

1048+
if constexpr (std::is_same_v<Op, complex::TanOp>) {
1049+
// tan(x+yi) = -i*tanh(-y + xi)
1050+
std::swap(resultReal, resultImag);
1051+
resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
1052+
}
1053+
10571054
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
10581055
resultImag);
10591056
return success();
@@ -1327,8 +1324,8 @@ void mlir::populateComplexToStandardConversionPatterns(
13271324
SignOpConversion,
13281325
SinOpConversion,
13291326
SqrtOpConversion,
1330-
TanOpConversion,
1331-
TanhOpConversion,
1327+
TanTanhOpConversion<complex::TanOp>,
1328+
TanTanhOpConversion<complex::TanhOp>,
13321329
PowOpConversion,
13331330
RsqrtOpConversion
13341331
>(patterns.getContext());

0 commit comments

Comments
 (0)