Skip to content

Commit ff9bc3a

Browse files
authored
Fix overflows in complex sqrt lowering. (#88480)
This ports XLA's complex sqrt lowering. The accuracy was tested with its exhaustive_unary_test_complex test. Note: rsqrt is still broken.
1 parent 1b8830c commit ff9bc3a

File tree

3 files changed

+280
-184
lines changed

3 files changed

+280
-184
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 97 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,52 @@ using namespace mlir;
2727

2828
namespace {
2929

30+
// Returns the absolute value or its square root.
31+
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
32+
ImplicitLocOpBuilder &b, bool returnSqrt = false) {
33+
Value one = b.create<arith::ConstantOp>(real.getType(),
34+
b.getFloatAttr(real.getType(), 1.0));
35+
36+
Value absReal = b.create<math::AbsFOp>(real, fmf);
37+
Value absImag = b.create<math::AbsFOp>(imag, fmf);
38+
39+
Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
40+
Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
41+
Value ratio = b.create<arith::DivFOp>(min, max, fmf);
42+
Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
43+
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
44+
Value result;
45+
46+
if (returnSqrt) {
47+
Value quarter = b.create<arith::ConstantOp>(
48+
real.getType(), b.getFloatAttr(real.getType(), 0.25));
49+
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
50+
Value sqrt = b.create<math::SqrtOp>(max, fmf);
51+
Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf);
52+
result = b.create<arith::MulFOp>(sqrt, p025, fmf);
53+
} else {
54+
Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
55+
result = b.create<arith::MulFOp>(max, sqrt, fmf);
56+
}
57+
58+
Value isNaN =
59+
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
60+
return b.create<arith::SelectOp>(isNaN, min, result);
61+
}
62+
3063
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
3164
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
3265

3366
LogicalResult
3467
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
3568
ConversionPatternRewriter &rewriter) const override {
36-
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
69+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
3770

3871
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
3972

40-
Type elementType = op.getType();
41-
Value one = b.create<arith::ConstantOp>(elementType,
42-
b.getFloatAttr(elementType, 1.0));
43-
4473
Value real = b.create<complex::ReOp>(adaptor.getComplex());
4574
Value imag = b.create<complex::ImOp>(adaptor.getComplex());
46-
Value absReal = b.create<math::AbsFOp>(real, fmf);
47-
Value absImag = b.create<math::AbsFOp>(imag, fmf);
48-
49-
Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
50-
Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
51-
Value ratio = b.create<arith::DivFOp>(min, max, fmf);
52-
Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
53-
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
54-
Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
55-
Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
56-
Value isNaN =
57-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
58-
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
75+
rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
5976

6077
return success();
6178
}
@@ -829,60 +846,71 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
829846
LogicalResult
830847
matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
831848
ConversionPatternRewriter &rewriter) const override {
832-
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
849+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
833850

834851
auto type = cast<ComplexType>(op.getType());
835-
Type elementType = type.getElementType();
836-
Value arg = adaptor.getComplex();
837-
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
838-
839-
Value zero =
840-
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
841-
842-
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
843-
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
844-
845-
Value absLhs = b.create<math::AbsFOp>(real, fmf);
846-
Value absArg = b.create<complex::AbsOp>(elementType, arg, fmf);
847-
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg, fmf);
852+
auto elementType = type.getElementType().cast<FloatType>();
853+
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
848854

855+
auto cst = [&](APFloat v) {
856+
return b.create<arith::ConstantOp>(elementType,
857+
b.getFloatAttr(elementType, v));
858+
};
859+
const auto &floatSemantics = elementType.getFloatSemantics();
860+
Value zero = cst(APFloat::getZero(floatSemantics));
849861
Value half = b.create<arith::ConstantOp>(elementType,
850862
b.getFloatAttr(elementType, 0.5));
851-
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half, fmf);
852-
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs, fmf);
853-
854-
Value realIsNegative =
855-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
856-
Value imagIsNegative =
857-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
858-
859-
Value resultReal = sqrtAddAbs;
860-
861-
Value imagDivTwoResultReal = b.create<arith::DivFOp>(
862-
imag, b.create<arith::AddFOp>(resultReal, resultReal, fmf), fmf);
863-
864-
Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
865863

864+
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
865+
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
866+
Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
867+
Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
868+
Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
869+
Value cos = b.create<math::CosOp>(sqrtArg, fmf);
870+
Value sin = b.create<math::SinOp>(sqrtArg, fmf);
871+
// sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
872+
// 0 * inf.
873+
Value sinIsZero =
874+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
875+
876+
Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
866877
Value resultImag = b.create<arith::SelectOp>(
867-
realIsNegative,
868-
b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
869-
resultReal),
870-
imagDivTwoResultReal);
871-
872-
resultReal = b.create<arith::SelectOp>(
873-
realIsNegative,
874-
b.create<arith::DivFOp>(
875-
imag, b.create<arith::AddFOp>(resultImag, resultImag, fmf), fmf),
876-
resultReal);
877-
878-
Value realIsZero =
879-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
880-
Value imagIsZero =
881-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
882-
Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
883-
884-
resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
885-
resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
878+
sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
879+
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
880+
arith::FastMathFlags::ninf)) {
881+
Value inf = cst(APFloat::getInf(floatSemantics));
882+
Value negInf = cst(APFloat::getInf(floatSemantics, true));
883+
Value nan = cst(APFloat::getNaN(floatSemantics));
884+
Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
885+
886+
Value absImagIsInf =
887+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
888+
Value absImagIsNotInf =
889+
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
890+
Value realIsInf =
891+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
892+
Value realIsNegInf =
893+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
894+
895+
resultReal = b.create<arith::SelectOp>(
896+
b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
897+
resultReal);
898+
resultReal = b.create<arith::SelectOp>(
899+
b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
900+
901+
Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
902+
resultImag = b.create<arith::SelectOp>(
903+
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
904+
nan, resultImag);
905+
resultImag = b.create<arith::SelectOp>(
906+
b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
907+
resultImag);
908+
}
909+
910+
Value resultIsZero =
911+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
912+
resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
913+
resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
886914

887915
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
888916
resultImag);
@@ -1065,27 +1093,27 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
10651093
// Case 2:
10661094
// 1^(c + d*i) = 1 + 0*i
10671095
Value lhsEqOne = builder.create<arith::AndIOp>(
1068-
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one),
1096+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
10691097
bEqZero);
10701098
Value cutoff2 =
10711099
builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
10721100

10731101
// Case 3:
10741102
// inf^(c + 0*i) = inf + 0*i, c > 0
10751103
Value lhsEqInf = builder.create<arith::AndIOp>(
1076-
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf),
1104+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
10771105
bEqZero);
10781106
Value rhsGt0 = builder.create<arith::AndIOp>(
10791107
dEqZero,
1080-
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero));
1108+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
10811109
Value cutoff3 = builder.create<arith::SelectOp>(
10821110
builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
10831111

10841112
// Case 4:
10851113
// inf^(c + 0*i) = 0 + 0*i, c < 0
10861114
Value rhsLt0 = builder.create<arith::AndIOp>(
10871115
dEqZero,
1088-
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero));
1116+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
10891117
Value cutoff4 = builder.create<arith::SelectOp>(
10901118
builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
10911119

0 commit comments

Comments
 (0)