Skip to content

[mlir][complex] Support Fastmath flag in conversion of complex.div to standard #82729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 27, 2024

Conversation

Lewuathe
Copy link
Member

Support Fastmath flag to convert complex.div to standard dialects.

See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

@llvmbot
Copy link
Member

llvmbot commented Feb 23, 2024

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

Changes

Support Fastmath flag to convert complex.div to standard dialects.

See: https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981


Patch is 23.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82729.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+79-51)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+112-1)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index cc315110f9be20..33b94b5042e378 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -195,6 +195,7 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value real =
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -206,11 +207,13 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
     // implementation in the subclass to combine them.
     Value half = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
-    Value exp = rewriter.create<math::ExpOp>(loc, imag);
-    Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
-    Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
-    Value sin = rewriter.create<math::SinOp>(loc, real);
-    Value cos = rewriter.create<math::CosOp>(loc, real);
+    Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf.getValue());
+    Value scaledExp =
+        rewriter.create<arith::MulFOp>(loc, half, exp, fmf.getValue());
+    Value reciprocalExp =
+        rewriter.create<arith::DivFOp>(loc, half, exp, fmf.getValue());
+    Value sin = rewriter.create<math::SinOp>(loc, real, fmf.getValue());
+    Value cos = rewriter.create<math::CosOp>(loc, real, fmf.getValue());
 
     auto resultPair =
         combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
@@ -257,6 +260,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getLhs().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value lhsReal =
         rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
@@ -290,45 +294,59 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
     //
     // See https://dl.acm.org/citation.cfm?id=368661 for more details.
     Value rhsRealImagRatio =
-        rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
+        rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf.getValue());
     Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
         loc, rhsImag,
-        rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
+        rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal,
+                                       fmf.getValue()),
+        fmf.getValue());
     Value realNumerator1 = rewriter.create<arith::AddFOp>(
-        loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
-        lhsImag);
-    Value resultReal1 =
-        rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
+        loc,
+        rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio,
+                                       fmf.getValue()),
+        lhsImag, fmf.getValue());
+    Value resultReal1 = rewriter.create<arith::DivFOp>(
+        loc, realNumerator1, rhsRealImagDenom, fmf.getValue());
     Value imagNumerator1 = rewriter.create<arith::SubFOp>(
-        loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
-        lhsReal);
-    Value resultImag1 =
-        rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
+        loc,
+        rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio,
+                                       fmf.getValue()),
+        lhsReal, fmf.getValue());
+    Value resultImag1 = rewriter.create<arith::DivFOp>(
+        loc, imagNumerator1, rhsRealImagDenom, fmf.getValue());
 
     Value rhsImagRealRatio =
-        rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
+        rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf.getValue());
     Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
         loc, rhsReal,
-        rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
+        rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag,
+                                       fmf.getValue()),
+        fmf.getValue());
     Value realNumerator2 = rewriter.create<arith::AddFOp>(
         loc, lhsReal,
-        rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
-    Value resultReal2 =
-        rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
+        rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio,
+                                       fmf.getValue()),
+        fmf.getValue());
+    Value resultReal2 = rewriter.create<arith::DivFOp>(
+        loc, realNumerator2, rhsImagRealDenom, fmf.getValue());
     Value imagNumerator2 = rewriter.create<arith::SubFOp>(
         loc, lhsImag,
-        rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
-    Value resultImag2 =
-        rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
+        rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio,
+                                       fmf.getValue()),
+        fmf.getValue());
+    Value resultImag2 = rewriter.create<arith::DivFOp>(
+        loc, imagNumerator2, rhsImagRealDenom, fmf.getValue());
 
     // Consider corner cases.
     // Case 1. Zero denominator, numerator contains at most one NaN value.
     Value zero = rewriter.create<arith::ConstantOp>(
         loc, elementType, rewriter.getZeroAttr(elementType));
-    Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal);
+    Value rhsRealAbs =
+        rewriter.create<math::AbsFOp>(loc, rhsReal, fmf.getValue());
     Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
         loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
-    Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag);
+    Value rhsImagAbs =
+        rewriter.create<math::AbsFOp>(loc, rhsImag, fmf.getValue());
     Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
         loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
     Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
@@ -346,10 +364,10 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
             elementType, APFloat::getInf(elementType.getFloatSemantics())));
     Value infWithSignOfRhsReal =
         rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
-    Value infinityResultReal =
-        rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
-    Value infinityResultImag =
-        rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
+    Value infinityResultReal = rewriter.create<arith::MulFOp>(
+        loc, infWithSignOfRhsReal, lhsReal, fmf.getValue());
+    Value infinityResultImag = rewriter.create<arith::MulFOp>(
+        loc, infWithSignOfRhsReal, lhsImag, fmf.getValue());
 
     // Case 2. Infinite numerator, finite denominator.
     Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -358,10 +376,12 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
         loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
     Value rhsFinite =
         rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
-    Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal);
+    Value lhsRealAbs =
+        rewriter.create<math::AbsFOp>(loc, lhsReal, fmf.getValue());
     Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
         loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
-    Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag);
+    Value lhsImagAbs =
+        rewriter.create<math::AbsFOp>(loc, lhsImag, fmf.getValue());
     Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
         loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
     Value lhsInfinite =
@@ -376,22 +396,26 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
     Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
         loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
         lhsImag);
-    Value lhsRealIsInfWithSignTimesRhsReal =
-        rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
-    Value lhsImagIsInfWithSignTimesRhsImag =
-        rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
+    Value lhsRealIsInfWithSignTimesRhsReal = rewriter.create<arith::MulFOp>(
+        loc, lhsRealIsInfWithSign, rhsReal, fmf.getValue());
+    Value lhsImagIsInfWithSignTimesRhsImag = rewriter.create<arith::MulFOp>(
+        loc, lhsImagIsInfWithSign, rhsImag, fmf.getValue());
     Value resultReal3 = rewriter.create<arith::MulFOp>(
         loc, inf,
         rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
-                                       lhsImagIsInfWithSignTimesRhsImag));
-    Value lhsRealIsInfWithSignTimesRhsImag =
-        rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
-    Value lhsImagIsInfWithSignTimesRhsReal =
-        rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
+                                       lhsImagIsInfWithSignTimesRhsImag,
+                                       fmf.getValue()),
+        fmf.getValue());
+    Value lhsRealIsInfWithSignTimesRhsImag = rewriter.create<arith::MulFOp>(
+        loc, lhsRealIsInfWithSign, rhsImag, fmf.getValue());
+    Value lhsImagIsInfWithSignTimesRhsReal = rewriter.create<arith::MulFOp>(
+        loc, lhsImagIsInfWithSign, rhsReal, fmf.getValue());
     Value resultImag3 = rewriter.create<arith::MulFOp>(
         loc, inf,
         rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
-                                       lhsRealIsInfWithSignTimesRhsImag));
+                                       lhsRealIsInfWithSignTimesRhsImag,
+                                       fmf.getValue()),
+        fmf.getValue());
 
     // Case 3: Finite numerator, infinite denominator.
     Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
@@ -414,22 +438,26 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
     Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
         loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
         rhsImag);
-    Value rhsRealIsInfWithSignTimesLhsReal =
-        rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
-    Value rhsImagIsInfWithSignTimesLhsImag =
-        rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
+    Value rhsRealIsInfWithSignTimesLhsReal = rewriter.create<arith::MulFOp>(
+        loc, lhsReal, rhsRealIsInfWithSign, fmf.getValue());
+    Value rhsImagIsInfWithSignTimesLhsImag = rewriter.create<arith::MulFOp>(
+        loc, lhsImag, rhsImagIsInfWithSign, fmf.getValue());
     Value resultReal4 = rewriter.create<arith::MulFOp>(
         loc, zero,
         rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
-                                       rhsImagIsInfWithSignTimesLhsImag));
-    Value rhsRealIsInfWithSignTimesLhsImag =
-        rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
-    Value rhsImagIsInfWithSignTimesLhsReal =
-        rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
+                                       rhsImagIsInfWithSignTimesLhsImag,
+                                       fmf.getValue()),
+        fmf.getValue());
+    Value rhsRealIsInfWithSignTimesLhsImag = rewriter.create<arith::MulFOp>(
+        loc, lhsImag, rhsRealIsInfWithSign, fmf.getValue());
+    Value rhsImagIsInfWithSignTimesLhsReal = rewriter.create<arith::MulFOp>(
+        loc, lhsReal, rhsImagIsInfWithSign, fmf.getValue());
     Value resultImag4 = rewriter.create<arith::MulFOp>(
         loc, zero,
         rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
-                                       rhsImagIsInfWithSignTimesLhsReal));
+                                       rhsImagIsInfWithSignTimesLhsReal,
+                                       fmf.getValue()),
+        fmf.getValue());
 
     Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
         loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 1fe843b1447ab3..39af7dd02a62d3 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1045,4 +1045,115 @@ func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
 // CHECK: %[[FINAL_IMAG:.*]] = arith.select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32
 
 // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
-// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_div_with_fmf
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %div = complex.div %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %div : complex<f32>
+}
+// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
+
+// CHECK: %[[RHS_REAL_IMAG_RATIO:.*]] = arith.divf %[[RHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[RHS_REAL_IMAG_RATIO]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_REAL_IMAG_DENOM:.*]] = arith.addf %[[RHS_IMAG]], %[[RHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL_IMAG_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_NUMERATOR_1:.*]] = arith.addf %[[LHS_REAL_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL_1:.*]] = arith.divf %[[REAL_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL_IMAG_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_NUMERATOR_1:.*]] = arith.subf %[[LHS_IMAG_TIMES_RHS_REAL_IMAG_RATIO]], %[[LHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IMAG_1:.*]] = arith.divf %[[IMAG_NUMERATOR_1]], %[[RHS_REAL_IMAG_DENOM]] fastmath<nnan,contract> : f32
+
+// CHECK: %[[RHS_IMAG_REAL_RATIO:.*]] = arith.divf %[[RHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[RHS_IMAG_REAL_RATIO]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_IMAG_REAL_DENOM:.*]] = arith.addf %[[RHS_REAL]], %[[RHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG_REAL_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_NUMERATOR_2:.*]] = arith.addf %[[LHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG_REAL_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL_2:.*]] = arith.divf %[[REAL_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG_REAL_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_NUMERATOR_2:.*]] = arith.subf %[[LHS_IMAG]], %[[LHS_REAL_TIMES_RHS_IMAG_REAL_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IMAG_2:.*]] = arith.divf %[[IMAG_NUMERATOR_2]], %[[RHS_IMAG_REAL_DENOM]] fastmath<nnan,contract> : f32
+
+// Case 1. Zero denominator, numerator contains at most one NaN value.
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_REAL_ABS_IS_ZERO:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[ZERO]] : f32
+// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_IMAG_ABS_IS_ZERO:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[ZERO]] : f32
+// CHECK: %[[LHS_REAL_IS_NOT_NAN:.*]] = arith.cmpf ord, %[[LHS_REAL]], %[[ZERO]] : f32
+// CHECK: %[[LHS_IMAG_IS_NOT_NAN:.*]] = arith.cmpf ord, %[[LHS_IMAG]], %[[ZERO]] : f32
+// CHECK: %[[LHS_CONTAINS_NOT_NAN_VALUE:.*]] = arith.ori %[[LHS_REAL_IS_NOT_NAN]], %[[LHS_IMAG_IS_NOT_NAN]] : i1
+// CHECK: %[[RHS_IS_ZERO:.*]] = arith.andi %[[RHS_REAL_ABS_IS_ZERO]], %[[RHS_IMAG_ABS_IS_ZERO]] : i1
+// CHECK: %[[RESULT_IS_INFINITY:.*]] = arith.andi %[[LHS_CONTAINS_NOT_NAN_VALUE]], %[[RHS_IS_ZERO]] : i1
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[INF_WITH_SIGN_OF_RHS_REAL:.*]] = math.copysign %[[INF]], %[[RHS_REAL]] : f32
+// CHECK: %[[INFINITY_RESULT_REAL:.*]] = arith.mulf %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[INFINITY_RESULT_IMAG:.*]] = arith.mulf %[[INF_WITH_SIGN_OF_RHS_REAL]], %[[LHS_IMAG]] fastmath<nnan,contract> : f32
+
+// Case 2. Infinite numerator, finite denominator.
+// CHECK: %[[RHS_REAL_FINITE:.*]] = arith.cmpf one, %[[RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IMAG_FINITE:.*]] = arith.cmpf one, %[[RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IS_FINITE:.*]] = arith.andi %[[RHS_REAL_FINITE]], %[[RHS_IMAG_FINITE]] : i1
+// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_INFINITE:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_INFINITE:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IS_INFINITE:.*]] = arith.ori %[[LHS_REAL_INFINITE]], %[[LHS_IMAG_INFINITE]] : i1
+// CHECK: %[[INF_NUM_FINITE_DENOM:.*]] = arith.andi %[[LHS_IS_INFINITE]], %[[RHS_IS_FINITE]] : i1
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[LHS_REAL_IS_INF:.*]] = arith.select %[[LHS_REAL_INFINITE]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[LHS_REAL_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_REAL_IS_INF]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF:.*]] = arith.select %[[LHS_IMAG_INFINITE]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF_WITH_SIGN:.*]] = math.copysign %[[LHS_IMAG_IS_INF]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[INF_MULTIPLICATOR_1:.*]] = arith.addf %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL_3:.*]] = arith.mulf %[[INF]], %[[INF_MULTIPLICATOR_1]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL_IS_INF_WITH_SIGN]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG_IS_INF_WITH_SIGN]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[INF_MULTIPLICATOR_2:.*]] = arith.subf %[[LHS_IMAG_IS_INF_WITH_SIGN_TIMES_RHS_REAL]], %[[LHS_REAL_IS_INF_WITH_SIGN_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IMAG_3:.*]] = arith.mulf %[[INF]], %[[INF_MULTIPLICATOR_2]] fastmath<nnan,contract> : f32
+
+// Case 3. Finite numerator, infinite denominator.
+// CHECK: %[[LHS_REAL_FINITE:.*]] = arith.cmpf one, %[[LHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_FINITE:.*]] = arith.cmpf one, %[[LHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IS_FINITE:.*]] = arith.andi %[[LHS_REAL_FINITE]], ...
[truncated]

@Lewuathe Lewuathe requested a review from joker-eph February 23, 2024 07:24
@Lewuathe Lewuathe force-pushed the support-fastmath-flag-complex-div branch from 0a81f0b to 81bc4ac Compare February 24, 2024 07:42
@Lewuathe Lewuathe force-pushed the support-fastmath-flag-complex-div branch from 81bc4ac to 03bd29e Compare February 24, 2024 07:49
Copy link
Member Author

@Lewuathe Lewuathe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph Thank you for the review!

It looks like flang spec failure is not directly related to this change. I'll merge this.
https://buildkite.com/llvm-project/github-pull-requests

@Lewuathe Lewuathe merged commit 288d317 into llvm:main Feb 27, 2024
@Lewuathe Lewuathe deleted the support-fastmath-flag-complex-div branch February 27, 2024 09:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants