Skip to content

Commit bdd3658

Browse files
authored
[MLIR] Fix ComplexToStandard lowering of complex::MulOp (#119591)
A complex multiplication should lower simply to the familiar 4 real multiplications, 1 real addition, 1 real subtraction. No special-casing of infinite or NaN values should be made, instead the complex numbers should be thought as just vectors of two reals, naturally bottoming out on the reals' semantics, IEEE754 or otherwise. That is what nearly everybody else is doing ("nearly" because at the end of this PR description we pinpoint the actual source of this in C99 `_Complex`), and this pattern, by trying to do something different, was generating much larger code, which was much slower and a departure from the naturally expected floating-point behavior. This code had originally been introduced in https://reviews.llvm.org/D105270, which stated this rationale: > The lowering handles special cases with NaN or infinity like C++. I don't think that the C++ standard is a particularly important thing to follow in this instance. What matters more is what people actually do in practice with complex numbers, which rarely involves the C++ `std::complex` library type. But out of curiosity, I checked, and the above statement seems incorrect. The [current C++ standard](https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/n4928.pdf) library specification for `std::complex` does not say anything about the implementation of complex multiplication: paragraph `[complex.ops]` falls back on `[complex.member.ops]` which says: > Effects: Multiplies the complex value rhs by the complex value *this and stores the product in *this. I also checked cppreference which often has useful information in case something changed in a c++ language revision, but likewise, nothing at all there: https://en.cppreference.com/w/cpp/numeric/complex/operator_arith3 Finally, I checked in Compiler Explorer what Clang 19 currently generates: https://godbolt.org/z/oY7Ks4j95 That is just the familiar 4 multiplications.... and then there is some weird check (`fcmp`) and conditionally a call to an external `__mulsc3`. Googled that, found this StackOverflow answer: https://stackoverflow.com/a/49438578 Summary: this is not about C++ (this post confirms my reading of the C++ standard not mandating anything about this). This is about C, and it just happens that this C++ standard library implementation bottoms out on code shared with the C `_Complex` implementation. Another nuance missing in that SO answer: this is actually [implementation-defined behavior](https://en.cppreference.com/w/c/preprocessor/impl). There are two modes, controlled by ```c #pragma STDC CX_LIMITED_RANGE {ON,OFF,DEFAULT} ``` It is implementation-defined which is the default. Clang defaults to OFF, but that's just Clang. In that mode, the check is required: https://en.cppreference.com/w/c/language/arithmetic_types#Complex_floating_types And the specific point in the [C99 standard](https://www.open-std.org/jtc1/sc22/wg14/www/docs/n1256.pdf) is: `G.5.1 Multiplicative operators`. But set it to ON and the check is gone: https://godbolt.org/z/aG8fnbYoP Summary: the argument has moved from C++ to C --- and even there, to implementation-defined behavior with a standard opt-out mechanism. Like with C++, I maintain that the C standard is not a particularly meaningful thing for MLIR to follow here, because people doing business with complex numbers tend to lower them to real numbers themselves, or have their own specialized complex types, either way not relying on C99's `_Complex` type --- and the very poor performance of the `CX_LIMITED_RANGE OFF` behavior (default in Clang) is certainly a key reason why people who care prefer to stay away from `_Complex` and `std::complex`. A good example that's relevant to MLIR's space is CUDA's `cuComplex` type (used in the cuBLAS CGEMM interface). Here is its multiplication function. The comment about competitiveness is interesting: it's not a quirk of this particular function, it's the spirit underpinning numerical code that matters. https://github.com/tpn/cuda-samples/blob/1bf5cd15c51ce80fc9b387c0ff89a9f535b42bf5/v8.0/include/cuComplex.h#L106-L120 ```c /* This implementation could suffer from intermediate overflow even though * the final result would be in range. However, various implementations do * not guard against this (presumably to avoid losing performance), so we * don't do it either to stay competitive. */ __host__ __device__ static __inline__ cuFloatComplex cuCmulf (cuFloatComplex x, cuFloatComplex y) { cuFloatComplex prod; prod = make_cuFloatComplex ((cuCrealf(x) * cuCrealf(y)) - (cuCimagf(x) * cuCimagf(y)), (cuCrealf(x) * cuCimagf(y)) + (cuCimagf(x) * cuCrealf(y))); return prod; } ``` Another instance in CUTLASS: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/complex.h#L231-L236 Signed-off-by: Benoit Jacob <[email protected]>
1 parent 60d9e6f commit bdd3658

File tree

2 files changed

+6
-688
lines changed

2 files changed

+6
-688
lines changed

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 0 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -696,177 +696,22 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
696696
auto elementType = cast<FloatType>(type.getElementType());
697697
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698698
auto fmfValue = fmf.getValue();
699-
700699
Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
701-
Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
702700
Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
703-
Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
704701
Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
705-
Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
706702
Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
707-
Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
708-
709703
Value lhsRealTimesRhsReal =
710704
b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
711-
Value lhsRealTimesRhsRealAbs =
712-
b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
713705
Value lhsImagTimesRhsImag =
714706
b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
715-
Value lhsImagTimesRhsImagAbs =
716-
b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
717707
Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
718708
lhsImagTimesRhsImag, fmfValue);
719-
720709
Value lhsImagTimesRhsReal =
721710
b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
722-
Value lhsImagTimesRhsRealAbs =
723-
b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
724711
Value lhsRealTimesRhsImag =
725712
b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
726-
Value lhsRealTimesRhsImagAbs =
727-
b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
728713
Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
729714
lhsRealTimesRhsImag, fmfValue);
730-
731-
// Handle cases where the "naive" calculation results in NaN values.
732-
Value realIsNan =
733-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
734-
Value imagIsNan =
735-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
736-
Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
737-
738-
Value inf = b.create<arith::ConstantOp>(
739-
elementType,
740-
b.getFloatAttr(elementType,
741-
APFloat::getInf(elementType.getFloatSemantics())));
742-
743-
// Case 1. `lhsReal` or `lhsImag` are infinite.
744-
Value lhsRealIsInf =
745-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
746-
Value lhsImagIsInf =
747-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
748-
Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
749-
Value rhsRealIsNan =
750-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
751-
Value rhsImagIsNan =
752-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
753-
Value zero =
754-
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
755-
Value one = b.create<arith::ConstantOp>(elementType,
756-
b.getFloatAttr(elementType, 1));
757-
Value lhsRealIsInfFloat =
758-
b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
759-
lhsReal = b.create<arith::SelectOp>(
760-
lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
761-
lhsReal);
762-
Value lhsImagIsInfFloat =
763-
b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
764-
lhsImag = b.create<arith::SelectOp>(
765-
lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
766-
lhsImag);
767-
Value lhsIsInfAndRhsRealIsNan =
768-
b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
769-
rhsReal = b.create<arith::SelectOp>(
770-
lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
771-
rhsReal);
772-
Value lhsIsInfAndRhsImagIsNan =
773-
b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
774-
rhsImag = b.create<arith::SelectOp>(
775-
lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
776-
rhsImag);
777-
778-
// Case 2. `rhsReal` or `rhsImag` are infinite.
779-
Value rhsRealIsInf =
780-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
781-
Value rhsImagIsInf =
782-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
783-
Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
784-
Value lhsRealIsNan =
785-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
786-
Value lhsImagIsNan =
787-
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
788-
Value rhsRealIsInfFloat =
789-
b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
790-
rhsReal = b.create<arith::SelectOp>(
791-
rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
792-
rhsReal);
793-
Value rhsImagIsInfFloat =
794-
b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
795-
rhsImag = b.create<arith::SelectOp>(
796-
rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
797-
rhsImag);
798-
Value rhsIsInfAndLhsRealIsNan =
799-
b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
800-
lhsReal = b.create<arith::SelectOp>(
801-
rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
802-
lhsReal);
803-
Value rhsIsInfAndLhsImagIsNan =
804-
b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
805-
lhsImag = b.create<arith::SelectOp>(
806-
rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
807-
lhsImag);
808-
Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
809-
810-
// Case 3. One of the pairwise products of left hand side with right hand
811-
// side is infinite.
812-
Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
813-
arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
814-
Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
815-
arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
816-
Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
817-
lhsImagTimesRhsImagIsInf);
818-
Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
819-
arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
820-
isSpecialCase =
821-
b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
822-
Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
823-
arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
824-
isSpecialCase =
825-
b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
826-
Type i1Type = b.getI1Type();
827-
Value notRecalc = b.create<arith::XOrIOp>(
828-
recalc,
829-
b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
830-
isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
831-
Value isSpecialCaseAndLhsRealIsNan =
832-
b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
833-
lhsReal = b.create<arith::SelectOp>(
834-
isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
835-
lhsReal);
836-
Value isSpecialCaseAndLhsImagIsNan =
837-
b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
838-
lhsImag = b.create<arith::SelectOp>(
839-
isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
840-
lhsImag);
841-
Value isSpecialCaseAndRhsRealIsNan =
842-
b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
843-
rhsReal = b.create<arith::SelectOp>(
844-
isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
845-
rhsReal);
846-
Value isSpecialCaseAndRhsImagIsNan =
847-
b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
848-
rhsImag = b.create<arith::SelectOp>(
849-
isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
850-
rhsImag);
851-
recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
852-
recalc = b.create<arith::AndIOp>(isNan, recalc);
853-
854-
// Recalculate real part.
855-
lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
856-
lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
857-
Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
858-
lhsImagTimesRhsImag, fmfValue);
859-
real = b.create<arith::SelectOp>(
860-
recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
861-
862-
// Recalculate imag part.
863-
lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
864-
lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
865-
Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
866-
lhsRealTimesRhsImag, fmfValue);
867-
imag = b.create<arith::SelectOp>(
868-
recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
869-
870715
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
871716
return success();
872717
}

0 commit comments

Comments
 (0)