@@ -956,27 +956,12 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
956
956
}
957
957
};
958
958
959
- struct TanOpConversion : public OpConversionPattern <complex::TanOp> {
960
- using OpConversionPattern<complex::TanOp>::OpConversionPattern;
959
+ template <typename Op>
960
+ struct TanTanhOpConversion : public OpConversionPattern <Op> {
961
+ using OpConversionPattern<Op>::OpConversionPattern;
961
962
962
963
LogicalResult
963
- matchAndRewrite (complex::TanOp op, OpAdaptor adaptor,
964
- ConversionPatternRewriter &rewriter) const override {
965
- auto loc = op.getLoc ();
966
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
967
-
968
- Value cos = rewriter.create <complex::CosOp>(loc, adaptor.getComplex (), fmf);
969
- Value sin = rewriter.create <complex::SinOp>(loc, adaptor.getComplex (), fmf);
970
- rewriter.replaceOpWithNewOp <complex::DivOp>(op, sin, cos, fmf);
971
- return success ();
972
- }
973
- };
974
-
975
- struct TanhOpConversion : public OpConversionPattern <complex::TanhOp> {
976
- using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
977
-
978
- LogicalResult
979
- matchAndRewrite (complex::TanhOp op, OpAdaptor adaptor,
964
+ matchAndRewrite (Op op, Op::Adaptor adaptor,
980
965
ConversionPatternRewriter &rewriter) const override {
981
966
ImplicitLocOpBuilder b (op.getLoc (), rewriter);
982
967
auto loc = op.getLoc ();
@@ -989,14 +974,20 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
989
974
b.create <complex::ReOp>(loc, elementType, adaptor.getComplex ());
990
975
Value imag =
991
976
b.create <complex::ImOp>(loc, elementType, adaptor.getComplex ());
977
+ Value negOne = b.create <arith::ConstantOp>(
978
+ elementType, b.getFloatAttr (elementType, -1.0 ));
979
+
980
+ if constexpr (std::is_same_v<Op, complex::TanOp>) {
981
+ // tan(x+yi) = -i*tanh(-y + xi)
982
+ std::swap (real, imag);
983
+ real = b.create <arith::MulFOp>(real, negOne, fmf);
984
+ }
992
985
993
986
auto cst = [&](APFloat v) {
994
987
return b.create <arith::ConstantOp>(elementType,
995
988
b.getFloatAttr (elementType, v));
996
989
};
997
990
Value inf = cst (APFloat::getInf (floatSemantics));
998
- Value negOne = b.create <arith::ConstantOp>(
999
- elementType, b.getFloatAttr (elementType, -1.0 ));
1000
991
Value four = b.create <arith::ConstantOp>(elementType,
1001
992
b.getFloatAttr (elementType, 4.0 ));
1002
993
Value twoReal = b.create <arith::AddFOp>(real, real, fmf);
@@ -1054,6 +1045,12 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
1054
1045
b.create <arith::SelectOp>(resultImagIsZero, zero, resultImag);
1055
1046
}
1056
1047
1048
+ if constexpr (std::is_same_v<Op, complex::TanOp>) {
1049
+ // tan(x+yi) = -i*tanh(-y + xi)
1050
+ std::swap (resultReal, resultImag);
1051
+ resultImag = b.create <arith::MulFOp>(resultImag, negOne, fmf);
1052
+ }
1053
+
1057
1054
rewriter.replaceOpWithNewOp <complex::CreateOp>(op, type, resultReal,
1058
1055
resultImag);
1059
1056
return success ();
@@ -1327,8 +1324,8 @@ void mlir::populateComplexToStandardConversionPatterns(
1327
1324
SignOpConversion,
1328
1325
SinOpConversion,
1329
1326
SqrtOpConversion,
1330
- TanOpConversion ,
1331
- TanhOpConversion ,
1327
+ TanTanhOpConversion<complex::TanOp> ,
1328
+ TanTanhOpConversion<complex::TanhOp> ,
1332
1329
PowOpConversion,
1333
1330
RsqrtOpConversion
1334
1331
>(patterns.getContext ());
0 commit comments