-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Fix ComplexToStandard
lowering of complex::MulOp
#119591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Benoit Jacob <[email protected]>
@llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesA complex multiplication should lower simply to the familiar 4 real multiplications, 1 real addition, 1 real subtraction. No special-casing of infinite or NaN values should be made, instead the complex numbers should be thought as just vectors of two reals, naturally bottoming out on the reals' semantics, IEEE754 or otherwise. That is what everybody else is doing, and this pattern, by trying to do something different, was generating much larger code, which was much slower and a departure from the naturally expected floating-point behavior. This code had originally been introduced in https://reviews.llvm.org/D105270, which stated this rationale: I don't think that the C++ standard is a particularly important thing to follow in this instance. What matters more is what people actually do in practice with complex numbers, which rarely involves the C++ But out of curiosity, I checked, and the above statement seems incorrect. The current C++ standard library specification for I also checked cppreference which often has useful information in case something changed in a c++ language revision, but likewise, nothing at all there: Finally, I checked in Compiler Explorer what Clang 19 currently generates: Patch is 57.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119591.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 807beebe4fb22a..473b1da4f701c7 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -696,177 +696,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
auto elementType = cast<FloatType>(type.getElementType());
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
auto fmfValue = fmf.getValue();
-
Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
- Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
- Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
- Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
- Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
-
Value lhsRealTimesRhsReal =
b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
- Value lhsRealTimesRhsRealAbs =
- b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
Value lhsImagTimesRhsImag =
b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
- Value lhsImagTimesRhsImagAbs =
- b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
lhsImagTimesRhsImag, fmfValue);
-
Value lhsImagTimesRhsReal =
b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
- Value lhsImagTimesRhsRealAbs =
- b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
Value lhsRealTimesRhsImag =
b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
- Value lhsRealTimesRhsImagAbs =
- b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
lhsRealTimesRhsImag, fmfValue);
-
- // Handle cases where the "naive" calculation results in NaN values.
- Value realIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
- Value imagIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
- Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
-
- Value inf = b.create<arith::ConstantOp>(
- elementType,
- b.getFloatAttr(elementType,
- APFloat::getInf(elementType.getFloatSemantics())));
-
- // Case 1. `lhsReal` or `lhsImag` are infinite.
- Value lhsRealIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
- Value lhsImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
- Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
- Value rhsRealIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
- Value rhsImagIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
- Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value lhsRealIsInfFloat =
- b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
- lhsReal = b.create<arith::SelectOp>(
- lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
- lhsReal);
- Value lhsImagIsInfFloat =
- b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
- lhsImag = b.create<arith::SelectOp>(
- lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
- lhsImag);
- Value lhsIsInfAndRhsRealIsNan =
- b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
- rhsReal = b.create<arith::SelectOp>(
- lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
- rhsReal);
- Value lhsIsInfAndRhsImagIsNan =
- b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
- rhsImag = b.create<arith::SelectOp>(
- lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
- rhsImag);
-
- // Case 2. `rhsReal` or `rhsImag` are infinite.
- Value rhsRealIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
- Value rhsImagIsInf =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
- Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
- Value lhsRealIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
- Value lhsImagIsNan =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
- Value rhsRealIsInfFloat =
- b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
- rhsReal = b.create<arith::SelectOp>(
- rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
- rhsReal);
- Value rhsImagIsInfFloat =
- b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
- rhsImag = b.create<arith::SelectOp>(
- rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
- rhsImag);
- Value rhsIsInfAndLhsRealIsNan =
- b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
- lhsReal = b.create<arith::SelectOp>(
- rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
- lhsReal);
- Value rhsIsInfAndLhsImagIsNan =
- b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
- lhsImag = b.create<arith::SelectOp>(
- rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
- lhsImag);
- Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
-
- // Case 3. One of the pairwise products of left hand side with right hand
- // side is infinite.
- Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
- arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
- Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
- arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
- Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
- lhsImagTimesRhsImagIsInf);
- Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
- arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
- isSpecialCase =
- b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
- Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
- arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
- isSpecialCase =
- b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
- Type i1Type = b.getI1Type();
- Value notRecalc = b.create<arith::XOrIOp>(
- recalc,
- b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
- isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
- Value isSpecialCaseAndLhsRealIsNan =
- b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
- lhsReal = b.create<arith::SelectOp>(
- isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
- lhsReal);
- Value isSpecialCaseAndLhsImagIsNan =
- b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
- lhsImag = b.create<arith::SelectOp>(
- isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
- lhsImag);
- Value isSpecialCaseAndRhsRealIsNan =
- b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
- rhsReal = b.create<arith::SelectOp>(
- isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
- rhsReal);
- Value isSpecialCaseAndRhsImagIsNan =
- b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
- rhsImag = b.create<arith::SelectOp>(
- isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
- rhsImag);
- recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
- recalc = b.create<arith::AndIOp>(isNan, recalc);
-
- // Recalculate real part.
- lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
- lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
- Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
- lhsImagTimesRhsImag, fmfValue);
- real = b.create<arith::SelectOp>(
- recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
-
- // Recalculate imag part.
- lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
- lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
- Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
- lhsRealTimesRhsImag, fmfValue);
- imag = b.create<arith::SelectOp>(
- recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
-
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
return success();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 3d73292e6b8868..a4ddabbd0821ac 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -339,115 +339,19 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
return %mul : complex<f32>
}
// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] : f32
// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] : f32
// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] : f32
// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] : f32
// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_REAL]] : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_REAL]] : f32
// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
-// Handle cases where the "naive" calculation results in NaN values.
-// CHECK: %[[REAL_IS_NAN:.*]] = arith.cmpf uno, %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[IS_NAN:.*]] = arith.andi %[[REAL_IS_NAN]], %[[IMAG_IS_NAN]] : i1
-// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
-
-// Case 1. LHS_REAL or LHS_IMAG are infinite.
-// CHECK: %[[LHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[LHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[LHS_IS_INF:.*]] = arith.ori %[[LHS_REAL_IS_INF]], %[[LHS_IMAG_IS_INF]] : i1
-// CHECK: %[[RHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_REAL]], %[[RHS_REAL]] : f32
-// CHECK: %[[RHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32
-// CHECK: %[[LHS_REAL1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32
-// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32
-// CHECK: %[[LHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32
-// CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL]] : f32
-// CHECK: %[[RHS_REAL1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32
-// CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG]] : f32
-// CHECK: %[[RHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32
-
-// Case 2. RHS_REAL or RHS_IMAG are infinite.
-// CHECK: %[[RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[RHS_IS_INF:.*]] = arith.ori %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1
-// CHECK: %[[LHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32
-// CHECK: %[[LHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32
-// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32
-// CHECK: %[[RHS_REAL2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32
-// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
-// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32
-// CHECK: %[[RHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32
-// CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL1]] : f32
-// CHECK: %[[LHS_REAL2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32
-// CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG1]] : f32
-// CHECK: %[[LHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32
-// CHECK: %[[RECALC:.*]] = arith.ori %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1
-
-// Case 3. One of the pairwise products of left hand side with right hand side
-// is infinite.
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[IS_SPECIAL_CASE:.*]] = arith.ori %[[LHS_REAL_TIMES_RHS_REAL_IS_INF]], %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF]] : i1
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
-// CHECK: %[[IS_SPECIAL_CASE1:.*]] = arith.ori %[[IS_SPECIAL_CASE]], %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF]] : i1
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
-// CHECK: %[[IS_SPECIAL_CASE2:.*]] = arith.ori %[[IS_SPECIAL_CASE1]], %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF]] : i1
-// CHECK: %[[TRUE:.*]] = arith.constant true
-// CHECK: %[[NOT_RECALC:.*]] = arith.xori %[[RECALC]], %[[TRUE]] : i1
-// CHECK: %[[IS_SPECIAL_CASE3:.*]] = arith.andi %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1
-// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL2]] : f32
-// CHECK: %[[LHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32
-// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG2]] : f32
-// CHECK: %[[LHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32
-// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL2]] : f32
-// CHECK: %[[RHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32
-// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1
-// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG2]] : f32
-// CHECK: %[[RHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32
-// CHECK: %[[RECALC2:.*]] = arith.ori %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1
-// CHECK: %[[RECALC3:.*]] = arith.andi %[[IS_NAN]], %[[RECALC2]] : i1
-
- // Recalculate real part.
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_REAL3]] : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] : f32
-// CHECK: %[[NEW_REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
-// CHECK: %[[NEW_REAL_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_REAL]] : f32
-// CHECK: %[[FINAL_REAL:.*]] = arith.select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32
-
-// Recalculate imag part.
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] : f32
-// CHECK: %[[NEW_IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
-// CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_IMAG]] : f32
-// CHECK: %[[FINAL_IMAG:.*]] = arith.select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32
-
-// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
+// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
// -----
@@ -977,115 +881,16 @@ func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
return %mul : complex<f32>
}
// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
-// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-
// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
-
// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RH...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not familiar with all the spec details, but +1 for simpler lowering of complex math by default. Looks good to me.
Giving some time for reviewers in other time zones to see this, by default will merge tomorrow. |
I found some tests in TF that fail due to this change. See cases MulComplex64SpecialCases and MulComplex128SpecialCases: https://github.com/tensorflow/tensorflow/blob/913caa8102bc4a7940f1aa205803e1188737c301/tensorflow/core/kernels/mlir_generated/gpu_binary_ops_test.cc#L1052 The failures are all inf/nan types of differences, so it seem expected based on this patch, e.g.
I expect @akuegel is best suited to explain if we still need this behavior, but may be on vacation right now. |
Happy to discuss this any time! Consistently with the above explanation, my opinion is that this test simply needs to be updated to expect the new results. |
A complex multiplication should lower simply to the familiar 4 real multiplications, 1 real addition, 1 real subtraction. No special-casing of infinite or NaN values should be made, instead the complex numbers should be thought as just vectors of two reals, naturally bottoming out on the reals' semantics, IEEE754 or otherwise. That is what nearly everybody else is doing ("nearly" because at the end of this PR description we pinpoint the actual source of this in C99
_Complex
), and this pattern, by trying to do something different, was generating much larger code, which was much slower and a departure from the naturally expected floating-point behavior.This code had originally been introduced in https://reviews.llvm.org/D105270, which stated this rationale:
I don't think that the C++ standard is a particularly important thing to follow in this instance. What matters more is what people actually do in practice with complex numbers, which rarely involves the C++
std::complex
library type.But out of curiosity, I checked, and the above statement seems incorrect. The current C++ standard library specification for
std::complex
does not say anything about the implementation of complex multiplication: paragraph[complex.ops]
falls back on[complex.member.ops]
which says:I also checked cppreference which often has useful information in case something changed in a c++ language revision, but likewise, nothing at all there:
https://en.cppreference.com/w/cpp/numeric/complex/operator_arith3
Finally, I checked in Compiler Explorer what Clang 19 currently generates:
https://godbolt.org/z/oY7Ks4j95
That is just the familiar 4 multiplications.... and then there is some weird check (
fcmp
) and conditionally a call to an external__mulsc3
. Googled that, found this StackOverflow answer:https://stackoverflow.com/a/49438578
Summary: this is not about C++ (this post confirms my reading of the C++ standard not mandating anything about this). This is about C, and it just happens that this C++ standard library implementation bottoms out on code shared with the C
_Complex
implementation.Another nuance missing in that SO answer: this is actually implementation-defined behavior. There are two modes, controlled by
#pragma STDC CX_LIMITED_RANGE {ON,OFF,DEFAULT}
It is implementation-defined which is the default. Clang defaults to OFF, but that's just Clang. In that mode, the check is required:
https://en.cppreference.com/w/c/language/arithmetic_types#Complex_floating_types
And the specific point in the C99 standard is:
G.5.1 Multiplicative operators
.But set it to ON and the check is gone:
https://godbolt.org/z/aG8fnbYoP
Summary: the argument has moved from C++ to C --- and even there, to implementation-defined behavior with a standard opt-out mechanism.
Like with C++, I maintain that the C standard is not a particularly meaningful thing for MLIR to follow here, because people doing business with complex numbers tend to lower them to real numbers themselves, or have their own specialized complex types, either way not relying on C99's
_Complex
type --- and the very poor performance of theCX_LIMITED_RANGE OFF
behavior (default in Clang) is certainly a key reason why people who care prefer to stay away from_Complex
andstd::complex
.A good example that's relevant to MLIR's space is CUDA's
cuComplex
type (used in the cuBLAS CGEMM interface). Here is its multiplication function. The comment about competitiveness is interesting: it's not a quirk of this particular function, it's the spirit underpinning numerical code that matters.https://github.com/tpn/cuda-samples/blob/1bf5cd15c51ce80fc9b387c0ff89a9f535b42bf5/v8.0/include/cuComplex.h#L106-L120
Another instance in CUTLASS: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/complex.h#L231-L236