Skip to content

[flang] Don't blow up when combining mixed COMPLEX operations #66235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,6 @@ common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(

Expr<SomeType> Parenthesize(Expr<SomeType> &&);

Expr<SomeReal> GetComplexPart(
const Expr<SomeComplex> &, bool isImaginary = false);
Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&, bool isImaginary = false);

template <int KIND>
Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
Expr<Type<TypeCategory::Real, KIND>> &&im) {
return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
}

template <typename A> constexpr bool IsNumericCategoryExpr() {
if constexpr (common::HasMember<A, TypelessExpression>) {
return false;
Expand Down
263 changes: 175 additions & 88 deletions flang/lib/Evaluate/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ std::optional<Expr<SomeType>> Package(
std::optional<Expr<SomeKind<CAT>>> &&catExpr) {
if (catExpr) {
return {AsGenericExpr(std::move(*catExpr))};
} else {
return std::nullopt;
}
return NoExpr();
}

// Mixed REAL+INTEGER operations. REAL**INTEGER is a special case that
Expand All @@ -204,6 +205,12 @@ std::optional<Expr<SomeType>> MixedRealLeft(
std::move(rx.u)));
}

template <int KIND>
Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
Expr<Type<TypeCategory::Real, KIND>> &&im) {
return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
}

std::optional<Expr<SomeComplex>> ConstructComplex(
parser::ContextualMessages &messages, Expr<SomeType> &&real,
Expr<SomeType> &&imaginary, int defaultRealKind) {
Expand All @@ -228,24 +235,87 @@ std::optional<Expr<SomeComplex>> ConstructComplex(
return std::nullopt;
}

Expr<SomeReal> GetComplexPart(const Expr<SomeComplex> &z, bool isImaginary) {
return common::visit(
[&](const auto &zk) {
static constexpr int kind{ResultType<decltype(zk)>::kind};
return AsCategoryExpr(ComplexComponent<kind>{isImaginary, zk});
},
z.u);
}
// Extracts the real or imaginary part of the result of a COMPLEX
// expression, when that expression is simple enough to be duplicated.
template <bool GET_IMAGINARY> struct ComplexPartExtractor {
template <typename A> static std::optional<Expr<SomeReal>> Get(const A &) {
return std::nullopt;
}

Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&z, bool isImaginary) {
return common::visit(
[&](auto &&zk) {
static constexpr int kind{ResultType<decltype(zk)>::kind};
return AsCategoryExpr(
ComplexComponent<kind>{isImaginary, std::move(zk)});
},
z.u);
}
template <int KIND>
static std::optional<Expr<SomeReal>> Get(
const Parentheses<Type<TypeCategory::Complex, KIND>> &kz) {
if (auto x{Get(kz.left())}) {
return AsGenericExpr(AsSpecificExpr(
Parentheses<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
} else {
return std::nullopt;
}
}

template <int KIND>
static std::optional<Expr<SomeReal>> Get(
const Negate<Type<TypeCategory::Complex, KIND>> &kz) {
if (auto x{Get(kz.left())}) {
return AsGenericExpr(AsSpecificExpr(
Negate<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
} else {
return std::nullopt;
}
}

template <int KIND>
static std::optional<Expr<SomeReal>> Get(
const Convert<Type<TypeCategory::Complex, KIND>, TypeCategory::Complex>
&kz) {
if (auto x{Get(kz.left())}) {
return AsGenericExpr(AsSpecificExpr(
Convert<Type<TypeCategory::Real, KIND>, TypeCategory::Real>{
AsGenericExpr(std::move(*x))}));
} else {
return std::nullopt;
}
}

template <int KIND>
static std::optional<Expr<SomeReal>> Get(const ComplexConstructor<KIND> &kz) {
return GET_IMAGINARY ? Get(kz.right()) : Get(kz.left());
}

template <int KIND>
static std::optional<Expr<SomeReal>> Get(
const Constant<Type<TypeCategory::Complex, KIND>> &kz) {
if (auto cz{kz.GetScalarValue()}) {
return AsGenericExpr(
AsSpecificExpr(GET_IMAGINARY ? cz->AIMAG() : cz->REAL()));
} else {
return std::nullopt;
}
}

template <int KIND>
static std::optional<Expr<SomeReal>> Get(
const Designator<Type<TypeCategory::Complex, KIND>> &kz) {
if (const auto *symbolRef{std::get_if<SymbolRef>(&kz.u)}) {
return AsGenericExpr(AsSpecificExpr(
Designator<Type<TypeCategory::Complex, KIND>>{ComplexPart{
DataRef{*symbolRef},
GET_IMAGINARY ? ComplexPart::Part::IM : ComplexPart::Part::RE}}));
} else {
return std::nullopt;
}
}

template <int KIND>
static std::optional<Expr<SomeReal>> Get(
const Expr<Type<TypeCategory::Complex, KIND>> &kz) {
return Get(kz.u);
}

static std::optional<Expr<SomeReal>> Get(const Expr<SomeComplex> &z) {
return Get(z.u);
}
};

// Convert REAL to COMPLEX of the same kind. Preserving the real operand kind
// and then applying complex operand promotion rules allows the result to have
Expand All @@ -266,56 +336,48 @@ Expr<SomeComplex> PromoteRealToComplex(Expr<SomeReal> &&someX) {
// corresponding COMPLEX+COMPLEX operation.
template <template <typename> class OPR, TypeCategory RCAT>
std::optional<Expr<SomeType>> MixedComplexLeft(
parser::ContextualMessages &messages, Expr<SomeComplex> &&zx,
Expr<SomeKind<RCAT>> &&iry, [[maybe_unused]] int defaultRealKind) {
Expr<SomeReal> zr{GetComplexPart(zx, false)};
Expr<SomeReal> zi{GetComplexPart(zx, true)};
if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
parser::ContextualMessages &messages, const Expr<SomeComplex> &zx,
const Expr<SomeKind<RCAT>> &iry, [[maybe_unused]] int defaultRealKind) {
if constexpr (RCAT == TypeCategory::Integer &&
std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
// COMPLEX**INTEGER is a special case that doesn't convert the exponent.
return Package(common::visit(
[&](const auto &zxk) {
using Ty = ResultType<decltype(zxk)>;
return AsCategoryExpr(AsExpr(
RealToIntPower<Ty>{common::Clone(zxk), common::Clone(iry)}));
},
zx.u));
}
std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zx)};
std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zx)};
if (!zr || !zi) {
} else if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
std::is_same_v<OPR<LargestReal>, Subtract<LargestReal>>) {
// (a,b) + x -> (a+x, b)
// (a,b) - x -> (a-x, b)
if (std::optional<Expr<SomeType>> rr{
NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
AsGenericExpr(std::move(iry)), defaultRealKind)}) {
NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
AsGenericExpr(common::Clone(iry)), defaultRealKind)}) {
return Package(ConstructComplex(messages, std::move(*rr),
AsGenericExpr(std::move(zi)), defaultRealKind));
AsGenericExpr(std::move(*zi)), defaultRealKind));
}
} else if constexpr (allowOperandDuplication &&
(std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>> ||
std::is_same_v<OPR<LargestReal>, Divide<LargestReal>>)) {
// (a,b) * x -> (a*x, b*x)
// (a,b) / x -> (a/x, b/x)
auto copy{iry};
auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
AsGenericExpr(std::move(iry)), defaultRealKind)};
auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zi)),
auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
AsGenericExpr(common::Clone(iry)), defaultRealKind)};
auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zi)),
AsGenericExpr(std::move(copy)), defaultRealKind)};
if (auto parts{common::AllPresent(std::move(rr), std::move(ri))}) {
return Package(ConstructComplex(messages, std::get<0>(std::move(*parts)),
std::get<1>(std::move(*parts)), defaultRealKind));
}
} else if constexpr (RCAT == TypeCategory::Integer &&
std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
// COMPLEX**INTEGER is a special case that doesn't convert the exponent.
static_assert(RCAT == TypeCategory::Integer);
return Package(common::visit(
[&](auto &&zxk) {
using Ty = ResultType<decltype(zxk)>;
return AsCategoryExpr(
AsExpr(RealToIntPower<Ty>{std::move(zxk), std::move(iry)}));
},
std::move(zx.u)));
} else {
// (a,b) ** x -> (a,b) ** (x,0)
if constexpr (RCAT == TypeCategory::Integer) {
Expr<SomeComplex> zy{ConvertTo(zx, std::move(iry))};
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
} else {
Expr<SomeComplex> zy{PromoteRealToComplex(std::move(iry))};
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
}
}
return NoExpr();
return std::nullopt;
}

// Mixed COMPLEX operations with the COMPLEX operand on the right.
Expand All @@ -325,39 +387,49 @@ std::optional<Expr<SomeType>> MixedComplexLeft(
// x / (a,b) -> (x,0) / (a,b) (and **)
template <template <typename> class OPR, TypeCategory LCAT>
std::optional<Expr<SomeType>> MixedComplexRight(
parser::ContextualMessages &messages, Expr<SomeKind<LCAT>> &&irx,
Expr<SomeComplex> &&zy, [[maybe_unused]] int defaultRealKind) {
parser::ContextualMessages &messages, const Expr<SomeKind<LCAT>> &irx,
const Expr<SomeComplex> &zy, [[maybe_unused]] int defaultRealKind) {
if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>>) {
// x + (a,b) -> (a,b) + x -> (a+x, b)
return MixedComplexLeft<OPR, LCAT>(
messages, std::move(zy), std::move(irx), defaultRealKind);
return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
} else if constexpr (allowOperandDuplication &&
std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>>) {
// x * (a,b) -> (a,b) * x -> (a*x, b*x)
return MixedComplexLeft<OPR, LCAT>(
messages, std::move(zy), std::move(irx), defaultRealKind);
return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
} else if constexpr (std::is_same_v<OPR<LargestReal>,
Subtract<LargestReal>>) {
// x - (a,b) -> (x-a, -b)
Expr<SomeReal> zr{GetComplexPart(zy, false)};
Expr<SomeReal> zi{GetComplexPart(zy, true)};
if (std::optional<Expr<SomeType>> rr{
NumericOperation<Subtract>(messages, AsGenericExpr(std::move(irx)),
AsGenericExpr(std::move(zr)), defaultRealKind)}) {
return Package(ConstructComplex(messages, std::move(*rr),
AsGenericExpr(-std::move(zi)), defaultRealKind));
}
} else {
// x / (a,b) -> (x,0) / (a,b)
if constexpr (LCAT == TypeCategory::Integer) {
Expr<SomeComplex> zx{ConvertTo(zy, std::move(irx))};
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
} else {
Expr<SomeComplex> zx{PromoteRealToComplex(std::move(irx))};
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zy)};
std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zy)};
if (zr && zi) {
if (std::optional<Expr<SomeType>> rr{NumericOperation<Subtract>(messages,
AsGenericExpr(common::Clone(irx)), AsGenericExpr(std::move(*zr)),
defaultRealKind)}) {
return Package(ConstructComplex(messages, std::move(*rr),
AsGenericExpr(-std::move(*zi)), defaultRealKind));
}
}
}
return NoExpr();
return std::nullopt;
}

// Promotes REAL(rk) and COMPLEX(zk) operands COMPLEX(max(rk,zk))
// then combine them with an operator.
template <template <typename> class OPR, TypeCategory XCAT, TypeCategory YCAT>
Expr<SomeComplex> PromoteMixedComplexReal(
Expr<SomeKind<XCAT>> &&x, Expr<SomeKind<YCAT>> &&y) {
static_assert(XCAT == TypeCategory::Complex || YCAT == TypeCategory::Complex);
static_assert(XCAT == TypeCategory::Real || YCAT == TypeCategory::Real);
return common::visit(
[&](const auto &kx, const auto &ky) {
constexpr int maxKind{std::max(
ResultType<decltype(kx)>::kind, ResultType<decltype(ky)>::kind)};
using ZTy = Type<TypeCategory::Complex, maxKind>;
return Expr<SomeComplex>{
Expr<ZTy>{OPR<ZTy>{ConvertToType<ZTy>(std::move(x)),
ConvertToType<ZTy>(std::move(y))}}};
},
x.u, y.u);
}

// N.B. When a "typeless" BOZ literal constant appears as one (not both!) of
Expand Down Expand Up @@ -397,20 +469,40 @@ std::optional<Expr<SomeType>> NumericOperation(
std::move(zx), std::move(zy)));
},
[&](Expr<SomeComplex> &&zx, Expr<SomeInteger> &&iy) {
return MixedComplexLeft<OPR>(
messages, std::move(zx), std::move(iy), defaultRealKind);
if (auto result{
MixedComplexLeft<OPR>(messages, zx, iy, defaultRealKind)}) {
return result;
} else {
return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
std::move(zx), ConvertTo(zx, std::move(iy))));
}
},
[&](Expr<SomeComplex> &&zx, Expr<SomeReal> &&ry) {
return MixedComplexLeft<OPR>(
messages, std::move(zx), std::move(ry), defaultRealKind);
if (auto result{
MixedComplexLeft<OPR>(messages, zx, ry, defaultRealKind)}) {
return result;
} else {
return Package(
PromoteMixedComplexReal<OPR>(std::move(zx), std::move(ry)));
}
},
[&](Expr<SomeInteger> &&ix, Expr<SomeComplex> &&zy) {
return MixedComplexRight<OPR>(
messages, std::move(ix), std::move(zy), defaultRealKind);
if (auto result{MixedComplexRight<OPR>(
messages, ix, zy, defaultRealKind)}) {
return result;
} else {
return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
ConvertTo(zy, std::move(ix)), std::move(zy)));
}
},
[&](Expr<SomeReal> &&rx, Expr<SomeComplex> &&zy) {
return MixedComplexRight<OPR>(
messages, std::move(rx), std::move(zy), defaultRealKind);
if (auto result{MixedComplexRight<OPR>(
messages, rx, zy, defaultRealKind)}) {
return result;
} else {
return Package(
PromoteMixedComplexReal<OPR>(std::move(rx), std::move(zy)));
}
},
// Operations with one typeless operand
[&](BOZLiteralConstant &&bx, Expr<SomeInteger> &&iy) {
Expand All @@ -433,7 +525,6 @@ std::optional<Expr<SomeType>> NumericOperation(
},
// Default case
[&](auto &&, auto &&) {
// TODO: defined operator
messages.Say("non-numeric operands to numeric operation"_err_en_US);
return NoExpr();
},
Expand Down Expand Up @@ -481,17 +572,14 @@ std::optional<Expr<SomeType>> Negation(
[&](Expr<SomeReal> &&x) { return Package(-std::move(x)); },
[&](Expr<SomeComplex> &&x) { return Package(-std::move(x)); },
[&](Expr<SomeCharacter> &&) {
// TODO: defined operator
messages.Say("CHARACTER cannot be negated"_err_en_US);
return NoExpr();
},
[&](Expr<SomeLogical> &&) {
// TODO: defined operator
messages.Say("LOGICAL cannot be negated"_err_en_US);
return NoExpr();
},
[&](Expr<SomeDerived> &&) {
// TODO: defined operator
messages.Say("Operand cannot be negated"_err_en_US);
return NoExpr();
},
Expand Down Expand Up @@ -643,8 +731,7 @@ std::optional<Expr<SomeType>> ConvertToType(
if (auto length{type.GetCharLength()}) {
converted = common::visit(
[&](auto &&x) {
using Ty = std::decay_t<decltype(x)>;
using CharacterType = typename Ty::Result;
using CharacterType = ResultType<decltype(x)>;
return Expr<SomeCharacter>{
Expr<CharacterType>{SetLength<CharacterType::kind>{
std::move(x), std::move(*length)}}};
Expand Down Expand Up @@ -1099,7 +1186,7 @@ static std::optional<Expr<SomeType>> DataConstantConversionHelper(
if (const auto *someExpr{UnwrapExpr<Expr<SomeKind<FROM>>>(*sized)}) {
return common::visit(
[](const auto &w) -> std::optional<Expr<SomeType>> {
using FromType = typename std::decay_t<decltype(w)>::Result;
using FromType = ResultType<decltype(w)>;
static constexpr int kind{FromType::kind};
if constexpr (IsValidKindOfIntrinsicType(TO, kind)) {
if (const auto *fromConst{UnwrapExpr<Constant<FromType>>(w)}) {
Expand Down
Loading