Skip to content

Fix complex abs corner cases. #88373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 11, 2024
Merged

Fix complex abs corner cases. #88373

merged 1 commit into from
Apr 11, 2024

Conversation

jreiffers
Copy link
Member

The current implementation fails for very small and very large values. For example, (0, -inf) should return inf, but it returns -inf.

This ports the logic used in XLA. Tested with XLA's exhaustive_binary_test_f32_f64.

The current implementation fails for very small and very large values.
For example, (0, -inf) should return inf, but it returns -inf.

This ports the logic used in XLA.
@jreiffers jreiffers requested a review from akuegel April 11, 2024 09:59
@llvmbot llvmbot added the mlir label Apr 11, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2024

@llvm/pr-subscribers-mlir

Author: Johannes Reifferscheid (jreiffers)

Changes

The current implementation fails for very small and very large values. For example, (0, -inf) should return inf, but it returns -inf.

This ports the logic used in XLA. Tested with XLA's exhaustive_binary_test_f32_f64.


Patch is 39.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88373.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+16-38)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+134-214)
  • (modified) mlir/test/Conversion/ComplexToStandard/full-conversion.mlir (+12-22)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a6fcf6a758c07f..462036e51a1f1c 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,7 +26,7 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
-// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
+
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
@@ -35,49 +35,27 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
                   ConversionPatternRewriter &rewriter) const override {
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
     Type elementType = op.getType();
-    Value arg = adaptor.getComplex();
-
-    Value zero =
-        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
     Value one = b.create<arith::ConstantOp>(elementType,
                                             b.getFloatAttr(elementType, 1.0));
 
-    Value real = b.create<complex::ReOp>(elementType, arg);
-    Value imag = b.create<complex::ImOp>(elementType, arg);
-
-    Value realIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
-    Value imagIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+    Value real = b.create<complex::ReOp>(adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+    Value absReal = b.create<math::AbsFOp>(real, fmf);
+    Value absImag = b.create<math::AbsFOp>(imag, fmf);
 
-    // Real > Imag
-    Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
-    Value imagSq =
-        b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
-    Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
-    Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
-    Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
-    Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
-
-    // Real <= Imag
-    Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
-    Value realSq =
-        b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
-    Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
-    Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
-    Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
-    Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
-
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(
-        op, realIsZero, imagAbs,
-        b.create<arith::SelectOp>(
-            imagIsZero, realAbs,
-            b.create<arith::SelectOp>(
-                b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
-                absImag, absReal)));
+    Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
+    Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+    Value ratio = b.create<arith::DivFOp>(min, max, fmf);
+    Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
+    Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
+    Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
+    Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
+    Value isNaN =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
 
     return success();
   }
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 46dba04a88aa0c..a1de61d10bb226 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -8,29 +8,21 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   return %abs : f32
 }
 
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: return %[[ABS3]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: return %[[ABS]] : f32
 
 // -----
 
@@ -258,29 +250,21 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg: complex<f32>
   return %log : complex<f32>
 }
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[RESULT:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RESULT]], %[[RESULT]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[RESULT]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = math.log %[[ABS]] : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
@@ -509,30 +493,22 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
-// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[ABS]] : f32
+// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[ABS]] : f32
 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
 // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
@@ -725,29 +701,21 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
-// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
-// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
-// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] : f32
-// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] : f32
-// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] : f32
-// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] : f32
-// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] : f32
-// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] : f32
-// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] : f32
-// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] : f32
-// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] : f32
-// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] : f32
-// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] : f32
-// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] : f32
-// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
-// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
-// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
-// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] : f32
 // CHECK: %[[CST2:.*]]  = arith.constant 5.000000e-01 : f32
 // CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
 // CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
@@ -821,29 +789,21 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
   return %abs : f32
 }
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: return %[[ABS3]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: return %[[ABS]] : f32
 
 // -----
 
@@ -928,29 +888,21 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg fastmath<nnan,contract> : complex<f32>
   return %log : complex<f32>
 }
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,con...
[truncated]

@jreiffers jreiffers merged commit 9d9bb7b into llvm:main Apr 11, 2024
@joker-eph
Copy link
Collaborator

Thanks for the fix!

@Hardcode84
Copy link
Contributor

I believe, current implementation is incorrect for case (0, 0) and fastmath=nnan.
div(min, max) == div(0, 0) == nan == poison which is propagated for all the following ops
cmpf uno == either false or poison (not sure about actual llvm semantics here)
both for false and poison select will return poison.
In practice, we are getting nan results instead of expected 0s.

Not sure if previous version was formally correct, but at least we were getting correct results for this case.

@jreiffers
Copy link
Member Author

Sorry about that. You're right. Let me take a look.

@jreiffers
Copy link
Member Author

I believe jreiffers@e869e7e will fix this. Still need to update tests.

@jreiffers
Copy link
Member Author

PR here: #95080

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants