Skip to content

[clang][Interp] Implement Complex-complex multiplication #94891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024

Conversation

tbaederr
Copy link
Contributor

@tbaederr tbaederr commented Jun 9, 2024

Share the implementation for floating-point complex-complex multiplication with the current interpreter. This means we need a new opcode for this, but there's no good way around that.

Share the implementation for floating-point complex-complex
multiplication with the current interpreter. This means we need a new
opcode for this, but there's no good way around that.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" labels Jun 9, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 9, 2024

@llvm/pr-subscribers-clang

Author: Timm Baeder (tbaederr)

Changes

Share the implementation for floating-point complex-complex multiplication with the current interpreter. This means we need a new opcode for this, but there's no good way around that.


Full diff: https://github.com/llvm/llvm-project/pull/94891.diff

6 Files Affected:

  • (modified) clang/lib/AST/ExprConstShared.h (+7)
  • (modified) clang/lib/AST/ExprConstant.cpp (+57-49)
  • (modified) clang/lib/AST/Interp/ByteCodeExprGen.cpp (+45-10)
  • (modified) clang/lib/AST/Interp/Interp.h (+57)
  • (modified) clang/lib/AST/Interp/Opcodes.td (+4)
  • (modified) clang/test/AST/Interp/complex.cpp (+31)
diff --git a/clang/lib/AST/ExprConstShared.h b/clang/lib/AST/ExprConstShared.h
index a97eac85abc69..9decd47e41767 100644
--- a/clang/lib/AST/ExprConstShared.h
+++ b/clang/lib/AST/ExprConstShared.h
@@ -14,6 +14,9 @@
 #ifndef LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
 #define LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
 
+namespace llvm {
+class APFloat;
+}
 namespace clang {
 class QualType;
 class LangOptions;
@@ -56,4 +59,8 @@ enum class GCCTypeClass {
 GCCTypeClass EvaluateBuiltinClassifyType(QualType T,
                                          const LangOptions &LangOpts);
 
+void HandleComplexComplexMul(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
+                             llvm::APFloat D, llvm::APFloat &ResR,
+                             llvm::APFloat &ResI);
+
 #endif
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 86fb396fabe2d..7c597a238f041 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15126,6 +15126,62 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   llvm_unreachable("unknown cast resulting in complex value");
 }
 
+void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
+                             APFloat &ResR, APFloat &ResI) {
+  // This is an implementation of complex multiplication according to the
+  // constraints laid out in C11 Annex G. The implementation uses the
+  // following naming scheme:
+  //   (a + ib) * (c + id)
+
+  APFloat AC = A * C;
+  APFloat BD = B * D;
+  APFloat AD = A * D;
+  APFloat BC = B * C;
+  ResR = AC - BD;
+  ResI = AD + BC;
+  if (ResR.isNaN() && ResI.isNaN()) {
+    bool Recalc = false;
+    if (A.isInfinity() || B.isInfinity()) {
+      A = APFloat::copySign(APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0),
+                            A);
+      B = APFloat::copySign(APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0),
+                            B);
+      if (C.isNaN())
+        C = APFloat::copySign(APFloat(C.getSemantics()), C);
+      if (D.isNaN())
+        D = APFloat::copySign(APFloat(D.getSemantics()), D);
+      Recalc = true;
+    }
+    if (C.isInfinity() || D.isInfinity()) {
+      C = APFloat::copySign(APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0),
+                            C);
+      D = APFloat::copySign(APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0),
+                            D);
+      if (A.isNaN())
+        A = APFloat::copySign(APFloat(A.getSemantics()), A);
+      if (B.isNaN())
+        B = APFloat::copySign(APFloat(B.getSemantics()), B);
+      Recalc = true;
+    }
+    if (!Recalc && (AC.isInfinity() || BD.isInfinity() || AD.isInfinity() ||
+                    BC.isInfinity())) {
+      if (A.isNaN())
+        A = APFloat::copySign(APFloat(A.getSemantics()), A);
+      if (B.isNaN())
+        B = APFloat::copySign(APFloat(B.getSemantics()), B);
+      if (C.isNaN())
+        C = APFloat::copySign(APFloat(C.getSemantics()), C);
+      if (D.isNaN())
+        D = APFloat::copySign(APFloat(D.getSemantics()), D);
+      Recalc = true;
+    }
+    if (Recalc) {
+      ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
+      ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
+    }
+  }
+}
+
 bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
   if (E->isPtrMemOp() || E->isAssignmentOp() || E->getOpcode() == BO_Comma)
     return ExprEvaluatorBaseTy::VisitBinaryOperator(E);
@@ -15225,55 +15281,7 @@ bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
             !handleFloatFloatBinOp(Info, E, ResI, BO_Mul, B))
           return false;
       } else {
-        // In the fully general case, we need to handle NaNs and infinities
-        // robustly.
-        APFloat AC = A * C;
-        APFloat BD = B * D;
-        APFloat AD = A * D;
-        APFloat BC = B * C;
-        ResR = AC - BD;
-        ResI = AD + BC;
-        if (ResR.isNaN() && ResI.isNaN()) {
-          bool Recalc = false;
-          if (A.isInfinity() || B.isInfinity()) {
-            A = APFloat::copySign(
-                APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0), A);
-            B = APFloat::copySign(
-                APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0), B);
-            if (C.isNaN())
-              C = APFloat::copySign(APFloat(C.getSemantics()), C);
-            if (D.isNaN())
-              D = APFloat::copySign(APFloat(D.getSemantics()), D);
-            Recalc = true;
-          }
-          if (C.isInfinity() || D.isInfinity()) {
-            C = APFloat::copySign(
-                APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0), C);
-            D = APFloat::copySign(
-                APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0), D);
-            if (A.isNaN())
-              A = APFloat::copySign(APFloat(A.getSemantics()), A);
-            if (B.isNaN())
-              B = APFloat::copySign(APFloat(B.getSemantics()), B);
-            Recalc = true;
-          }
-          if (!Recalc && (AC.isInfinity() || BD.isInfinity() ||
-                          AD.isInfinity() || BC.isInfinity())) {
-            if (A.isNaN())
-              A = APFloat::copySign(APFloat(A.getSemantics()), A);
-            if (B.isNaN())
-              B = APFloat::copySign(APFloat(B.getSemantics()), B);
-            if (C.isNaN())
-              C = APFloat::copySign(APFloat(C.getSemantics()), C);
-            if (D.isNaN())
-              D = APFloat::copySign(APFloat(D.getSemantics()), D);
-            Recalc = true;
-          }
-          if (Recalc) {
-            ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
-            ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
-          }
-        }
+        HandleComplexComplexMul(A, B, C, D, ResR, ResI);
       }
     } else {
       ComplexValue LHS = Result;
diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index ff2b51e3fb6fa..2fa479b818064 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -854,6 +854,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
   if (const auto *AT = RHSType->getAs<AtomicType>())
     RHSType = AT->getValueType();
 
+  // For ComplexComplex Mul, we have special ops to make their implementation
+  // easier.
+  BinaryOperatorKind Op = E->getOpcode();
+  if (Op == BO_Mul && LHSType->isAnyComplexType() &&
+      RHSType->isAnyComplexType()) {
+    assert(classifyPrim(LHSType->getAs<ComplexType>()->getElementType()) ==
+           classifyPrim(RHSType->getAs<ComplexType>()->getElementType()));
+    PrimType ElemT =
+        classifyPrim(LHSType->getAs<ComplexType>()->getElementType());
+    if (!this->visit(LHS))
+      return false;
+    if (!this->visit(RHS))
+      return false;
+    return this->emitMulc(ElemT, E);
+  }
+
   // Evaluate LHS and save value to LHSOffset.
   bool LHSIsComplex;
   unsigned LHSOffset;
@@ -897,22 +913,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
   // For both LHS and RHS, either load the value from the complex pointer, or
   // directly from the local variable. For index 1 (i.e. the imaginary part),
   // just load 0 and do the operation anyway.
-  auto loadComplexValue = [this](bool IsComplex, unsigned ElemIndex,
-                                 unsigned Offset, const Expr *E) -> bool {
+  auto loadComplexValue = [this](bool IsComplex, bool LoadZero,
+                                 unsigned ElemIndex, unsigned Offset,
+                                 const Expr *E) -> bool {
     if (IsComplex) {
       if (!this->emitGetLocal(PT_Ptr, Offset, E))
         return false;
       return this->emitArrayElemPop(classifyComplexElementType(E->getType()),
                                     ElemIndex, E);
     }
-    if (ElemIndex == 0)
+    if (ElemIndex == 0 || !LoadZero)
       return this->emitGetLocal(classifyPrim(E->getType()), Offset, E);
     return this->visitZeroInitializer(classifyPrim(E->getType()), E->getType(),
                                       E);
   };
 
   // Now we can get pointers to the LHS and RHS from the offsets above.
-  BinaryOperatorKind Op = E->getOpcode();
   for (unsigned ElemIndex = 0; ElemIndex != 2; ++ElemIndex) {
     // Result pointer for the store later.
     if (!this->DiscardResult) {
@@ -920,15 +936,14 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
         return false;
     }
 
-    if (!loadComplexValue(LHSIsComplex, ElemIndex, LHSOffset, LHS))
-      return false;
-
-    if (!loadComplexValue(RHSIsComplex, ElemIndex, RHSOffset, RHS))
-      return false;
-
     // The actual operation.
     switch (Op) {
     case BO_Add:
+      if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
+        return false;
+
+      if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
+        return false;
       if (ResultElemT == PT_Float) {
         if (!this->emitAddf(getRoundingMode(E), E))
           return false;
@@ -938,6 +953,11 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
       }
       break;
     case BO_Sub:
+      if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
+        return false;
+
+      if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
+        return false;
       if (ResultElemT == PT_Float) {
         if (!this->emitSubf(getRoundingMode(E), E))
           return false;
@@ -946,6 +966,21 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
           return false;
       }
       break;
+    case BO_Mul:
+      if (!loadComplexValue(LHSIsComplex, false, ElemIndex, LHSOffset, LHS))
+        return false;
+
+      if (!loadComplexValue(RHSIsComplex, false, ElemIndex, RHSOffset, RHS))
+        return false;
+
+      if (ResultElemT == PT_Float) {
+        if (!this->emitMulf(getRoundingMode(E), E))
+          return false;
+      } else {
+        if (!this->emitMul(ResultElemT, E))
+          return false;
+      }
+      break;
 
     default:
       return false;
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index f63711da90c7e..116a9c799a639 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_CLANG_AST_INTERP_INTERP_H
 #define LLVM_CLANG_AST_INTERP_INTERP_H
 
+#include "../ExprConstShared.h"
 #include "Boolean.h"
 #include "Floating.h"
 #include "Function.h"
@@ -368,6 +369,62 @@ inline bool Mulf(InterpState &S, CodePtr OpPC, llvm::RoundingMode RM) {
   S.Stk.push<Floating>(Result);
   return CheckFloatResult(S, OpPC, Result, Status);
 }
+
+template <PrimType Name, class T = typename PrimConv<Name>::T>
+inline bool Mulc(InterpState &S, CodePtr OpPC) {
+  const Pointer &RHS = S.Stk.pop<Pointer>();
+  const Pointer &LHS = S.Stk.pop<Pointer>();
+  const Pointer &Result = S.Stk.peek<Pointer>();
+
+  if constexpr (std::is_same_v<T, Floating>) {
+    APFloat A = LHS.atIndex(0).deref<Floating>().getAPFloat();
+    APFloat B = LHS.atIndex(1).deref<Floating>().getAPFloat();
+    APFloat C = RHS.atIndex(0).deref<Floating>().getAPFloat();
+    APFloat D = RHS.atIndex(1).deref<Floating>().getAPFloat();
+
+    APFloat ResR(A.getSemantics());
+    APFloat ResI(A.getSemantics());
+    HandleComplexComplexMul(A, B, C, D, ResR, ResI);
+
+    // Copy into the result.
+    Result.atIndex(0).deref<Floating>() = Floating(ResR);
+    Result.atIndex(0).initialize();
+    Result.atIndex(1).deref<Floating>() = Floating(ResI);
+    Result.atIndex(1).initialize();
+    Result.initialize();
+  } else {
+    // Integer element type.
+    const T &LHSR = LHS.atIndex(0).deref<T>();
+    const T &LHSI = LHS.atIndex(1).deref<T>();
+    const T &RHSR = RHS.atIndex(0).deref<T>();
+    const T &RHSI = RHS.atIndex(1).deref<T>();
+    unsigned Bits = LHSR.bitWidth();
+
+    // real(Result) = (real(LHS) * real(RHS)) - (imag(LHS) * imag(RHS))
+    T A;
+    if (T::mul(LHSR, RHSR, Bits, &A))
+      return false;
+    T B;
+    if (T::mul(LHSI, RHSI, Bits, &B))
+      return false;
+    if (T::sub(A, B, Bits, &Result.atIndex(0).deref<T>()))
+      return false;
+    Result.atIndex(0).initialize();
+
+    // imag(Result) = (real(LHS) * imag(RHS)) + (imag(LHS) * real(RHS))
+    if (T::mul(LHSR, RHSI, Bits, &A))
+      return false;
+    if (T::mul(LHSI, RHSR, Bits, &B))
+      return false;
+    if (T::add(A, B, Bits, &Result.atIndex(1).deref<T>()))
+      return false;
+    Result.atIndex(1).initialize();
+    Result.initialize();
+  }
+
+  return true;
+}
+
 /// 1) Pops the RHS from the stack.
 /// 2) Pops the LHS from the stack.
 /// 3) Pushes 'LHS & RHS' on the stack
diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index a5ac8206104c8..c9884476e48b9 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -537,6 +537,10 @@ def Sub  : AluOpcode;
 def Subf : FloatOpcode;
 def Mul  : AluOpcode;
 def Mulf : FloatOpcode;
+def Mulc : Opcode {
+  let Types = [NumberTypeClass];
+  let HasGroup = 1;
+}
 def Rem  : IntegerOpcode;
 def Div  : IntegerOpcode;
 def Divf : FloatOpcode;
diff --git a/clang/test/AST/Interp/complex.cpp b/clang/test/AST/Interp/complex.cpp
index 09cb620d7b7c3..f6ed9a643a99e 100644
--- a/clang/test/AST/Interp/complex.cpp
+++ b/clang/test/AST/Interp/complex.cpp
@@ -9,6 +9,37 @@ static_assert(&__imag z1 == &__real z1 + 1, "");
 static_assert((*(&__imag z1)) == __imag z1, "");
 static_assert((*(&__real z1)) == __real z1, "");
 
+
+static_assert(((1.25 + 0.5j) * (0.25 - 0.75j)) == (0.6875 - 0.8125j), "");
+static_assert(((1.25 + 0.5j) * 0.25) == (0.3125 + 0.125j), "");
+static_assert((1.25 * (0.25 - 0.75j)) == (0.3125 - 0.9375j), "");
+constexpr _Complex float InfC = {1.0, __builtin_inf()};
+constexpr _Complex float InfInf = __builtin_inf() + InfC;
+static_assert(__real__(InfInf) == __builtin_inf());
+static_assert(__imag__(InfInf) == __builtin_inf());
+static_assert(__builtin_isnan(__real__(InfInf * InfInf)));
+static_assert(__builtin_isinf_sign(__imag__(InfInf * InfInf)) == 1);
+
+static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * 1.0)) == 1);
+static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * 1.0)) == 1);
+static_assert(__builtin_isinf_sign(__real__(1.0 * (__builtin_inf() + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__imag__(1.0 * (1.0 + InfC))) == 1);
+static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (1.0 + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + 1.0j))) == -1);
+static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * (1.0 + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (1.0 + InfC))) == -1);
+static_assert(__builtin_isinf_sign(__imag__((1.0 + 1.0j) * (1.0 + InfC))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + InfC))) == -1);
+static_assert(__builtin_isinf_sign(__real__(InfInf * InfInf)) == 0);
+
+constexpr _Complex int IIMA = {1,2};
+constexpr _Complex int IIMB = {10,20};
+constexpr _Complex int IIMC = IIMA * IIMB;
+static_assert(__real(IIMC) == -30, "");
+static_assert(__imag(IIMC) == 40, "");
+
 constexpr _Complex int Comma1 = {1, 2};
 constexpr _Complex int Comma2 = (0, Comma1);
 static_assert(Comma1 == Comma1, "");

@tbaederr
Copy link
Contributor Author

Ping

Copy link
Collaborator

@AaronBallman AaronBallman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@tbaederr tbaederr merged commit 4bf160e into llvm:main Jun 17, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants