Skip to content

Commit daa5da0

Browse files
authored
[flang] Don't blow up when combining mixed COMPLEX operations (#66235)
Expression processing applies some straightforward rewriting of mixed complex/real and complex/integer operations to avoid having to promote the real/integer operand to complex and then perform a complex operation; for example, (a,b)+x becomes (a+x,b) rather than (a,b)+(x,0). But this can blow up the expression representation when the complex operand cannot be duplicated cheaply. So apply this technique only to complex operands that are appropriate to duplicate. Fixes #65142.
1 parent c8c075e commit daa5da0

File tree

3 files changed

+189
-98
lines changed

3 files changed

+189
-98
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,6 @@ common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
149149

150150
Expr<SomeType> Parenthesize(Expr<SomeType> &&);
151151

152-
Expr<SomeReal> GetComplexPart(
153-
const Expr<SomeComplex> &, bool isImaginary = false);
154-
Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&, bool isImaginary = false);
155-
156-
template <int KIND>
157-
Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
158-
Expr<Type<TypeCategory::Real, KIND>> &&im) {
159-
return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
160-
}
161-
162152
template <typename A> constexpr bool IsNumericCategoryExpr() {
163153
if constexpr (common::HasMember<A, TypelessExpression>) {
164154
return false;

flang/lib/Evaluate/tools.cpp

Lines changed: 175 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ std::optional<Expr<SomeType>> Package(
180180
std::optional<Expr<SomeKind<CAT>>> &&catExpr) {
181181
if (catExpr) {
182182
return {AsGenericExpr(std::move(*catExpr))};
183+
} else {
184+
return std::nullopt;
183185
}
184-
return NoExpr();
185186
}
186187

187188
// Mixed REAL+INTEGER operations. REAL**INTEGER is a special case that
@@ -204,6 +205,12 @@ std::optional<Expr<SomeType>> MixedRealLeft(
204205
std::move(rx.u)));
205206
}
206207

208+
template <int KIND>
209+
Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
210+
Expr<Type<TypeCategory::Real, KIND>> &&im) {
211+
return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
212+
}
213+
207214
std::optional<Expr<SomeComplex>> ConstructComplex(
208215
parser::ContextualMessages &messages, Expr<SomeType> &&real,
209216
Expr<SomeType> &&imaginary, int defaultRealKind) {
@@ -228,24 +235,87 @@ std::optional<Expr<SomeComplex>> ConstructComplex(
228235
return std::nullopt;
229236
}
230237

231-
Expr<SomeReal> GetComplexPart(const Expr<SomeComplex> &z, bool isImaginary) {
232-
return common::visit(
233-
[&](const auto &zk) {
234-
static constexpr int kind{ResultType<decltype(zk)>::kind};
235-
return AsCategoryExpr(ComplexComponent<kind>{isImaginary, zk});
236-
},
237-
z.u);
238-
}
238+
// Extracts the real or imaginary part of the result of a COMPLEX
239+
// expression, when that expression is simple enough to be duplicated.
240+
template <bool GET_IMAGINARY> struct ComplexPartExtractor {
241+
template <typename A> static std::optional<Expr<SomeReal>> Get(const A &) {
242+
return std::nullopt;
243+
}
239244

240-
Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&z, bool isImaginary) {
241-
return common::visit(
242-
[&](auto &&zk) {
243-
static constexpr int kind{ResultType<decltype(zk)>::kind};
244-
return AsCategoryExpr(
245-
ComplexComponent<kind>{isImaginary, std::move(zk)});
246-
},
247-
z.u);
248-
}
245+
template <int KIND>
246+
static std::optional<Expr<SomeReal>> Get(
247+
const Parentheses<Type<TypeCategory::Complex, KIND>> &kz) {
248+
if (auto x{Get(kz.left())}) {
249+
return AsGenericExpr(AsSpecificExpr(
250+
Parentheses<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
251+
} else {
252+
return std::nullopt;
253+
}
254+
}
255+
256+
template <int KIND>
257+
static std::optional<Expr<SomeReal>> Get(
258+
const Negate<Type<TypeCategory::Complex, KIND>> &kz) {
259+
if (auto x{Get(kz.left())}) {
260+
return AsGenericExpr(AsSpecificExpr(
261+
Negate<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
262+
} else {
263+
return std::nullopt;
264+
}
265+
}
266+
267+
template <int KIND>
268+
static std::optional<Expr<SomeReal>> Get(
269+
const Convert<Type<TypeCategory::Complex, KIND>, TypeCategory::Complex>
270+
&kz) {
271+
if (auto x{Get(kz.left())}) {
272+
return AsGenericExpr(AsSpecificExpr(
273+
Convert<Type<TypeCategory::Real, KIND>, TypeCategory::Real>{
274+
AsGenericExpr(std::move(*x))}));
275+
} else {
276+
return std::nullopt;
277+
}
278+
}
279+
280+
template <int KIND>
281+
static std::optional<Expr<SomeReal>> Get(const ComplexConstructor<KIND> &kz) {
282+
return GET_IMAGINARY ? Get(kz.right()) : Get(kz.left());
283+
}
284+
285+
template <int KIND>
286+
static std::optional<Expr<SomeReal>> Get(
287+
const Constant<Type<TypeCategory::Complex, KIND>> &kz) {
288+
if (auto cz{kz.GetScalarValue()}) {
289+
return AsGenericExpr(
290+
AsSpecificExpr(GET_IMAGINARY ? cz->AIMAG() : cz->REAL()));
291+
} else {
292+
return std::nullopt;
293+
}
294+
}
295+
296+
template <int KIND>
297+
static std::optional<Expr<SomeReal>> Get(
298+
const Designator<Type<TypeCategory::Complex, KIND>> &kz) {
299+
if (const auto *symbolRef{std::get_if<SymbolRef>(&kz.u)}) {
300+
return AsGenericExpr(AsSpecificExpr(
301+
Designator<Type<TypeCategory::Complex, KIND>>{ComplexPart{
302+
DataRef{*symbolRef},
303+
GET_IMAGINARY ? ComplexPart::Part::IM : ComplexPart::Part::RE}}));
304+
} else {
305+
return std::nullopt;
306+
}
307+
}
308+
309+
template <int KIND>
310+
static std::optional<Expr<SomeReal>> Get(
311+
const Expr<Type<TypeCategory::Complex, KIND>> &kz) {
312+
return Get(kz.u);
313+
}
314+
315+
static std::optional<Expr<SomeReal>> Get(const Expr<SomeComplex> &z) {
316+
return Get(z.u);
317+
}
318+
};
249319

250320
// Convert REAL to COMPLEX of the same kind. Preserving the real operand kind
251321
// and then applying complex operand promotion rules allows the result to have
@@ -266,56 +336,48 @@ Expr<SomeComplex> PromoteRealToComplex(Expr<SomeReal> &&someX) {
266336
// corresponding COMPLEX+COMPLEX operation.
267337
template <template <typename> class OPR, TypeCategory RCAT>
268338
std::optional<Expr<SomeType>> MixedComplexLeft(
269-
parser::ContextualMessages &messages, Expr<SomeComplex> &&zx,
270-
Expr<SomeKind<RCAT>> &&iry, [[maybe_unused]] int defaultRealKind) {
271-
Expr<SomeReal> zr{GetComplexPart(zx, false)};
272-
Expr<SomeReal> zi{GetComplexPart(zx, true)};
273-
if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
339+
parser::ContextualMessages &messages, const Expr<SomeComplex> &zx,
340+
const Expr<SomeKind<RCAT>> &iry, [[maybe_unused]] int defaultRealKind) {
341+
if constexpr (RCAT == TypeCategory::Integer &&
342+
std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
343+
// COMPLEX**INTEGER is a special case that doesn't convert the exponent.
344+
return Package(common::visit(
345+
[&](const auto &zxk) {
346+
using Ty = ResultType<decltype(zxk)>;
347+
return AsCategoryExpr(AsExpr(
348+
RealToIntPower<Ty>{common::Clone(zxk), common::Clone(iry)}));
349+
},
350+
zx.u));
351+
}
352+
std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zx)};
353+
std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zx)};
354+
if (!zr || !zi) {
355+
} else if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
274356
std::is_same_v<OPR<LargestReal>, Subtract<LargestReal>>) {
275357
// (a,b) + x -> (a+x, b)
276358
// (a,b) - x -> (a-x, b)
277359
if (std::optional<Expr<SomeType>> rr{
278-
NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
279-
AsGenericExpr(std::move(iry)), defaultRealKind)}) {
360+
NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
361+
AsGenericExpr(common::Clone(iry)), defaultRealKind)}) {
280362
return Package(ConstructComplex(messages, std::move(*rr),
281-
AsGenericExpr(std::move(zi)), defaultRealKind));
363+
AsGenericExpr(std::move(*zi)), defaultRealKind));
282364
}
283365
} else if constexpr (allowOperandDuplication &&
284366
(std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>> ||
285367
std::is_same_v<OPR<LargestReal>, Divide<LargestReal>>)) {
286368
// (a,b) * x -> (a*x, b*x)
287369
// (a,b) / x -> (a/x, b/x)
288370
auto copy{iry};
289-
auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
290-
AsGenericExpr(std::move(iry)), defaultRealKind)};
291-
auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zi)),
371+
auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
372+
AsGenericExpr(common::Clone(iry)), defaultRealKind)};
373+
auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zi)),
292374
AsGenericExpr(std::move(copy)), defaultRealKind)};
293375
if (auto parts{common::AllPresent(std::move(rr), std::move(ri))}) {
294376
return Package(ConstructComplex(messages, std::get<0>(std::move(*parts)),
295377
std::get<1>(std::move(*parts)), defaultRealKind));
296378
}
297-
} else if constexpr (RCAT == TypeCategory::Integer &&
298-
std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
299-
// COMPLEX**INTEGER is a special case that doesn't convert the exponent.
300-
static_assert(RCAT == TypeCategory::Integer);
301-
return Package(common::visit(
302-
[&](auto &&zxk) {
303-
using Ty = ResultType<decltype(zxk)>;
304-
return AsCategoryExpr(
305-
AsExpr(RealToIntPower<Ty>{std::move(zxk), std::move(iry)}));
306-
},
307-
std::move(zx.u)));
308-
} else {
309-
// (a,b) ** x -> (a,b) ** (x,0)
310-
if constexpr (RCAT == TypeCategory::Integer) {
311-
Expr<SomeComplex> zy{ConvertTo(zx, std::move(iry))};
312-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
313-
} else {
314-
Expr<SomeComplex> zy{PromoteRealToComplex(std::move(iry))};
315-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
316-
}
317379
}
318-
return NoExpr();
380+
return std::nullopt;
319381
}
320382

321383
// Mixed COMPLEX operations with the COMPLEX operand on the right.
@@ -325,39 +387,49 @@ std::optional<Expr<SomeType>> MixedComplexLeft(
325387
// x / (a,b) -> (x,0) / (a,b) (and **)
326388
template <template <typename> class OPR, TypeCategory LCAT>
327389
std::optional<Expr<SomeType>> MixedComplexRight(
328-
parser::ContextualMessages &messages, Expr<SomeKind<LCAT>> &&irx,
329-
Expr<SomeComplex> &&zy, [[maybe_unused]] int defaultRealKind) {
390+
parser::ContextualMessages &messages, const Expr<SomeKind<LCAT>> &irx,
391+
const Expr<SomeComplex> &zy, [[maybe_unused]] int defaultRealKind) {
330392
if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>>) {
331393
// x + (a,b) -> (a,b) + x -> (a+x, b)
332-
return MixedComplexLeft<OPR, LCAT>(
333-
messages, std::move(zy), std::move(irx), defaultRealKind);
394+
return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
334395
} else if constexpr (allowOperandDuplication &&
335396
std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>>) {
336397
// x * (a,b) -> (a,b) * x -> (a*x, b*x)
337-
return MixedComplexLeft<OPR, LCAT>(
338-
messages, std::move(zy), std::move(irx), defaultRealKind);
398+
return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
339399
} else if constexpr (std::is_same_v<OPR<LargestReal>,
340400
Subtract<LargestReal>>) {
341401
// x - (a,b) -> (x-a, -b)
342-
Expr<SomeReal> zr{GetComplexPart(zy, false)};
343-
Expr<SomeReal> zi{GetComplexPart(zy, true)};
344-
if (std::optional<Expr<SomeType>> rr{
345-
NumericOperation<Subtract>(messages, AsGenericExpr(std::move(irx)),
346-
AsGenericExpr(std::move(zr)), defaultRealKind)}) {
347-
return Package(ConstructComplex(messages, std::move(*rr),
348-
AsGenericExpr(-std::move(zi)), defaultRealKind));
349-
}
350-
} else {
351-
// x / (a,b) -> (x,0) / (a,b)
352-
if constexpr (LCAT == TypeCategory::Integer) {
353-
Expr<SomeComplex> zx{ConvertTo(zy, std::move(irx))};
354-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
355-
} else {
356-
Expr<SomeComplex> zx{PromoteRealToComplex(std::move(irx))};
357-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
402+
std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zy)};
403+
std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zy)};
404+
if (zr && zi) {
405+
if (std::optional<Expr<SomeType>> rr{NumericOperation<Subtract>(messages,
406+
AsGenericExpr(common::Clone(irx)), AsGenericExpr(std::move(*zr)),
407+
defaultRealKind)}) {
408+
return Package(ConstructComplex(messages, std::move(*rr),
409+
AsGenericExpr(-std::move(*zi)), defaultRealKind));
410+
}
358411
}
359412
}
360-
return NoExpr();
413+
return std::nullopt;
414+
}
415+
416+
// Promotes REAL(rk) and COMPLEX(zk) operands COMPLEX(max(rk,zk))
417+
// then combine them with an operator.
418+
template <template <typename> class OPR, TypeCategory XCAT, TypeCategory YCAT>
419+
Expr<SomeComplex> PromoteMixedComplexReal(
420+
Expr<SomeKind<XCAT>> &&x, Expr<SomeKind<YCAT>> &&y) {
421+
static_assert(XCAT == TypeCategory::Complex || YCAT == TypeCategory::Complex);
422+
static_assert(XCAT == TypeCategory::Real || YCAT == TypeCategory::Real);
423+
return common::visit(
424+
[&](const auto &kx, const auto &ky) {
425+
constexpr int maxKind{std::max(
426+
ResultType<decltype(kx)>::kind, ResultType<decltype(ky)>::kind)};
427+
using ZTy = Type<TypeCategory::Complex, maxKind>;
428+
return Expr<SomeComplex>{
429+
Expr<ZTy>{OPR<ZTy>{ConvertToType<ZTy>(std::move(x)),
430+
ConvertToType<ZTy>(std::move(y))}}};
431+
},
432+
x.u, y.u);
361433
}
362434

363435
// N.B. When a "typeless" BOZ literal constant appears as one (not both!) of
@@ -397,20 +469,40 @@ std::optional<Expr<SomeType>> NumericOperation(
397469
std::move(zx), std::move(zy)));
398470
},
399471
[&](Expr<SomeComplex> &&zx, Expr<SomeInteger> &&iy) {
400-
return MixedComplexLeft<OPR>(
401-
messages, std::move(zx), std::move(iy), defaultRealKind);
472+
if (auto result{
473+
MixedComplexLeft<OPR>(messages, zx, iy, defaultRealKind)}) {
474+
return result;
475+
} else {
476+
return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
477+
std::move(zx), ConvertTo(zx, std::move(iy))));
478+
}
402479
},
403480
[&](Expr<SomeComplex> &&zx, Expr<SomeReal> &&ry) {
404-
return MixedComplexLeft<OPR>(
405-
messages, std::move(zx), std::move(ry), defaultRealKind);
481+
if (auto result{
482+
MixedComplexLeft<OPR>(messages, zx, ry, defaultRealKind)}) {
483+
return result;
484+
} else {
485+
return Package(
486+
PromoteMixedComplexReal<OPR>(std::move(zx), std::move(ry)));
487+
}
406488
},
407489
[&](Expr<SomeInteger> &&ix, Expr<SomeComplex> &&zy) {
408-
return MixedComplexRight<OPR>(
409-
messages, std::move(ix), std::move(zy), defaultRealKind);
490+
if (auto result{MixedComplexRight<OPR>(
491+
messages, ix, zy, defaultRealKind)}) {
492+
return result;
493+
} else {
494+
return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
495+
ConvertTo(zy, std::move(ix)), std::move(zy)));
496+
}
410497
},
411498
[&](Expr<SomeReal> &&rx, Expr<SomeComplex> &&zy) {
412-
return MixedComplexRight<OPR>(
413-
messages, std::move(rx), std::move(zy), defaultRealKind);
499+
if (auto result{MixedComplexRight<OPR>(
500+
messages, rx, zy, defaultRealKind)}) {
501+
return result;
502+
} else {
503+
return Package(
504+
PromoteMixedComplexReal<OPR>(std::move(rx), std::move(zy)));
505+
}
414506
},
415507
// Operations with one typeless operand
416508
[&](BOZLiteralConstant &&bx, Expr<SomeInteger> &&iy) {
@@ -433,7 +525,6 @@ std::optional<Expr<SomeType>> NumericOperation(
433525
},
434526
// Default case
435527
[&](auto &&, auto &&) {
436-
// TODO: defined operator
437528
messages.Say("non-numeric operands to numeric operation"_err_en_US);
438529
return NoExpr();
439530
},
@@ -481,17 +572,14 @@ std::optional<Expr<SomeType>> Negation(
481572
[&](Expr<SomeReal> &&x) { return Package(-std::move(x)); },
482573
[&](Expr<SomeComplex> &&x) { return Package(-std::move(x)); },
483574
[&](Expr<SomeCharacter> &&) {
484-
// TODO: defined operator
485575
messages.Say("CHARACTER cannot be negated"_err_en_US);
486576
return NoExpr();
487577
},
488578
[&](Expr<SomeLogical> &&) {
489-
// TODO: defined operator
490579
messages.Say("LOGICAL cannot be negated"_err_en_US);
491580
return NoExpr();
492581
},
493582
[&](Expr<SomeDerived> &&) {
494-
// TODO: defined operator
495583
messages.Say("Operand cannot be negated"_err_en_US);
496584
return NoExpr();
497585
},
@@ -643,8 +731,7 @@ std::optional<Expr<SomeType>> ConvertToType(
643731
if (auto length{type.GetCharLength()}) {
644732
converted = common::visit(
645733
[&](auto &&x) {
646-
using Ty = std::decay_t<decltype(x)>;
647-
using CharacterType = typename Ty::Result;
734+
using CharacterType = ResultType<decltype(x)>;
648735
return Expr<SomeCharacter>{
649736
Expr<CharacterType>{SetLength<CharacterType::kind>{
650737
std::move(x), std::move(*length)}}};
@@ -1099,7 +1186,7 @@ static std::optional<Expr<SomeType>> DataConstantConversionHelper(
10991186
if (const auto *someExpr{UnwrapExpr<Expr<SomeKind<FROM>>>(*sized)}) {
11001187
return common::visit(
11011188
[](const auto &w) -> std::optional<Expr<SomeType>> {
1102-
using FromType = typename std::decay_t<decltype(w)>::Result;
1189+
using FromType = ResultType<decltype(w)>;
11031190
static constexpr int kind{FromType::kind};
11041191
if constexpr (IsValidKindOfIntrinsicType(TO, kind)) {
11051192
if (const auto *fromConst{UnwrapExpr<Constant<FromType>>(w)}) {

0 commit comments

Comments
 (0)