@@ -27,35 +27,52 @@ using namespace mlir;
27
27
28
28
namespace {
29
29
30
+ // Returns the absolute value or its square root.
31
+ Value computeAbs (Value real, Value imag, arith::FastMathFlags fmf,
32
+ ImplicitLocOpBuilder &b, bool returnSqrt = false ) {
33
+ Value one = b.create <arith::ConstantOp>(real.getType (),
34
+ b.getFloatAttr (real.getType (), 1.0 ));
35
+
36
+ Value absReal = b.create <math::AbsFOp>(real, fmf);
37
+ Value absImag = b.create <math::AbsFOp>(imag, fmf);
38
+
39
+ Value max = b.create <arith::MaximumFOp>(absReal, absImag, fmf);
40
+ Value min = b.create <arith::MinimumFOp>(absReal, absImag, fmf);
41
+ Value ratio = b.create <arith::DivFOp>(min, max, fmf);
42
+ Value ratioSq = b.create <arith::MulFOp>(ratio, ratio, fmf);
43
+ Value ratioSqPlusOne = b.create <arith::AddFOp>(ratioSq, one, fmf);
44
+ Value result;
45
+
46
+ if (returnSqrt) {
47
+ Value quarter = b.create <arith::ConstantOp>(
48
+ real.getType (), b.getFloatAttr (real.getType (), 0.25 ));
49
+ // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
50
+ Value sqrt = b.create <math::SqrtOp>(max, fmf);
51
+ Value p025 = b.create <math::PowFOp>(ratioSqPlusOne, quarter, fmf);
52
+ result = b.create <arith::MulFOp>(sqrt, p025, fmf);
53
+ } else {
54
+ Value sqrt = b.create <math::SqrtOp>(ratioSqPlusOne, fmf);
55
+ result = b.create <arith::MulFOp>(max, sqrt, fmf);
56
+ }
57
+
58
+ Value isNaN =
59
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
60
+ return b.create <arith::SelectOp>(isNaN, min, result);
61
+ }
62
+
30
63
struct AbsOpConversion : public OpConversionPattern <complex::AbsOp> {
31
64
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
32
65
33
66
LogicalResult
34
67
matchAndRewrite (complex::AbsOp op, OpAdaptor adaptor,
35
68
ConversionPatternRewriter &rewriter) const override {
36
- mlir:: ImplicitLocOpBuilder b (op.getLoc (), rewriter);
69
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
37
70
38
71
arith::FastMathFlags fmf = op.getFastMathFlagsAttr ().getValue ();
39
72
40
- Type elementType = op.getType ();
41
- Value one = b.create <arith::ConstantOp>(elementType,
42
- b.getFloatAttr (elementType, 1.0 ));
43
-
44
73
Value real = b.create <complex::ReOp>(adaptor.getComplex ());
45
74
Value imag = b.create <complex::ImOp>(adaptor.getComplex ());
46
- Value absReal = b.create <math::AbsFOp>(real, fmf);
47
- Value absImag = b.create <math::AbsFOp>(imag, fmf);
48
-
49
- Value max = b.create <arith::MaximumFOp>(absReal, absImag, fmf);
50
- Value min = b.create <arith::MinimumFOp>(absReal, absImag, fmf);
51
- Value ratio = b.create <arith::DivFOp>(min, max, fmf);
52
- Value ratioSq = b.create <arith::MulFOp>(ratio, ratio, fmf);
53
- Value ratioSqPlusOne = b.create <arith::AddFOp>(ratioSq, one, fmf);
54
- Value sqrt = b.create <math::SqrtOp>(ratioSqPlusOne, fmf);
55
- Value result = b.create <arith::MulFOp>(max, sqrt, fmf);
56
- Value isNaN =
57
- b.create <arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
58
- rewriter.replaceOpWithNewOp <arith::SelectOp>(op, isNaN, min, result);
75
+ rewriter.replaceOp (op, computeAbs (real, imag, fmf, b));
59
76
60
77
return success ();
61
78
}
@@ -829,60 +846,71 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
829
846
LogicalResult
830
847
matchAndRewrite (complex::SqrtOp op, OpAdaptor adaptor,
831
848
ConversionPatternRewriter &rewriter) const override {
832
- mlir:: ImplicitLocOpBuilder b (op.getLoc (), rewriter);
849
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
833
850
834
851
auto type = cast<ComplexType>(op.getType ());
835
- Type elementType = type.getElementType ();
836
- Value arg = adaptor.getComplex ();
837
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
838
-
839
- Value zero =
840
- b.create <arith::ConstantOp>(elementType, b.getZeroAttr (elementType));
841
-
842
- Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
843
- Value imag = b.create <complex::ImOp>(elementType, adaptor.getComplex ());
844
-
845
- Value absLhs = b.create <math::AbsFOp>(real, fmf);
846
- Value absArg = b.create <complex::AbsOp>(elementType, arg, fmf);
847
- Value addAbs = b.create <arith::AddFOp>(absLhs, absArg, fmf);
852
+ auto elementType = type.getElementType ().cast <FloatType>();
853
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr ().getValue ();
848
854
855
+ auto cst = [&](APFloat v) {
856
+ return b.create <arith::ConstantOp>(elementType,
857
+ b.getFloatAttr (elementType, v));
858
+ };
859
+ const auto &floatSemantics = elementType.getFloatSemantics ();
860
+ Value zero = cst (APFloat::getZero (floatSemantics));
849
861
Value half = b.create <arith::ConstantOp>(elementType,
850
862
b.getFloatAttr (elementType, 0.5 ));
851
- Value halfAddAbs = b.create <arith::MulFOp>(addAbs, half, fmf);
852
- Value sqrtAddAbs = b.create <math::SqrtOp>(halfAddAbs, fmf);
853
-
854
- Value realIsNegative =
855
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
856
- Value imagIsNegative =
857
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
858
-
859
- Value resultReal = sqrtAddAbs;
860
-
861
- Value imagDivTwoResultReal = b.create <arith::DivFOp>(
862
- imag, b.create <arith::AddFOp>(resultReal, resultReal, fmf), fmf);
863
-
864
- Value negativeResultReal = b.create <arith::NegFOp>(resultReal);
865
863
864
+ Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
865
+ Value imag = b.create <complex::ImOp>(elementType, adaptor.getComplex ());
866
+ Value absSqrt = computeAbs (real, imag, fmf, b, /* returnSqrt=*/ true );
867
+ Value argArg = b.create <math::Atan2Op>(imag, real, fmf);
868
+ Value sqrtArg = b.create <arith::MulFOp>(argArg, half, fmf);
869
+ Value cos = b.create <math::CosOp>(sqrtArg, fmf);
870
+ Value sin = b.create <math::SinOp>(sqrtArg, fmf);
871
+ // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
872
+ // 0 * inf.
873
+ Value sinIsZero =
874
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
875
+
876
+ Value resultReal = b.create <arith::MulFOp>(absSqrt, cos, fmf);
866
877
Value resultImag = b.create <arith::SelectOp>(
867
- realIsNegative,
868
- b.create <arith::SelectOp>(imagIsNegative, negativeResultReal,
869
- resultReal),
870
- imagDivTwoResultReal);
871
-
872
- resultReal = b.create <arith::SelectOp>(
873
- realIsNegative,
874
- b.create <arith::DivFOp>(
875
- imag, b.create <arith::AddFOp>(resultImag, resultImag, fmf), fmf),
876
- resultReal);
877
-
878
- Value realIsZero =
879
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
880
- Value imagIsZero =
881
- b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
882
- Value argIsZero = b.create <arith::AndIOp>(realIsZero, imagIsZero);
883
-
884
- resultReal = b.create <arith::SelectOp>(argIsZero, zero, resultReal);
885
- resultImag = b.create <arith::SelectOp>(argIsZero, zero, resultImag);
878
+ sinIsZero, zero, b.create <arith::MulFOp>(absSqrt, sin, fmf));
879
+ if (!arith::bitEnumContainsAll (fmf, arith::FastMathFlags::nnan |
880
+ arith::FastMathFlags::ninf)) {
881
+ Value inf = cst (APFloat::getInf (floatSemantics));
882
+ Value negInf = cst (APFloat::getInf (floatSemantics, true ));
883
+ Value nan = cst (APFloat::getNaN (floatSemantics));
884
+ Value absImag = b.create <math::AbsFOp>(elementType, imag, fmf);
885
+
886
+ Value absImagIsInf =
887
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
888
+ Value absImagIsNotInf =
889
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
890
+ Value realIsInf =
891
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
892
+ Value realIsNegInf =
893
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
894
+
895
+ resultReal = b.create <arith::SelectOp>(
896
+ b.create <arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
897
+ resultReal);
898
+ resultReal = b.create <arith::SelectOp>(
899
+ b.create <arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
900
+
901
+ Value imagSignInf = b.create <math::CopySignOp>(inf, imag, fmf);
902
+ resultImag = b.create <arith::SelectOp>(
903
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
904
+ nan, resultImag);
905
+ resultImag = b.create <arith::SelectOp>(
906
+ b.create <arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
907
+ resultImag);
908
+ }
909
+
910
+ Value resultIsZero =
911
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
912
+ resultReal = b.create <arith::SelectOp>(resultIsZero, zero, resultReal);
913
+ resultImag = b.create <arith::SelectOp>(resultIsZero, zero, resultImag);
886
914
887
915
rewriter.replaceOpWithNewOp <complex::CreateOp>(op, type, resultReal,
888
916
resultImag);
@@ -1065,27 +1093,27 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
1065
1093
// Case 2:
1066
1094
// 1^(c + d*i) = 1 + 0*i
1067
1095
Value lhsEqOne = builder.create <arith::AndIOp>(
1068
- builder.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one),
1096
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf ),
1069
1097
bEqZero);
1070
1098
Value cutoff2 =
1071
1099
builder.create <arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1072
1100
1073
1101
// Case 3:
1074
1102
// inf^(c + 0*i) = inf + 0*i, c > 0
1075
1103
Value lhsEqInf = builder.create <arith::AndIOp>(
1076
- builder.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf),
1104
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf ),
1077
1105
bEqZero);
1078
1106
Value rhsGt0 = builder.create <arith::AndIOp>(
1079
1107
dEqZero,
1080
- builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero));
1108
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf ));
1081
1109
Value cutoff3 = builder.create <arith::SelectOp>(
1082
1110
builder.create <arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1083
1111
1084
1112
// Case 4:
1085
1113
// inf^(c + 0*i) = 0 + 0*i, c < 0
1086
1114
Value rhsLt0 = builder.create <arith::AndIOp>(
1087
1115
dEqZero,
1088
- builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero));
1116
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf ));
1089
1117
Value cutoff4 = builder.create <arith::SelectOp>(
1090
1118
builder.create <arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1091
1119
0 commit comments