Skip to content

Fix overflows in complex sqrt lowering. #88480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2024
Merged

Fix overflows in complex sqrt lowering. #88480

merged 1 commit into from
Apr 12, 2024

Conversation

jreiffers
Copy link
Member

This ports XLA's complex sqrt lowering. The accuracy was tested with its exhaustive_unary_test_complex test.

Note: rsqrt is still broken.

This ports XLA's complex sqrt lowering. The accuracy was tested with its
exhaustive_unary_test_complex test.

Note: rsqrt is still broken.
@jreiffers jreiffers requested review from akuegel and pifon2a April 12, 2024 07:16
@llvmbot llvmbot added the mlir label Apr 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-mlir

Author: Johannes Reifferscheid (jreiffers)

Changes

This ports XLA's complex sqrt lowering. The accuracy was tested with its exhaustive_unary_test_complex test.

Note: rsqrt is still broken.


Patch is 39.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88480.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+97-69)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+182-114)
  • (modified) mlir/test/Conversion/ComplexToStandard/full-conversion.mlir (+1-1)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c82e8105f06e5..0664b053fc9e67 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -27,35 +27,52 @@ using namespace mlir;
 
 namespace {
 
+// Returns the absolute value or its square root.
+Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
+                 ImplicitLocOpBuilder &b, bool returnSqrt = false) {
+  Value one = b.create<arith::ConstantOp>(real.getType(),
+                                          b.getFloatAttr(real.getType(), 1.0));
+
+  Value absReal = b.create<math::AbsFOp>(real, fmf);
+  Value absImag = b.create<math::AbsFOp>(imag, fmf);
+
+  Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
+  Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+  Value ratio = b.create<arith::DivFOp>(min, max, fmf);
+  Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
+  Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
+  Value result;
+
+  if (returnSqrt) {
+    Value quarter = b.create<arith::ConstantOp>(
+        real.getType(), b.getFloatAttr(real.getType(), 0.25));
+    // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
+    Value sqrt = b.create<math::SqrtOp>(max, fmf);
+    Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf);
+    result = b.create<arith::MulFOp>(sqrt, p025, fmf);
+  } else {
+    Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
+    result = b.create<arith::MulFOp>(max, sqrt, fmf);
+  }
+
+  Value isNaN =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
+  return b.create<arith::SelectOp>(isNaN, min, result);
+}
+
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
-    Type elementType = op.getType();
-    Value one = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 1.0));
-
     Value real = b.create<complex::ReOp>(adaptor.getComplex());
     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
-    Value absReal = b.create<math::AbsFOp>(real, fmf);
-    Value absImag = b.create<math::AbsFOp>(imag, fmf);
-
-    Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
-    Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
-    Value ratio = b.create<arith::DivFOp>(min, max, fmf);
-    Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
-    Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
-    Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
-    Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
-    Value isNaN =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
+    rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
 
     return success();
   }
@@ -829,60 +846,71 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
   LogicalResult
   matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     auto type = cast<ComplexType>(op.getType());
-    Type elementType = type.getElementType();
-    Value arg = adaptor.getComplex();
-    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
-
-    Value zero =
-        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
-
-    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
-    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-
-    Value absLhs = b.create<math::AbsFOp>(real, fmf);
-    Value absArg = b.create<complex::AbsOp>(elementType, arg, fmf);
-    Value addAbs = b.create<arith::AddFOp>(absLhs, absArg, fmf);
+    auto elementType = type.getElementType().cast<FloatType>();
+    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
+    auto cst = [&](APFloat v) {
+      return b.create<arith::ConstantOp>(elementType,
+                                         b.getFloatAttr(elementType, v));
+    };
+    const auto &floatSemantics = elementType.getFloatSemantics();
+    Value zero = cst(APFloat::getZero(floatSemantics));
     Value half = b.create<arith::ConstantOp>(elementType,
                                              b.getFloatAttr(elementType, 0.5));
-    Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half, fmf);
-    Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs, fmf);
-
-    Value realIsNegative =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
-    Value imagIsNegative =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
-
-    Value resultReal = sqrtAddAbs;
-
-    Value imagDivTwoResultReal = b.create<arith::DivFOp>(
-        imag, b.create<arith::AddFOp>(resultReal, resultReal, fmf), fmf);
-
-    Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
 
+    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+    Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
+    Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
+    Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
+    Value cos = b.create<math::CosOp>(sqrtArg, fmf);
+    Value sin = b.create<math::SinOp>(sqrtArg, fmf);
+    // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
+    // 0 * inf.
+    Value sinIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
+
+    Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
     Value resultImag = b.create<arith::SelectOp>(
-        realIsNegative,
-        b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
-                                  resultReal),
-        imagDivTwoResultReal);
-
-    resultReal = b.create<arith::SelectOp>(
-        realIsNegative,
-        b.create<arith::DivFOp>(
-            imag, b.create<arith::AddFOp>(resultImag, resultImag, fmf), fmf),
-        resultReal);
-
-    Value realIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
-    Value imagIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
-    Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
-
-    resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
-    resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
+        sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
+    if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
+                                            arith::FastMathFlags::ninf)) {
+      Value inf = cst(APFloat::getInf(floatSemantics));
+      Value negInf = cst(APFloat::getInf(floatSemantics, true));
+      Value nan = cst(APFloat::getNaN(floatSemantics));
+      Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
+
+      Value absImagIsInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
+      Value absImagIsNotInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
+      Value realIsInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
+      Value realIsNegInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
+
+      resultReal = b.create<arith::SelectOp>(
+          b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
+          resultReal);
+      resultReal = b.create<arith::SelectOp>(
+          b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
+
+      Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
+      resultImag = b.create<arith::SelectOp>(
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
+          nan, resultImag);
+      resultImag = b.create<arith::SelectOp>(
+          b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
+          resultImag);
+    }
+
+    Value resultIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
+    resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
+    resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
@@ -1065,7 +1093,7 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
   // Case 2:
   // 1^(c + d*i) = 1 + 0*i
   Value lhsEqOne = builder.create<arith::AndIOp>(
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one),
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
       bEqZero);
   Value cutoff2 =
       builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
@@ -1073,11 +1101,11 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
   // Case 3:
   // inf^(c + 0*i) = inf + 0*i, c > 0
   Value lhsEqInf = builder.create<arith::AndIOp>(
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf),
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
       bEqZero);
   Value rhsGt0 = builder.create<arith::AndIOp>(
       dEqZero,
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero));
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
   Value cutoff3 = builder.create<arith::SelectOp>(
       builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
 
@@ -1085,7 +1113,7 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
   // inf^(c + 0*i) = 0 + 0*i, c < 0
   Value rhsLt0 = builder.create<arith::AndIOp>(
       dEqZero,
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero));
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
   Value cutoff4 = builder.create<arith::SelectOp>(
       builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
 
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8d2fb09daa87b6..b22c1acacaea18 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -8,9 +8,9 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   return %abs : f32
 }
 
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -250,9 +250,9 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg: complex<f32>
   return %log : complex<f32>
 }
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -493,9 +493,9 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -697,45 +697,95 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
   return %sqrt : complex<f32>
 }
 
-// CHECK: %[[CST:.*]]  = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] : f32
 // CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
 // CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
 // CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
-// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] : f32
-// CHECK: %[[CST2:.*]]  = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
-// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
-// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] : f32
-// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] : f32
-// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
-// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
-// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
-// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] : f32
-// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] : f32
-// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
-// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
-// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
-// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
-// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
-// CHECK: return %[[VAR41]] : complex<f32>
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NINF:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] : f32
+// CHECK: %[[ABSIMINF:.*]] = arith.cmpf oeq, %[[ABSIM]], %[[INF]] : f32
+// CHECK: %[[ABSIMNOTINF:.*]] = arith.cmpf one, %[[ABSIM]], %[[INF]] : f32
+// CHECK: %[[REINF:.*]] = arith.cmpf oeq, %[[RE]], %[[INF]] : f32
+// CHECK: %[[RENINF:.*]] = arith.cmpf oeq, %[[RE]], %[[NINF]] : f32
+// CHECK: %[[RESULT_RE_ZERO:.*]] = arith.andi %[[RENINF]], %[[ABSIMNOTINF]] : i1
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_RE_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESUL_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[REINF]] : i1
+// CHECK: %[[RESULT_RE3:.*]] = arith.select %[[RESUL_IM_INF]], %[[INF]], %[[RESULT_RE2]] : f32
+// CHECK: %[[INF_IM_SIGN:.*]] = math.copysign %[[INF]], %[[IM]] : f32
+// CHECK: %[[RESULT_IM_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS]], %[[SQRT_ABS]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_IM_NAN]], %[[NAN]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[RENINF]] : i1
+// CHECK: %[[RESULT_IM4:.*]] = arith.select %[[RESULT_IM_INF]], %[[INF_IM_SIGN]], %[[RESULT_IM3]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] : f32
+// CHECK: %[[RESULT_RE4:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE3]] : f32
+// CHECK: %[[RESULT_IM5:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM4]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_RE4]], %[[RESULT_IM5]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt_nnan_ninf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
+  %sqrt = complex.sqrt %arg fastmath<nnan,ninf> : complex<f32>
+  return %sqrt : complex<f32>
+}
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,ninf> : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fast...
[truncated]

@jreiffers jreiffers merged commit ff9bc3a into llvm:main Apr 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants