Skip to content

[MLIR] Fix ComplexToStandard lowering of complex::MulOp #119591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 12, 2024

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Dec 11, 2024

A complex multiplication should lower simply to the familiar 4 real multiplications, 1 real addition, 1 real subtraction. No special-casing of infinite or NaN values should be made, instead the complex numbers should be thought as just vectors of two reals, naturally bottoming out on the reals' semantics, IEEE754 or otherwise. That is what nearly everybody else is doing ("nearly" because at the end of this PR description we pinpoint the actual source of this in C99 _Complex), and this pattern, by trying to do something different, was generating much larger code, which was much slower and a departure from the naturally expected floating-point behavior.

This code had originally been introduced in https://reviews.llvm.org/D105270, which stated this rationale:

The lowering handles special cases with NaN or infinity like C++.

I don't think that the C++ standard is a particularly important thing to follow in this instance. What matters more is what people actually do in practice with complex numbers, which rarely involves the C++ std::complex library type.

But out of curiosity, I checked, and the above statement seems incorrect. The current C++ standard library specification for std::complex does not say anything about the implementation of complex multiplication: paragraph [complex.ops] falls back on [complex.member.ops] which says:

Effects: Multiplies the complex value rhs by the complex value *this and stores the product in *this.

I also checked cppreference which often has useful information in case something changed in a c++ language revision, but likewise, nothing at all there:
https://en.cppreference.com/w/cpp/numeric/complex/operator_arith3

Finally, I checked in Compiler Explorer what Clang 19 currently generates:
https://godbolt.org/z/oY7Ks4j95
That is just the familiar 4 multiplications.... and then there is some weird check (fcmp) and conditionally a call to an external __mulsc3. Googled that, found this StackOverflow answer:
https://stackoverflow.com/a/49438578

Summary: this is not about C++ (this post confirms my reading of the C++ standard not mandating anything about this). This is about C, and it just happens that this C++ standard library implementation bottoms out on code shared with the C _Complex implementation.

Another nuance missing in that SO answer: this is actually implementation-defined behavior. There are two modes, controlled by

#pragma STDC CX_LIMITED_RANGE {ON,OFF,DEFAULT}

It is implementation-defined which is the default. Clang defaults to OFF, but that's just Clang. In that mode, the check is required:
https://en.cppreference.com/w/c/language/arithmetic_types#Complex_floating_types
And the specific point in the C99 standard is: G.5.1 Multiplicative operators.

But set it to ON and the check is gone:
https://godbolt.org/z/aG8fnbYoP

Summary: the argument has moved from C++ to C --- and even there, to implementation-defined behavior with a standard opt-out mechanism.

Like with C++, I maintain that the C standard is not a particularly meaningful thing for MLIR to follow here, because people doing business with complex numbers tend to lower them to real numbers themselves, or have their own specialized complex types, either way not relying on C99's _Complex type --- and the very poor performance of the CX_LIMITED_RANGE OFF behavior (default in Clang) is certainly a key reason why people who care prefer to stay away from _Complex and std::complex.

A good example that's relevant to MLIR's space is CUDA's cuComplex type (used in the cuBLAS CGEMM interface). Here is its multiplication function. The comment about competitiveness is interesting: it's not a quirk of this particular function, it's the spirit underpinning numerical code that matters.
https://github.com/tpn/cuda-samples/blob/1bf5cd15c51ce80fc9b387c0ff89a9f535b42bf5/v8.0/include/cuComplex.h#L106-L120

/* This implementation could suffer from intermediate overflow even though
 * the final result would be in range. However, various implementations do
 * not guard against this (presumably to avoid losing performance), so we 
 * don't do it either to stay competitive.
 */
__host__ __device__ static __inline__ cuFloatComplex cuCmulf (cuFloatComplex x,
                                                              cuFloatComplex y)
{
    cuFloatComplex prod;
    prod = make_cuFloatComplex  ((cuCrealf(x) * cuCrealf(y)) - 
                                 (cuCimagf(x) * cuCimagf(y)),
                                 (cuCrealf(x) * cuCimagf(y)) + 
                                 (cuCimagf(x) * cuCrealf(y)));
    return prod;
}

Another instance in CUTLASS: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/complex.h#L231-L236

Signed-off-by: Benoit Jacob <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2024

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

A complex multiplication should lower simply to the familiar 4 real multiplications, 1 real addition, 1 real subtraction. No special-casing of infinite or NaN values should be made, instead the complex numbers should be thought as just vectors of two reals, naturally bottoming out on the reals' semantics, IEEE754 or otherwise. That is what everybody else is doing, and this pattern, by trying to do something different, was generating much larger code, which was much slower and a departure from the naturally expected floating-point behavior.

This code had originally been introduced in https://reviews.llvm.org/D105270, which stated this rationale:
> The lowering handles special cases with NaN or infinity like C++.

I don't think that the C++ standard is a particularly important thing to follow in this instance. What matters more is what people actually do in practice with complex numbers, which rarely involves the C++ std::complex library type.

But out of curiosity, I checked, and the above statement seems incorrect. The current C++ standard library specification for std::complex does not say anything about the implementation of complex multiplication: paragraph [complex.ops] falls back on [complex.member.ops] which says:
> Effects: Multiplies the complex value rhs by the complex value *this and stores the product in *this.

I also checked cppreference which often has useful information in case something changed in a c++ language revision, but likewise, nothing at all there:
https://en.cppreference.com/w/cpp/numeric/complex/operator_arith3

Finally, I checked in Compiler Explorer what Clang 19 currently generates:
https://godbolt.org/z/oY7Ks4j95
That is just the familiar 4 multiplications.


Patch is 57.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119591.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (-155)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+6-533)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 807beebe4fb22a..473b1da4f701c7 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -696,177 +696,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
     auto elementType = cast<FloatType>(type.getElementType());
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
     auto fmfValue = fmf.getValue();
-
     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
-    Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
-    Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
-    Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
-    Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
-
     Value lhsRealTimesRhsReal =
         b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
-    Value lhsRealTimesRhsRealAbs =
-        b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
     Value lhsImagTimesRhsImag =
         b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
-    Value lhsImagTimesRhsImagAbs =
-        b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
     Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
                                          lhsImagTimesRhsImag, fmfValue);
-
     Value lhsImagTimesRhsReal =
         b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
-    Value lhsImagTimesRhsRealAbs =
-        b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
     Value lhsRealTimesRhsImag =
         b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
-    Value lhsRealTimesRhsImagAbs =
-        b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
     Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
                                          lhsRealTimesRhsImag, fmfValue);
-
-    // Handle cases where the "naive" calculation results in NaN values.
-    Value realIsNan =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
-    Value imagIsNan =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
-    Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
-
-    Value inf = b.create<arith::ConstantOp>(
-        elementType,
-        b.getFloatAttr(elementType,
-                       APFloat::getInf(elementType.getFloatSemantics())));
-
-    // Case 1. `lhsReal` or `lhsImag` are infinite.
-    Value lhsRealIsInf =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
-    Value lhsImagIsInf =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
-    Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
-    Value rhsRealIsNan =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
-    Value rhsImagIsNan =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
-    Value zero =
-        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
-    Value one = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 1));
-    Value lhsRealIsInfFloat =
-        b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
-    lhsReal = b.create<arith::SelectOp>(
-        lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
-        lhsReal);
-    Value lhsImagIsInfFloat =
-        b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
-    lhsImag = b.create<arith::SelectOp>(
-        lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
-        lhsImag);
-    Value lhsIsInfAndRhsRealIsNan =
-        b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
-    rhsReal = b.create<arith::SelectOp>(
-        lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
-        rhsReal);
-    Value lhsIsInfAndRhsImagIsNan =
-        b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
-    rhsImag = b.create<arith::SelectOp>(
-        lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
-        rhsImag);
-
-    // Case 2. `rhsReal` or `rhsImag` are infinite.
-    Value rhsRealIsInf =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
-    Value rhsImagIsInf =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
-    Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
-    Value lhsRealIsNan =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
-    Value lhsImagIsNan =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
-    Value rhsRealIsInfFloat =
-        b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
-    rhsReal = b.create<arith::SelectOp>(
-        rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
-        rhsReal);
-    Value rhsImagIsInfFloat =
-        b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
-    rhsImag = b.create<arith::SelectOp>(
-        rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
-        rhsImag);
-    Value rhsIsInfAndLhsRealIsNan =
-        b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
-    lhsReal = b.create<arith::SelectOp>(
-        rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
-        lhsReal);
-    Value rhsIsInfAndLhsImagIsNan =
-        b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
-    lhsImag = b.create<arith::SelectOp>(
-        rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
-        lhsImag);
-    Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
-
-    // Case 3. One of the pairwise products of left hand side with right hand
-    // side is infinite.
-    Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
-        arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
-    Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
-        arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
-    Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
-                                                 lhsImagTimesRhsImagIsInf);
-    Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
-        arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
-    isSpecialCase =
-        b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
-    Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
-        arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
-    isSpecialCase =
-        b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
-    Type i1Type = b.getI1Type();
-    Value notRecalc = b.create<arith::XOrIOp>(
-        recalc,
-        b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
-    isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
-    Value isSpecialCaseAndLhsRealIsNan =
-        b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
-    lhsReal = b.create<arith::SelectOp>(
-        isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
-        lhsReal);
-    Value isSpecialCaseAndLhsImagIsNan =
-        b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
-    lhsImag = b.create<arith::SelectOp>(
-        isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
-        lhsImag);
-    Value isSpecialCaseAndRhsRealIsNan =
-        b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
-    rhsReal = b.create<arith::SelectOp>(
-        isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
-        rhsReal);
-    Value isSpecialCaseAndRhsImagIsNan =
-        b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
-    rhsImag = b.create<arith::SelectOp>(
-        isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
-        rhsImag);
-    recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
-    recalc = b.create<arith::AndIOp>(isNan, recalc);
-
-    // Recalculate real part.
-    lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
-    lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
-    Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
-                                            lhsImagTimesRhsImag, fmfValue);
-    real = b.create<arith::SelectOp>(
-        recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
-
-    // Recalculate imag part.
-    lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
-    lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
-    Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
-                                            lhsRealTimesRhsImag, fmfValue);
-    imag = b.create<arith::SelectOp>(
-        recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
-
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
     return success();
   }
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 3d73292e6b8868..a4ddabbd0821ac 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -339,115 +339,19 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
   return %mul : complex<f32>
 }
 // CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] : f32
 // CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] : f32
 // CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] : f32
 // CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] : f32
 
 // CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_REAL]] : f32
 // CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
 // CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
 
 // CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_REAL]] : f32
 // CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
 // CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
 
-// Handle cases where the "naive" calculation results in NaN values.
-// CHECK: %[[REAL_IS_NAN:.*]] = arith.cmpf uno, %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[IS_NAN:.*]] = arith.andi %[[REAL_IS_NAN]], %[[IMAG_IS_NAN]] : i1
-// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
-
-// Case 1. LHS_REAL or LHS_IMAG are infinite.
-// CHECK: %[[LHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[LHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[LHS_IS_INF:.*]] = arith.ori %[[LHS_REAL_IS_INF]], %[[LHS_IMAG_IS_INF]] : i1
-// CHECK:  %[[RHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_REAL]], %[[RHS_REAL]] : f32
-// CHECK: %[[RHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32
-// CHECK: %[[LHS_REAL1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32
-// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32
-// CHECK: %[[LHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32
-// CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL]] : f32
-// CHECK: %[[RHS_REAL1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32
-// CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG]] : f32
-// CHECK: %[[RHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32
-
-// Case 2. RHS_REAL or RHS_IMAG are infinite.
-// CHECK: %[[RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[RHS_IS_INF:.*]] = arith.ori %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1
-// CHECK: %[[LHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32
-// CHECK: %[[LHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32
-// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32
-// CHECK: %[[RHS_REAL2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32
-// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32
-// CHECK: %[[RHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32
-// CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL1]] : f32
-// CHECK: %[[LHS_REAL2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32
-// CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG1]] : f32
-// CHECK: %[[LHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32
-// CHECK: %[[RECALC:.*]] = arith.ori %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1
-
-// Case 3. One of the pairwise products of left hand side with right hand side
-// is infinite.
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[IS_SPECIAL_CASE:.*]] = arith.ori %[[LHS_REAL_TIMES_RHS_REAL_IS_INF]], %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF]] : i1
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[IS_SPECIAL_CASE1:.*]] = arith.ori %[[IS_SPECIAL_CASE]], %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF]] : i1
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[IS_SPECIAL_CASE2:.*]] = arith.ori %[[IS_SPECIAL_CASE1]], %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF]] : i1
-// CHECK: %[[TRUE:.*]] = arith.constant true
-// CHECK: %[[NOT_RECALC:.*]] = arith.xori %[[RECALC]], %[[TRUE]] : i1
-// CHECK: %[[IS_SPECIAL_CASE3:.*]] = arith.andi %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1
-// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL2]] : f32
-// CHECK: %[[LHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32
-// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG2]] : f32
-// CHECK: %[[LHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32
-// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL2]] : f32
-// CHECK: %[[RHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32
-// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG2]] : f32
-// CHECK: %[[RHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32
-// CHECK: %[[RECALC2:.*]] = arith.ori %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1
-// CHECK: %[[RECALC3:.*]] = arith.andi %[[IS_NAN]], %[[RECALC2]] : i1
-
- // Recalculate real part.
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_REAL3]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] : f32
-// CHECK: %[[NEW_REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
-// CHECK: %[[NEW_REAL_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_REAL]] : f32
-// CHECK: %[[FINAL_REAL:.*]] = arith.select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32
-
-// Recalculate imag part.
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] : f32
-// CHECK: %[[NEW_IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
-// CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_IMAG]] : f32
-// CHECK: %[[FINAL_IMAG:.*]] = arith.select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32
-
-// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
 // -----
@@ -977,115 +881,16 @@ func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
   return %mul : complex<f32>
 }
 // CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] fastmath<nnan,contract> : f32
 // CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] fastmath<nnan,contract> : f32
 // CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] fastmath<nnan,contract> : f32
 // CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-
 // CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
 // CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
 // CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
-
 // CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
 // CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RH...
[truncated]

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with all the spec details, but +1 for simpler lowering of complex math by default. Looks good to me.

@bjacob
Copy link
Contributor Author

bjacob commented Dec 11, 2024

Giving some time for reviewers in other time zones to see this, by default will merge tomorrow.

@bjacob bjacob merged commit bdd3658 into llvm:main Dec 12, 2024
10 checks passed
@rupprecht
Copy link
Collaborator

I found some tests in TF that fail due to this change. See cases MulComplex64SpecialCases and MulComplex128SpecialCases: https://github.com/tensorflow/tensorflow/blob/913caa8102bc4a7940f1aa205803e1188737c301/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc#L1052

The failures are all inf/nan types of differences, so it seem expected based on this patch, e.g.

tensorflow/core/framework/tensor_testutil.cc:213
Value of: IsClose(Tx[i], Ty[i], typed_atol, typed_rtol)
  Actual: false ((-inf,-nan) not close to (nan,nan))
Expected: true
i = 65 Tx[i] = (-inf,-nan) Ty[i] = (nan,nan)

I expect @akuegel is best suited to explain if we still need this behavior, but may be on vacation right now.

@bjacob
Copy link
Contributor Author

bjacob commented Dec 13, 2024

Happy to discuss this any time! Consistently with the above explanation, my opinion is that this test simply needs to be updated to expect the new results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants