-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[clang][Interp] Implement Complex-complex multiplication #94891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Share the implementation for floating-point complex-complex multiplication with the current interpreter. This means we need a new opcode for this, but there's no good way around that.
@llvm/pr-subscribers-clang Author: Timm Baeder (tbaederr) ChangesShare the implementation for floating-point complex-complex multiplication with the current interpreter. This means we need a new opcode for this, but there's no good way around that. Full diff: https://github.com/llvm/llvm-project/pull/94891.diff 6 Files Affected:
diff --git a/clang/lib/AST/ExprConstShared.h b/clang/lib/AST/ExprConstShared.h
index a97eac85abc69..9decd47e41767 100644
--- a/clang/lib/AST/ExprConstShared.h
+++ b/clang/lib/AST/ExprConstShared.h
@@ -14,6 +14,9 @@
#ifndef LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
#define LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
+namespace llvm {
+class APFloat;
+}
namespace clang {
class QualType;
class LangOptions;
@@ -56,4 +59,8 @@ enum class GCCTypeClass {
GCCTypeClass EvaluateBuiltinClassifyType(QualType T,
const LangOptions &LangOpts);
+void HandleComplexComplexMul(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
+ llvm::APFloat D, llvm::APFloat &ResR,
+ llvm::APFloat &ResI);
+
#endif
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 86fb396fabe2d..7c597a238f041 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15126,6 +15126,62 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
llvm_unreachable("unknown cast resulting in complex value");
}
+void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
+ APFloat &ResR, APFloat &ResI) {
+ // This is an implementation of complex multiplication according to the
+ // constraints laid out in C11 Annex G. The implementation uses the
+ // following naming scheme:
+ // (a + ib) * (c + id)
+
+ APFloat AC = A * C;
+ APFloat BD = B * D;
+ APFloat AD = A * D;
+ APFloat BC = B * C;
+ ResR = AC - BD;
+ ResI = AD + BC;
+ if (ResR.isNaN() && ResI.isNaN()) {
+ bool Recalc = false;
+ if (A.isInfinity() || B.isInfinity()) {
+ A = APFloat::copySign(APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0),
+ A);
+ B = APFloat::copySign(APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0),
+ B);
+ if (C.isNaN())
+ C = APFloat::copySign(APFloat(C.getSemantics()), C);
+ if (D.isNaN())
+ D = APFloat::copySign(APFloat(D.getSemantics()), D);
+ Recalc = true;
+ }
+ if (C.isInfinity() || D.isInfinity()) {
+ C = APFloat::copySign(APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0),
+ C);
+ D = APFloat::copySign(APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0),
+ D);
+ if (A.isNaN())
+ A = APFloat::copySign(APFloat(A.getSemantics()), A);
+ if (B.isNaN())
+ B = APFloat::copySign(APFloat(B.getSemantics()), B);
+ Recalc = true;
+ }
+ if (!Recalc && (AC.isInfinity() || BD.isInfinity() || AD.isInfinity() ||
+ BC.isInfinity())) {
+ if (A.isNaN())
+ A = APFloat::copySign(APFloat(A.getSemantics()), A);
+ if (B.isNaN())
+ B = APFloat::copySign(APFloat(B.getSemantics()), B);
+ if (C.isNaN())
+ C = APFloat::copySign(APFloat(C.getSemantics()), C);
+ if (D.isNaN())
+ D = APFloat::copySign(APFloat(D.getSemantics()), D);
+ Recalc = true;
+ }
+ if (Recalc) {
+ ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
+ ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
+ }
+ }
+}
+
bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
if (E->isPtrMemOp() || E->isAssignmentOp() || E->getOpcode() == BO_Comma)
return ExprEvaluatorBaseTy::VisitBinaryOperator(E);
@@ -15225,55 +15281,7 @@ bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
!handleFloatFloatBinOp(Info, E, ResI, BO_Mul, B))
return false;
} else {
- // In the fully general case, we need to handle NaNs and infinities
- // robustly.
- APFloat AC = A * C;
- APFloat BD = B * D;
- APFloat AD = A * D;
- APFloat BC = B * C;
- ResR = AC - BD;
- ResI = AD + BC;
- if (ResR.isNaN() && ResI.isNaN()) {
- bool Recalc = false;
- if (A.isInfinity() || B.isInfinity()) {
- A = APFloat::copySign(
- APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0), A);
- B = APFloat::copySign(
- APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0), B);
- if (C.isNaN())
- C = APFloat::copySign(APFloat(C.getSemantics()), C);
- if (D.isNaN())
- D = APFloat::copySign(APFloat(D.getSemantics()), D);
- Recalc = true;
- }
- if (C.isInfinity() || D.isInfinity()) {
- C = APFloat::copySign(
- APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0), C);
- D = APFloat::copySign(
- APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0), D);
- if (A.isNaN())
- A = APFloat::copySign(APFloat(A.getSemantics()), A);
- if (B.isNaN())
- B = APFloat::copySign(APFloat(B.getSemantics()), B);
- Recalc = true;
- }
- if (!Recalc && (AC.isInfinity() || BD.isInfinity() ||
- AD.isInfinity() || BC.isInfinity())) {
- if (A.isNaN())
- A = APFloat::copySign(APFloat(A.getSemantics()), A);
- if (B.isNaN())
- B = APFloat::copySign(APFloat(B.getSemantics()), B);
- if (C.isNaN())
- C = APFloat::copySign(APFloat(C.getSemantics()), C);
- if (D.isNaN())
- D = APFloat::copySign(APFloat(D.getSemantics()), D);
- Recalc = true;
- }
- if (Recalc) {
- ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
- ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
- }
- }
+ HandleComplexComplexMul(A, B, C, D, ResR, ResI);
}
} else {
ComplexValue LHS = Result;
diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index ff2b51e3fb6fa..2fa479b818064 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -854,6 +854,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
if (const auto *AT = RHSType->getAs<AtomicType>())
RHSType = AT->getValueType();
+ // For ComplexComplex Mul, we have special ops to make their implementation
+ // easier.
+ BinaryOperatorKind Op = E->getOpcode();
+ if (Op == BO_Mul && LHSType->isAnyComplexType() &&
+ RHSType->isAnyComplexType()) {
+ assert(classifyPrim(LHSType->getAs<ComplexType>()->getElementType()) ==
+ classifyPrim(RHSType->getAs<ComplexType>()->getElementType()));
+ PrimType ElemT =
+ classifyPrim(LHSType->getAs<ComplexType>()->getElementType());
+ if (!this->visit(LHS))
+ return false;
+ if (!this->visit(RHS))
+ return false;
+ return this->emitMulc(ElemT, E);
+ }
+
// Evaluate LHS and save value to LHSOffset.
bool LHSIsComplex;
unsigned LHSOffset;
@@ -897,22 +913,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
// For both LHS and RHS, either load the value from the complex pointer, or
// directly from the local variable. For index 1 (i.e. the imaginary part),
// just load 0 and do the operation anyway.
- auto loadComplexValue = [this](bool IsComplex, unsigned ElemIndex,
- unsigned Offset, const Expr *E) -> bool {
+ auto loadComplexValue = [this](bool IsComplex, bool LoadZero,
+ unsigned ElemIndex, unsigned Offset,
+ const Expr *E) -> bool {
if (IsComplex) {
if (!this->emitGetLocal(PT_Ptr, Offset, E))
return false;
return this->emitArrayElemPop(classifyComplexElementType(E->getType()),
ElemIndex, E);
}
- if (ElemIndex == 0)
+ if (ElemIndex == 0 || !LoadZero)
return this->emitGetLocal(classifyPrim(E->getType()), Offset, E);
return this->visitZeroInitializer(classifyPrim(E->getType()), E->getType(),
E);
};
// Now we can get pointers to the LHS and RHS from the offsets above.
- BinaryOperatorKind Op = E->getOpcode();
for (unsigned ElemIndex = 0; ElemIndex != 2; ++ElemIndex) {
// Result pointer for the store later.
if (!this->DiscardResult) {
@@ -920,15 +936,14 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
return false;
}
- if (!loadComplexValue(LHSIsComplex, ElemIndex, LHSOffset, LHS))
- return false;
-
- if (!loadComplexValue(RHSIsComplex, ElemIndex, RHSOffset, RHS))
- return false;
-
// The actual operation.
switch (Op) {
case BO_Add:
+ if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
+ return false;
+
+ if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
+ return false;
if (ResultElemT == PT_Float) {
if (!this->emitAddf(getRoundingMode(E), E))
return false;
@@ -938,6 +953,11 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
}
break;
case BO_Sub:
+ if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
+ return false;
+
+ if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
+ return false;
if (ResultElemT == PT_Float) {
if (!this->emitSubf(getRoundingMode(E), E))
return false;
@@ -946,6 +966,21 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
return false;
}
break;
+ case BO_Mul:
+ if (!loadComplexValue(LHSIsComplex, false, ElemIndex, LHSOffset, LHS))
+ return false;
+
+ if (!loadComplexValue(RHSIsComplex, false, ElemIndex, RHSOffset, RHS))
+ return false;
+
+ if (ResultElemT == PT_Float) {
+ if (!this->emitMulf(getRoundingMode(E), E))
+ return false;
+ } else {
+ if (!this->emitMul(ResultElemT, E))
+ return false;
+ }
+ break;
default:
return false;
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index f63711da90c7e..116a9c799a639 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -13,6 +13,7 @@
#ifndef LLVM_CLANG_AST_INTERP_INTERP_H
#define LLVM_CLANG_AST_INTERP_INTERP_H
+#include "../ExprConstShared.h"
#include "Boolean.h"
#include "Floating.h"
#include "Function.h"
@@ -368,6 +369,62 @@ inline bool Mulf(InterpState &S, CodePtr OpPC, llvm::RoundingMode RM) {
S.Stk.push<Floating>(Result);
return CheckFloatResult(S, OpPC, Result, Status);
}
+
+template <PrimType Name, class T = typename PrimConv<Name>::T>
+inline bool Mulc(InterpState &S, CodePtr OpPC) {
+ const Pointer &RHS = S.Stk.pop<Pointer>();
+ const Pointer &LHS = S.Stk.pop<Pointer>();
+ const Pointer &Result = S.Stk.peek<Pointer>();
+
+ if constexpr (std::is_same_v<T, Floating>) {
+ APFloat A = LHS.atIndex(0).deref<Floating>().getAPFloat();
+ APFloat B = LHS.atIndex(1).deref<Floating>().getAPFloat();
+ APFloat C = RHS.atIndex(0).deref<Floating>().getAPFloat();
+ APFloat D = RHS.atIndex(1).deref<Floating>().getAPFloat();
+
+ APFloat ResR(A.getSemantics());
+ APFloat ResI(A.getSemantics());
+ HandleComplexComplexMul(A, B, C, D, ResR, ResI);
+
+ // Copy into the result.
+ Result.atIndex(0).deref<Floating>() = Floating(ResR);
+ Result.atIndex(0).initialize();
+ Result.atIndex(1).deref<Floating>() = Floating(ResI);
+ Result.atIndex(1).initialize();
+ Result.initialize();
+ } else {
+ // Integer element type.
+ const T &LHSR = LHS.atIndex(0).deref<T>();
+ const T &LHSI = LHS.atIndex(1).deref<T>();
+ const T &RHSR = RHS.atIndex(0).deref<T>();
+ const T &RHSI = RHS.atIndex(1).deref<T>();
+ unsigned Bits = LHSR.bitWidth();
+
+ // real(Result) = (real(LHS) * real(RHS)) - (imag(LHS) * imag(RHS))
+ T A;
+ if (T::mul(LHSR, RHSR, Bits, &A))
+ return false;
+ T B;
+ if (T::mul(LHSI, RHSI, Bits, &B))
+ return false;
+ if (T::sub(A, B, Bits, &Result.atIndex(0).deref<T>()))
+ return false;
+ Result.atIndex(0).initialize();
+
+ // imag(Result) = (real(LHS) * imag(RHS)) + (imag(LHS) * real(RHS))
+ if (T::mul(LHSR, RHSI, Bits, &A))
+ return false;
+ if (T::mul(LHSI, RHSR, Bits, &B))
+ return false;
+ if (T::add(A, B, Bits, &Result.atIndex(1).deref<T>()))
+ return false;
+ Result.atIndex(1).initialize();
+ Result.initialize();
+ }
+
+ return true;
+}
+
/// 1) Pops the RHS from the stack.
/// 2) Pops the LHS from the stack.
/// 3) Pushes 'LHS & RHS' on the stack
diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index a5ac8206104c8..c9884476e48b9 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -537,6 +537,10 @@ def Sub : AluOpcode;
def Subf : FloatOpcode;
def Mul : AluOpcode;
def Mulf : FloatOpcode;
+def Mulc : Opcode {
+ let Types = [NumberTypeClass];
+ let HasGroup = 1;
+}
def Rem : IntegerOpcode;
def Div : IntegerOpcode;
def Divf : FloatOpcode;
diff --git a/clang/test/AST/Interp/complex.cpp b/clang/test/AST/Interp/complex.cpp
index 09cb620d7b7c3..f6ed9a643a99e 100644
--- a/clang/test/AST/Interp/complex.cpp
+++ b/clang/test/AST/Interp/complex.cpp
@@ -9,6 +9,37 @@ static_assert(&__imag z1 == &__real z1 + 1, "");
static_assert((*(&__imag z1)) == __imag z1, "");
static_assert((*(&__real z1)) == __real z1, "");
+
+static_assert(((1.25 + 0.5j) * (0.25 - 0.75j)) == (0.6875 - 0.8125j), "");
+static_assert(((1.25 + 0.5j) * 0.25) == (0.3125 + 0.125j), "");
+static_assert((1.25 * (0.25 - 0.75j)) == (0.3125 - 0.9375j), "");
+constexpr _Complex float InfC = {1.0, __builtin_inf()};
+constexpr _Complex float InfInf = __builtin_inf() + InfC;
+static_assert(__real__(InfInf) == __builtin_inf());
+static_assert(__imag__(InfInf) == __builtin_inf());
+static_assert(__builtin_isnan(__real__(InfInf * InfInf)));
+static_assert(__builtin_isinf_sign(__imag__(InfInf * InfInf)) == 1);
+
+static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * 1.0)) == 1);
+static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * 1.0)) == 1);
+static_assert(__builtin_isinf_sign(__real__(1.0 * (__builtin_inf() + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__imag__(1.0 * (1.0 + InfC))) == 1);
+static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (1.0 + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + 1.0j))) == -1);
+static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * (1.0 + 1.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (1.0 + InfC))) == -1);
+static_assert(__builtin_isinf_sign(__imag__((1.0 + 1.0j) * (1.0 + InfC))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + InfC))) == -1);
+static_assert(__builtin_isinf_sign(__real__(InfInf * InfInf)) == 0);
+
+constexpr _Complex int IIMA = {1,2};
+constexpr _Complex int IIMB = {10,20};
+constexpr _Complex int IIMC = IIMA * IIMB;
+static_assert(__real(IIMC) == -30, "");
+static_assert(__imag(IIMC) == 40, "");
+
constexpr _Complex int Comma1 = {1, 2};
constexpr _Complex int Comma2 = (0, Comma1);
static_assert(Comma1 == Comma1, "");
|
Ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Share the implementation for floating-point complex-complex multiplication with the current interpreter. This means we need a new opcode for this, but there's no good way around that.