@@ -27,9 +27,11 @@ using namespace mlir;
27
27
28
28
namespace {
29
29
30
- // Returns the absolute value or its square root.
30
+ enum class AbsFn { abs, sqrt, rsqrt };
31
+
32
+ // Returns the absolute value, its square root or its reciprocal square root.
31
33
Value computeAbs (Value real, Value imag, arith::FastMathFlags fmf,
32
- ImplicitLocOpBuilder &b, bool returnSqrt = false ) {
34
+ ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs ) {
33
35
Value one = b.create <arith::ConstantOp>(real.getType (),
34
36
b.getFloatAttr (real.getType (), 1.0 ));
35
37
@@ -43,7 +45,13 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
43
45
Value ratioSqPlusOne = b.create <arith::AddFOp>(ratioSq, one, fmf);
44
46
Value result;
45
47
46
- if (returnSqrt) {
48
+ if (fn == AbsFn::rsqrt) {
49
+ ratioSqPlusOne = b.create <math::RsqrtOp>(ratioSqPlusOne, fmf);
50
+ min = b.create <math::RsqrtOp>(min, fmf);
51
+ max = b.create <math::RsqrtOp>(max, fmf);
52
+ }
53
+
54
+ if (fn == AbsFn::sqrt) {
47
55
Value quarter = b.create <arith::ConstantOp>(
48
56
real.getType (), b.getFloatAttr (real.getType (), 0.25 ));
49
57
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
@@ -863,7 +871,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
863
871
864
872
Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
865
873
Value imag = b.create <complex::ImOp>(elementType, adaptor.getComplex ());
866
- Value absSqrt = computeAbs (real, imag, fmf, b, /* returnSqrt= */ true );
874
+ Value absSqrt = computeAbs (real, imag, fmf, b, AbsFn::sqrt );
867
875
Value argArg = b.create <math::Atan2Op>(imag, real, fmf);
868
876
Value sqrtArg = b.create <arith::MulFOp>(argArg, half, fmf);
869
877
Value cos = b.create <math::CosOp>(sqrtArg, fmf);
@@ -1147,18 +1155,74 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1147
1155
LogicalResult
1148
1156
matchAndRewrite (complex::RsqrtOp op, OpAdaptor adaptor,
1149
1157
ConversionPatternRewriter &rewriter) const override {
1150
- mlir::ImplicitLocOpBuilder builder (op.getLoc (), rewriter);
1158
+ mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
1151
1159
auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
1152
1160
auto elementType = cast<FloatType>(type.getElementType ());
1153
1161
1154
- Value c = builder.create <arith::ConstantOp>(
1155
- elementType, builder.getFloatAttr (elementType, -0.5 ));
1156
- Value d = builder.create <arith::ConstantOp>(
1157
- elementType, builder.getFloatAttr (elementType, 0 ));
1162
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr ().getValue ();
1163
+
1164
+ auto cst = [&](APFloat v) {
1165
+ return b.create <arith::ConstantOp>(elementType,
1166
+ b.getFloatAttr (elementType, v));
1167
+ };
1168
+ const auto &floatSemantics = elementType.getFloatSemantics ();
1169
+ Value zero = cst (APFloat::getZero (floatSemantics));
1170
+ Value inf = cst (APFloat::getInf (floatSemantics));
1171
+ Value negHalf = b.create <arith::ConstantOp>(
1172
+ elementType, b.getFloatAttr (elementType, -0.5 ));
1173
+ Value nan = cst (APFloat::getNaN (floatSemantics));
1174
+
1175
+ Value real = b.create <complex::ReOp>(elementType, adaptor.getComplex ());
1176
+ Value imag = b.create <complex::ImOp>(elementType, adaptor.getComplex ());
1177
+ Value absRsqrt = computeAbs (real, imag, fmf, b, AbsFn::rsqrt);
1178
+ Value argArg = b.create <math::Atan2Op>(imag, real, fmf);
1179
+ Value rsqrtArg = b.create <arith::MulFOp>(argArg, negHalf, fmf);
1180
+ Value cos = b.create <math::CosOp>(rsqrtArg, fmf);
1181
+ Value sin = b.create <math::SinOp>(rsqrtArg, fmf);
1182
+
1183
+ Value resultReal = b.create <arith::MulFOp>(absRsqrt, cos, fmf);
1184
+ Value resultImag = b.create <arith::MulFOp>(absRsqrt, sin, fmf);
1185
+
1186
+ if (!arith::bitEnumContainsAll (fmf, arith::FastMathFlags::nnan |
1187
+ arith::FastMathFlags::ninf)) {
1188
+ Value negOne = b.create <arith::ConstantOp>(
1189
+ elementType, b.getFloatAttr (elementType, -1 ));
1190
+
1191
+ Value realSignedZero = b.create <math::CopySignOp>(zero, real, fmf);
1192
+ Value imagSignedZero = b.create <math::CopySignOp>(zero, imag, fmf);
1193
+ Value negImagSignedZero =
1194
+ b.create <arith::MulFOp>(negOne, imagSignedZero, fmf);
1158
1195
1159
- rewriter.replaceOp (op,
1160
- {powOpConversionImpl (builder, type, adaptor.getComplex (),
1161
- c, d, op.getFastmath ())});
1196
+ Value absReal = b.create <math::AbsFOp>(real, fmf);
1197
+ Value absImag = b.create <math::AbsFOp>(imag, fmf);
1198
+
1199
+ Value absImagIsInf =
1200
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1201
+ Value realIsNan =
1202
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1203
+ Value realIsInf =
1204
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1205
+ Value inIsNanInf = b.create <arith::AndIOp>(absImagIsInf, realIsNan);
1206
+
1207
+ Value resultIsZero = b.create <arith::OrIOp>(inIsNanInf, realIsInf);
1208
+
1209
+ resultReal =
1210
+ b.create <arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1211
+ resultImag = b.create <arith::SelectOp>(resultIsZero, negImagSignedZero,
1212
+ resultImag);
1213
+ }
1214
+
1215
+ Value isRealZero =
1216
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1217
+ Value isImagZero =
1218
+ b.create <arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1219
+ Value isZero = b.create <arith::AndIOp>(isRealZero, isImagZero);
1220
+
1221
+ resultReal = b.create <arith::SelectOp>(isZero, inf, resultReal);
1222
+ resultImag = b.create <arith::SelectOp>(isZero, nan, resultImag);
1223
+
1224
+ rewriter.replaceOpWithNewOp <complex::CreateOp>(op, type, resultReal,
1225
+ resultImag);
1162
1226
return success ();
1163
1227
}
1164
1228
};
0 commit comments