Skip to content

[clang][Interp] Implement Complex-complex multiplication #94891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions clang/lib/AST/ExprConstShared.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#ifndef LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H
#define LLVM_CLANG_LIB_AST_EXPRCONSTSHARED_H

namespace llvm {
class APFloat;
}
namespace clang {
class QualType;
class LangOptions;
Expand Down Expand Up @@ -56,4 +59,8 @@ enum class GCCTypeClass {
GCCTypeClass EvaluateBuiltinClassifyType(QualType T,
const LangOptions &LangOpts);

void HandleComplexComplexMul(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
llvm::APFloat D, llvm::APFloat &ResR,
llvm::APFloat &ResI);

#endif
106 changes: 57 additions & 49 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15126,6 +15126,62 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
llvm_unreachable("unknown cast resulting in complex value");
}

void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
APFloat &ResR, APFloat &ResI) {
// This is an implementation of complex multiplication according to the
// constraints laid out in C11 Annex G. The implementation uses the
// following naming scheme:
// (a + ib) * (c + id)

APFloat AC = A * C;
APFloat BD = B * D;
APFloat AD = A * D;
APFloat BC = B * C;
ResR = AC - BD;
ResI = AD + BC;
if (ResR.isNaN() && ResI.isNaN()) {
bool Recalc = false;
if (A.isInfinity() || B.isInfinity()) {
A = APFloat::copySign(APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0),
A);
B = APFloat::copySign(APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0),
B);
if (C.isNaN())
C = APFloat::copySign(APFloat(C.getSemantics()), C);
if (D.isNaN())
D = APFloat::copySign(APFloat(D.getSemantics()), D);
Recalc = true;
}
if (C.isInfinity() || D.isInfinity()) {
C = APFloat::copySign(APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0),
C);
D = APFloat::copySign(APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0),
D);
if (A.isNaN())
A = APFloat::copySign(APFloat(A.getSemantics()), A);
if (B.isNaN())
B = APFloat::copySign(APFloat(B.getSemantics()), B);
Recalc = true;
}
if (!Recalc && (AC.isInfinity() || BD.isInfinity() || AD.isInfinity() ||
BC.isInfinity())) {
if (A.isNaN())
A = APFloat::copySign(APFloat(A.getSemantics()), A);
if (B.isNaN())
B = APFloat::copySign(APFloat(B.getSemantics()), B);
if (C.isNaN())
C = APFloat::copySign(APFloat(C.getSemantics()), C);
if (D.isNaN())
D = APFloat::copySign(APFloat(D.getSemantics()), D);
Recalc = true;
}
if (Recalc) {
ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
}
}
}

bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
if (E->isPtrMemOp() || E->isAssignmentOp() || E->getOpcode() == BO_Comma)
return ExprEvaluatorBaseTy::VisitBinaryOperator(E);
Expand Down Expand Up @@ -15225,55 +15281,7 @@ bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
!handleFloatFloatBinOp(Info, E, ResI, BO_Mul, B))
return false;
} else {
// In the fully general case, we need to handle NaNs and infinities
// robustly.
APFloat AC = A * C;
APFloat BD = B * D;
APFloat AD = A * D;
APFloat BC = B * C;
ResR = AC - BD;
ResI = AD + BC;
if (ResR.isNaN() && ResI.isNaN()) {
bool Recalc = false;
if (A.isInfinity() || B.isInfinity()) {
A = APFloat::copySign(
APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0), A);
B = APFloat::copySign(
APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0), B);
if (C.isNaN())
C = APFloat::copySign(APFloat(C.getSemantics()), C);
if (D.isNaN())
D = APFloat::copySign(APFloat(D.getSemantics()), D);
Recalc = true;
}
if (C.isInfinity() || D.isInfinity()) {
C = APFloat::copySign(
APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0), C);
D = APFloat::copySign(
APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0), D);
if (A.isNaN())
A = APFloat::copySign(APFloat(A.getSemantics()), A);
if (B.isNaN())
B = APFloat::copySign(APFloat(B.getSemantics()), B);
Recalc = true;
}
if (!Recalc && (AC.isInfinity() || BD.isInfinity() ||
AD.isInfinity() || BC.isInfinity())) {
if (A.isNaN())
A = APFloat::copySign(APFloat(A.getSemantics()), A);
if (B.isNaN())
B = APFloat::copySign(APFloat(B.getSemantics()), B);
if (C.isNaN())
C = APFloat::copySign(APFloat(C.getSemantics()), C);
if (D.isNaN())
D = APFloat::copySign(APFloat(D.getSemantics()), D);
Recalc = true;
}
if (Recalc) {
ResR = APFloat::getInf(A.getSemantics()) * (A * C - B * D);
ResI = APFloat::getInf(A.getSemantics()) * (A * D + B * C);
}
}
HandleComplexComplexMul(A, B, C, D, ResR, ResI);
}
} else {
ComplexValue LHS = Result;
Expand Down
55 changes: 45 additions & 10 deletions clang/lib/AST/Interp/ByteCodeExprGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
if (const auto *AT = RHSType->getAs<AtomicType>())
RHSType = AT->getValueType();

// For ComplexComplex Mul, we have special ops to make their implementation
// easier.
BinaryOperatorKind Op = E->getOpcode();
if (Op == BO_Mul && LHSType->isAnyComplexType() &&
RHSType->isAnyComplexType()) {
assert(classifyPrim(LHSType->getAs<ComplexType>()->getElementType()) ==
classifyPrim(RHSType->getAs<ComplexType>()->getElementType()));
PrimType ElemT =
classifyPrim(LHSType->getAs<ComplexType>()->getElementType());
if (!this->visit(LHS))
return false;
if (!this->visit(RHS))
return false;
return this->emitMulc(ElemT, E);
}

// Evaluate LHS and save value to LHSOffset.
bool LHSIsComplex;
unsigned LHSOffset;
Expand Down Expand Up @@ -897,38 +913,37 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
// For both LHS and RHS, either load the value from the complex pointer, or
// directly from the local variable. For index 1 (i.e. the imaginary part),
// just load 0 and do the operation anyway.
auto loadComplexValue = [this](bool IsComplex, unsigned ElemIndex,
unsigned Offset, const Expr *E) -> bool {
auto loadComplexValue = [this](bool IsComplex, bool LoadZero,
unsigned ElemIndex, unsigned Offset,
const Expr *E) -> bool {
if (IsComplex) {
if (!this->emitGetLocal(PT_Ptr, Offset, E))
return false;
return this->emitArrayElemPop(classifyComplexElementType(E->getType()),
ElemIndex, E);
}
if (ElemIndex == 0)
if (ElemIndex == 0 || !LoadZero)
return this->emitGetLocal(classifyPrim(E->getType()), Offset, E);
return this->visitZeroInitializer(classifyPrim(E->getType()), E->getType(),
E);
};

// Now we can get pointers to the LHS and RHS from the offsets above.
BinaryOperatorKind Op = E->getOpcode();
for (unsigned ElemIndex = 0; ElemIndex != 2; ++ElemIndex) {
// Result pointer for the store later.
if (!this->DiscardResult) {
if (!this->emitGetLocal(PT_Ptr, ResultOffset, E))
return false;
}

if (!loadComplexValue(LHSIsComplex, ElemIndex, LHSOffset, LHS))
return false;

if (!loadComplexValue(RHSIsComplex, ElemIndex, RHSOffset, RHS))
return false;

// The actual operation.
switch (Op) {
case BO_Add:
if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
return false;

if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
return false;
if (ResultElemT == PT_Float) {
if (!this->emitAddf(getRoundingMode(E), E))
return false;
Expand All @@ -938,6 +953,11 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
}
break;
case BO_Sub:
if (!loadComplexValue(LHSIsComplex, true, ElemIndex, LHSOffset, LHS))
return false;

if (!loadComplexValue(RHSIsComplex, true, ElemIndex, RHSOffset, RHS))
return false;
if (ResultElemT == PT_Float) {
if (!this->emitSubf(getRoundingMode(E), E))
return false;
Expand All @@ -946,6 +966,21 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
return false;
}
break;
case BO_Mul:
if (!loadComplexValue(LHSIsComplex, false, ElemIndex, LHSOffset, LHS))
return false;

if (!loadComplexValue(RHSIsComplex, false, ElemIndex, RHSOffset, RHS))
return false;

if (ResultElemT == PT_Float) {
if (!this->emitMulf(getRoundingMode(E), E))
return false;
} else {
if (!this->emitMul(ResultElemT, E))
return false;
}
break;

default:
return false;
Expand Down
57 changes: 57 additions & 0 deletions clang/lib/AST/Interp/Interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef LLVM_CLANG_AST_INTERP_INTERP_H
#define LLVM_CLANG_AST_INTERP_INTERP_H

#include "../ExprConstShared.h"
#include "Boolean.h"
#include "Floating.h"
#include "Function.h"
Expand Down Expand Up @@ -368,6 +369,62 @@ inline bool Mulf(InterpState &S, CodePtr OpPC, llvm::RoundingMode RM) {
S.Stk.push<Floating>(Result);
return CheckFloatResult(S, OpPC, Result, Status);
}

template <PrimType Name, class T = typename PrimConv<Name>::T>
inline bool Mulc(InterpState &S, CodePtr OpPC) {
const Pointer &RHS = S.Stk.pop<Pointer>();
const Pointer &LHS = S.Stk.pop<Pointer>();
const Pointer &Result = S.Stk.peek<Pointer>();

if constexpr (std::is_same_v<T, Floating>) {
APFloat A = LHS.atIndex(0).deref<Floating>().getAPFloat();
APFloat B = LHS.atIndex(1).deref<Floating>().getAPFloat();
APFloat C = RHS.atIndex(0).deref<Floating>().getAPFloat();
APFloat D = RHS.atIndex(1).deref<Floating>().getAPFloat();

APFloat ResR(A.getSemantics());
APFloat ResI(A.getSemantics());
HandleComplexComplexMul(A, B, C, D, ResR, ResI);

// Copy into the result.
Result.atIndex(0).deref<Floating>() = Floating(ResR);
Result.atIndex(0).initialize();
Result.atIndex(1).deref<Floating>() = Floating(ResI);
Result.atIndex(1).initialize();
Result.initialize();
} else {
// Integer element type.
const T &LHSR = LHS.atIndex(0).deref<T>();
const T &LHSI = LHS.atIndex(1).deref<T>();
const T &RHSR = RHS.atIndex(0).deref<T>();
const T &RHSI = RHS.atIndex(1).deref<T>();
unsigned Bits = LHSR.bitWidth();

// real(Result) = (real(LHS) * real(RHS)) - (imag(LHS) * imag(RHS))
T A;
if (T::mul(LHSR, RHSR, Bits, &A))
return false;
T B;
if (T::mul(LHSI, RHSI, Bits, &B))
return false;
if (T::sub(A, B, Bits, &Result.atIndex(0).deref<T>()))
return false;
Result.atIndex(0).initialize();

// imag(Result) = (real(LHS) * imag(RHS)) + (imag(LHS) * real(RHS))
if (T::mul(LHSR, RHSI, Bits, &A))
return false;
if (T::mul(LHSI, RHSR, Bits, &B))
return false;
if (T::add(A, B, Bits, &Result.atIndex(1).deref<T>()))
return false;
Result.atIndex(1).initialize();
Result.initialize();
}

return true;
}

/// 1) Pops the RHS from the stack.
/// 2) Pops the LHS from the stack.
/// 3) Pushes 'LHS & RHS' on the stack
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/Interp/Opcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,10 @@ def Sub : AluOpcode;
def Subf : FloatOpcode;
def Mul : AluOpcode;
def Mulf : FloatOpcode;
def Mulc : Opcode {
let Types = [NumberTypeClass];
let HasGroup = 1;
}
def Rem : IntegerOpcode;
def Div : IntegerOpcode;
def Divf : FloatOpcode;
Expand Down
31 changes: 31 additions & 0 deletions clang/test/AST/Interp/complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,37 @@ static_assert(&__imag z1 == &__real z1 + 1, "");
static_assert((*(&__imag z1)) == __imag z1, "");
static_assert((*(&__real z1)) == __real z1, "");


static_assert(((1.25 + 0.5j) * (0.25 - 0.75j)) == (0.6875 - 0.8125j), "");
static_assert(((1.25 + 0.5j) * 0.25) == (0.3125 + 0.125j), "");
static_assert((1.25 * (0.25 - 0.75j)) == (0.3125 - 0.9375j), "");
constexpr _Complex float InfC = {1.0, __builtin_inf()};
constexpr _Complex float InfInf = __builtin_inf() + InfC;
static_assert(__real__(InfInf) == __builtin_inf());
static_assert(__imag__(InfInf) == __builtin_inf());
static_assert(__builtin_isnan(__real__(InfInf * InfInf)));
static_assert(__builtin_isinf_sign(__imag__(InfInf * InfInf)) == 1);

static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * 1.0)) == 1);
static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * 1.0)) == 1);
static_assert(__builtin_isinf_sign(__real__(1.0 * (__builtin_inf() + 1.0j))) == 1);
static_assert(__builtin_isinf_sign(__imag__(1.0 * (1.0 + InfC))) == 1);
static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (1.0 + 1.0j))) == 1);
static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) * (__builtin_inf() + 1.0j))) == 1);
static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + 1.0j))) == -1);
static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) * (1.0 + 1.0j))) == 1);
static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) * (1.0 + InfC))) == -1);
static_assert(__builtin_isinf_sign(__imag__((1.0 + 1.0j) * (1.0 + InfC))) == 1);
static_assert(__builtin_isinf_sign(__real__((1.0 + InfC) * (1.0 + InfC))) == -1);
static_assert(__builtin_isinf_sign(__real__(InfInf * InfInf)) == 0);

constexpr _Complex int IIMA = {1,2};
constexpr _Complex int IIMB = {10,20};
constexpr _Complex int IIMC = IIMA * IIMB;
static_assert(__real(IIMC) == -30, "");
static_assert(__imag(IIMC) == 40, "");

constexpr _Complex int Comma1 = {1, 2};
constexpr _Complex int Comma2 = (0, Comma1);
static_assert(Comma1 == Comma1, "");
Expand Down
Loading