Skip to content

Commit 722164f

Browse files
klauslerZijunZhaoCCK
authored andcommitted
[flang] Don't blow up when combining mixed COMPLEX operations (llvm#66235)
Expression processing applies some straightforward rewriting of mixed complex/real and complex/integer operations to avoid having to promote the real/integer operand to complex and then perform a complex operation; for example, (a,b)+x becomes (a+x,b) rather than (a,b)+(x,0). But this can blow up the expression representation when the complex operand cannot be duplicated cheaply. So apply this technique only to complex operands that are appropriate to duplicate. Fixes llvm#65142.
1 parent fd4babc commit 722164f

File tree

3 files changed

+189
-98
lines changed

3 files changed

+189
-98
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,6 @@ common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
149149

150150
Expr<SomeType> Parenthesize(Expr<SomeType> &&);
151151

152-
Expr<SomeReal> GetComplexPart(
153-
const Expr<SomeComplex> &, bool isImaginary = false);
154-
Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&, bool isImaginary = false);
155-
156-
template <int KIND>
157-
Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
158-
Expr<Type<TypeCategory::Real, KIND>> &&im) {
159-
return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
160-
}
161-
162152
template <typename A> constexpr bool IsNumericCategoryExpr() {
163153
if constexpr (common::HasMember<A, TypelessExpression>) {
164154
return false;

flang/lib/Evaluate/tools.cpp

Lines changed: 175 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,9 @@ std::optional<Expr<SomeType>> Package(
180180
std::optional<Expr<SomeKind<CAT>>> &&catExpr) {
181181
if (catExpr) {
182182
return {AsGenericExpr(std::move(*catExpr))};
183+
} else {
184+
return std::nullopt;
183185
}
184-
return NoExpr();
185186
}
186187

187188
// Mixed REAL+INTEGER operations. REAL**INTEGER is a special case that
@@ -204,6 +205,12 @@ std::optional<Expr<SomeType>> MixedRealLeft(
204205
std::move(rx.u)));
205206
}
206207

208+
template <int KIND>
209+
Expr<SomeComplex> MakeComplex(Expr<Type<TypeCategory::Real, KIND>> &&re,
210+
Expr<Type<TypeCategory::Real, KIND>> &&im) {
211+
return AsCategoryExpr(ComplexConstructor<KIND>{std::move(re), std::move(im)});
212+
}
213+
207214
std::optional<Expr<SomeComplex>> ConstructComplex(
208215
parser::ContextualMessages &messages, Expr<SomeType> &&real,
209216
Expr<SomeType> &&imaginary, int defaultRealKind) {
@@ -228,24 +235,87 @@ std::optional<Expr<SomeComplex>> ConstructComplex(
228235
return std::nullopt;
229236
}
230237

231-
Expr<SomeReal> GetComplexPart(const Expr<SomeComplex> &z, bool isImaginary) {
232-
return common::visit(
233-
[&](const auto &zk) {
234-
static constexpr int kind{ResultType<decltype(zk)>::kind};
235-
return AsCategoryExpr(ComplexComponent<kind>{isImaginary, zk});
236-
},
237-
z.u);
238-
}
238+
// Extracts the real or imaginary part of the result of a COMPLEX
239+
// expression, when that expression is simple enough to be duplicated.
240+
template <bool GET_IMAGINARY> struct ComplexPartExtractor {
241+
template <typename A> static std::optional<Expr<SomeReal>> Get(const A &) {
242+
return std::nullopt;
243+
}
239244

240-
Expr<SomeReal> GetComplexPart(Expr<SomeComplex> &&z, bool isImaginary) {
241-
return common::visit(
242-
[&](auto &&zk) {
243-
static constexpr int kind{ResultType<decltype(zk)>::kind};
244-
return AsCategoryExpr(
245-
ComplexComponent<kind>{isImaginary, std::move(zk)});
246-
},
247-
z.u);
248-
}
245+
template <int KIND>
246+
static std::optional<Expr<SomeReal>> Get(
247+
const Parentheses<Type<TypeCategory::Complex, KIND>> &kz) {
248+
if (auto x{Get(kz.left())}) {
249+
return AsGenericExpr(AsSpecificExpr(
250+
Parentheses<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
251+
} else {
252+
return std::nullopt;
253+
}
254+
}
255+
256+
template <int KIND>
257+
static std::optional<Expr<SomeReal>> Get(
258+
const Negate<Type<TypeCategory::Complex, KIND>> &kz) {
259+
if (auto x{Get(kz.left())}) {
260+
return AsGenericExpr(AsSpecificExpr(
261+
Negate<Type<TypeCategory::Real, KIND>>{std::move(*x)}));
262+
} else {
263+
return std::nullopt;
264+
}
265+
}
266+
267+
template <int KIND>
268+
static std::optional<Expr<SomeReal>> Get(
269+
const Convert<Type<TypeCategory::Complex, KIND>, TypeCategory::Complex>
270+
&kz) {
271+
if (auto x{Get(kz.left())}) {
272+
return AsGenericExpr(AsSpecificExpr(
273+
Convert<Type<TypeCategory::Real, KIND>, TypeCategory::Real>{
274+
AsGenericExpr(std::move(*x))}));
275+
} else {
276+
return std::nullopt;
277+
}
278+
}
279+
280+
template <int KIND>
281+
static std::optional<Expr<SomeReal>> Get(const ComplexConstructor<KIND> &kz) {
282+
return GET_IMAGINARY ? Get(kz.right()) : Get(kz.left());
283+
}
284+
285+
template <int KIND>
286+
static std::optional<Expr<SomeReal>> Get(
287+
const Constant<Type<TypeCategory::Complex, KIND>> &kz) {
288+
if (auto cz{kz.GetScalarValue()}) {
289+
return AsGenericExpr(
290+
AsSpecificExpr(GET_IMAGINARY ? cz->AIMAG() : cz->REAL()));
291+
} else {
292+
return std::nullopt;
293+
}
294+
}
295+
296+
template <int KIND>
297+
static std::optional<Expr<SomeReal>> Get(
298+
const Designator<Type<TypeCategory::Complex, KIND>> &kz) {
299+
if (const auto *symbolRef{std::get_if<SymbolRef>(&kz.u)}) {
300+
return AsGenericExpr(AsSpecificExpr(
301+
Designator<Type<TypeCategory::Complex, KIND>>{ComplexPart{
302+
DataRef{*symbolRef},
303+
GET_IMAGINARY ? ComplexPart::Part::IM : ComplexPart::Part::RE}}));
304+
} else {
305+
return std::nullopt;
306+
}
307+
}
308+
309+
template <int KIND>
310+
static std::optional<Expr<SomeReal>> Get(
311+
const Expr<Type<TypeCategory::Complex, KIND>> &kz) {
312+
return Get(kz.u);
313+
}
314+
315+
static std::optional<Expr<SomeReal>> Get(const Expr<SomeComplex> &z) {
316+
return Get(z.u);
317+
}
318+
};
249319

250320
// Convert REAL to COMPLEX of the same kind. Preserving the real operand kind
251321
// and then applying complex operand promotion rules allows the result to have
@@ -266,56 +336,48 @@ Expr<SomeComplex> PromoteRealToComplex(Expr<SomeReal> &&someX) {
266336
// corresponding COMPLEX+COMPLEX operation.
267337
template <template <typename> class OPR, TypeCategory RCAT>
268338
std::optional<Expr<SomeType>> MixedComplexLeft(
269-
parser::ContextualMessages &messages, Expr<SomeComplex> &&zx,
270-
Expr<SomeKind<RCAT>> &&iry, [[maybe_unused]] int defaultRealKind) {
271-
Expr<SomeReal> zr{GetComplexPart(zx, false)};
272-
Expr<SomeReal> zi{GetComplexPart(zx, true)};
273-
if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
339+
parser::ContextualMessages &messages, const Expr<SomeComplex> &zx,
340+
const Expr<SomeKind<RCAT>> &iry, [[maybe_unused]] int defaultRealKind) {
341+
if constexpr (RCAT == TypeCategory::Integer &&
342+
std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
343+
// COMPLEX**INTEGER is a special case that doesn't convert the exponent.
344+
return Package(common::visit(
345+
[&](const auto &zxk) {
346+
using Ty = ResultType<decltype(zxk)>;
347+
return AsCategoryExpr(AsExpr(
348+
RealToIntPower<Ty>{common::Clone(zxk), common::Clone(iry)}));
349+
},
350+
zx.u));
351+
}
352+
std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zx)};
353+
std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zx)};
354+
if (!zr || !zi) {
355+
} else if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>> ||
274356
std::is_same_v<OPR<LargestReal>, Subtract<LargestReal>>) {
275357
// (a,b) + x -> (a+x, b)
276358
// (a,b) - x -> (a-x, b)
277359
if (std::optional<Expr<SomeType>> rr{
278-
NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
279-
AsGenericExpr(std::move(iry)), defaultRealKind)}) {
360+
NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
361+
AsGenericExpr(common::Clone(iry)), defaultRealKind)}) {
280362
return Package(ConstructComplex(messages, std::move(*rr),
281-
AsGenericExpr(std::move(zi)), defaultRealKind));
363+
AsGenericExpr(std::move(*zi)), defaultRealKind));
282364
}
283365
} else if constexpr (allowOperandDuplication &&
284366
(std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>> ||
285367
std::is_same_v<OPR<LargestReal>, Divide<LargestReal>>)) {
286368
// (a,b) * x -> (a*x, b*x)
287369
// (a,b) / x -> (a/x, b/x)
288370
auto copy{iry};
289-
auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zr)),
290-
AsGenericExpr(std::move(iry)), defaultRealKind)};
291-
auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(zi)),
371+
auto rr{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zr)),
372+
AsGenericExpr(common::Clone(iry)), defaultRealKind)};
373+
auto ri{NumericOperation<OPR>(messages, AsGenericExpr(std::move(*zi)),
292374
AsGenericExpr(std::move(copy)), defaultRealKind)};
293375
if (auto parts{common::AllPresent(std::move(rr), std::move(ri))}) {
294376
return Package(ConstructComplex(messages, std::get<0>(std::move(*parts)),
295377
std::get<1>(std::move(*parts)), defaultRealKind));
296378
}
297-
} else if constexpr (RCAT == TypeCategory::Integer &&
298-
std::is_same_v<OPR<LargestReal>, Power<LargestReal>>) {
299-
// COMPLEX**INTEGER is a special case that doesn't convert the exponent.
300-
static_assert(RCAT == TypeCategory::Integer);
301-
return Package(common::visit(
302-
[&](auto &&zxk) {
303-
using Ty = ResultType<decltype(zxk)>;
304-
return AsCategoryExpr(
305-
AsExpr(RealToIntPower<Ty>{std::move(zxk), std::move(iry)}));
306-
},
307-
std::move(zx.u)));
308-
} else {
309-
// (a,b) ** x -> (a,b) ** (x,0)
310-
if constexpr (RCAT == TypeCategory::Integer) {
311-
Expr<SomeComplex> zy{ConvertTo(zx, std::move(iry))};
312-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
313-
} else {
314-
Expr<SomeComplex> zy{PromoteRealToComplex(std::move(iry))};
315-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
316-
}
317379
}
318-
return NoExpr();
380+
return std::nullopt;
319381
}
320382

321383
// Mixed COMPLEX operations with the COMPLEX operand on the right.
@@ -325,39 +387,49 @@ std::optional<Expr<SomeType>> MixedComplexLeft(
325387
// x / (a,b) -> (x,0) / (a,b) (and **)
326388
template <template <typename> class OPR, TypeCategory LCAT>
327389
std::optional<Expr<SomeType>> MixedComplexRight(
328-
parser::ContextualMessages &messages, Expr<SomeKind<LCAT>> &&irx,
329-
Expr<SomeComplex> &&zy, [[maybe_unused]] int defaultRealKind) {
390+
parser::ContextualMessages &messages, const Expr<SomeKind<LCAT>> &irx,
391+
const Expr<SomeComplex> &zy, [[maybe_unused]] int defaultRealKind) {
330392
if constexpr (std::is_same_v<OPR<LargestReal>, Add<LargestReal>>) {
331393
// x + (a,b) -> (a,b) + x -> (a+x, b)
332-
return MixedComplexLeft<OPR, LCAT>(
333-
messages, std::move(zy), std::move(irx), defaultRealKind);
394+
return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
334395
} else if constexpr (allowOperandDuplication &&
335396
std::is_same_v<OPR<LargestReal>, Multiply<LargestReal>>) {
336397
// x * (a,b) -> (a,b) * x -> (a*x, b*x)
337-
return MixedComplexLeft<OPR, LCAT>(
338-
messages, std::move(zy), std::move(irx), defaultRealKind);
398+
return MixedComplexLeft<OPR, LCAT>(messages, zy, irx, defaultRealKind);
339399
} else if constexpr (std::is_same_v<OPR<LargestReal>,
340400
Subtract<LargestReal>>) {
341401
// x - (a,b) -> (x-a, -b)
342-
Expr<SomeReal> zr{GetComplexPart(zy, false)};
343-
Expr<SomeReal> zi{GetComplexPart(zy, true)};
344-
if (std::optional<Expr<SomeType>> rr{
345-
NumericOperation<Subtract>(messages, AsGenericExpr(std::move(irx)),
346-
AsGenericExpr(std::move(zr)), defaultRealKind)}) {
347-
return Package(ConstructComplex(messages, std::move(*rr),
348-
AsGenericExpr(-std::move(zi)), defaultRealKind));
349-
}
350-
} else {
351-
// x / (a,b) -> (x,0) / (a,b)
352-
if constexpr (LCAT == TypeCategory::Integer) {
353-
Expr<SomeComplex> zx{ConvertTo(zy, std::move(irx))};
354-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
355-
} else {
356-
Expr<SomeComplex> zx{PromoteRealToComplex(std::move(irx))};
357-
return Package(PromoteAndCombine<OPR>(std::move(zx), std::move(zy)));
402+
std::optional<Expr<SomeReal>> zr{ComplexPartExtractor<false>{}.Get(zy)};
403+
std::optional<Expr<SomeReal>> zi{ComplexPartExtractor<true>{}.Get(zy)};
404+
if (zr && zi) {
405+
if (std::optional<Expr<SomeType>> rr{NumericOperation<Subtract>(messages,
406+
AsGenericExpr(common::Clone(irx)), AsGenericExpr(std::move(*zr)),
407+
defaultRealKind)}) {
408+
return Package(ConstructComplex(messages, std::move(*rr),
409+
AsGenericExpr(-std::move(*zi)), defaultRealKind));
410+
}
358411
}
359412
}
360-
return NoExpr();
413+
return std::nullopt;
414+
}
415+
416+
// Promotes REAL(rk) and COMPLEX(zk) operands COMPLEX(max(rk,zk))
417+
// then combine them with an operator.
418+
template <template <typename> class OPR, TypeCategory XCAT, TypeCategory YCAT>
419+
Expr<SomeComplex> PromoteMixedComplexReal(
420+
Expr<SomeKind<XCAT>> &&x, Expr<SomeKind<YCAT>> &&y) {
421+
static_assert(XCAT == TypeCategory::Complex || YCAT == TypeCategory::Complex);
422+
static_assert(XCAT == TypeCategory::Real || YCAT == TypeCategory::Real);
423+
return common::visit(
424+
[&](const auto &kx, const auto &ky) {
425+
constexpr int maxKind{std::max(
426+
ResultType<decltype(kx)>::kind, ResultType<decltype(ky)>::kind)};
427+
using ZTy = Type<TypeCategory::Complex, maxKind>;
428+
return Expr<SomeComplex>{
429+
Expr<ZTy>{OPR<ZTy>{ConvertToType<ZTy>(std::move(x)),
430+
ConvertToType<ZTy>(std::move(y))}}};
431+
},
432+
x.u, y.u);
361433
}
362434

363435
// N.B. When a "typeless" BOZ literal constant appears as one (not both!) of
@@ -397,20 +469,40 @@ std::optional<Expr<SomeType>> NumericOperation(
397469
std::move(zx), std::move(zy)));
398470
},
399471
[&](Expr<SomeComplex> &&zx, Expr<SomeInteger> &&iy) {
400-
return MixedComplexLeft<OPR>(
401-
messages, std::move(zx), std::move(iy), defaultRealKind);
472+
if (auto result{
473+
MixedComplexLeft<OPR>(messages, zx, iy, defaultRealKind)}) {
474+
return result;
475+
} else {
476+
return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
477+
std::move(zx), ConvertTo(zx, std::move(iy))));
478+
}
402479
},
403480
[&](Expr<SomeComplex> &&zx, Expr<SomeReal> &&ry) {
404-
return MixedComplexLeft<OPR>(
405-
messages, std::move(zx), std::move(ry), defaultRealKind);
481+
if (auto result{
482+
MixedComplexLeft<OPR>(messages, zx, ry, defaultRealKind)}) {
483+
return result;
484+
} else {
485+
return Package(
486+
PromoteMixedComplexReal<OPR>(std::move(zx), std::move(ry)));
487+
}
406488
},
407489
[&](Expr<SomeInteger> &&ix, Expr<SomeComplex> &&zy) {
408-
return MixedComplexRight<OPR>(
409-
messages, std::move(ix), std::move(zy), defaultRealKind);
490+
if (auto result{MixedComplexRight<OPR>(
491+
messages, ix, zy, defaultRealKind)}) {
492+
return result;
493+
} else {
494+
return Package(PromoteAndCombine<OPR, TypeCategory::Complex>(
495+
ConvertTo(zy, std::move(ix)), std::move(zy)));
496+
}
410497
},
411498
[&](Expr<SomeReal> &&rx, Expr<SomeComplex> &&zy) {
412-
return MixedComplexRight<OPR>(
413-
messages, std::move(rx), std::move(zy), defaultRealKind);
499+
if (auto result{MixedComplexRight<OPR>(
500+
messages, rx, zy, defaultRealKind)}) {
501+
return result;
502+
} else {
503+
return Package(
504+
PromoteMixedComplexReal<OPR>(std::move(rx), std::move(zy)));
505+
}
414506
},
415507
// Operations with one typeless operand
416508
[&](BOZLiteralConstant &&bx, Expr<SomeInteger> &&iy) {
@@ -433,7 +525,6 @@ std::optional<Expr<SomeType>> NumericOperation(
433525
},
434526
// Default case
435527
[&](auto &&, auto &&) {
436-
// TODO: defined operator
437528
messages.Say("non-numeric operands to numeric operation"_err_en_US);
438529
return NoExpr();
439530
},
@@ -481,17 +572,14 @@ std::optional<Expr<SomeType>> Negation(
481572
[&](Expr<SomeReal> &&x) { return Package(-std::move(x)); },
482573
[&](Expr<SomeComplex> &&x) { return Package(-std::move(x)); },
483574
[&](Expr<SomeCharacter> &&) {
484-
// TODO: defined operator
485575
messages.Say("CHARACTER cannot be negated"_err_en_US);
486576
return NoExpr();
487577
},
488578
[&](Expr<SomeLogical> &&) {
489-
// TODO: defined operator
490579
messages.Say("LOGICAL cannot be negated"_err_en_US);
491580
return NoExpr();
492581
},
493582
[&](Expr<SomeDerived> &&) {
494-
// TODO: defined operator
495583
messages.Say("Operand cannot be negated"_err_en_US);
496584
return NoExpr();
497585
},
@@ -643,8 +731,7 @@ std::optional<Expr<SomeType>> ConvertToType(
643731
if (auto length{type.GetCharLength()}) {
644732
converted = common::visit(
645733
[&](auto &&x) {
646-
using Ty = std::decay_t<decltype(x)>;
647-
using CharacterType = typename Ty::Result;
734+
using CharacterType = ResultType<decltype(x)>;
648735
return Expr<SomeCharacter>{
649736
Expr<CharacterType>{SetLength<CharacterType::kind>{
650737
std::move(x), std::move(*length)}}};
@@ -1099,7 +1186,7 @@ static std::optional<Expr<SomeType>> DataConstantConversionHelper(
10991186
if (const auto *someExpr{UnwrapExpr<Expr<SomeKind<FROM>>>(*sized)}) {
11001187
return common::visit(
11011188
[](const auto &w) -> std::optional<Expr<SomeType>> {
1102-
using FromType = typename std::decay_t<decltype(w)>::Result;
1189+
using FromType = ResultType<decltype(w)>;
11031190
static constexpr int kind{FromType::kind};
11041191
if constexpr (IsValidKindOfIntrinsicType(TO, kind)) {
11051192
if (const auto *fromConst{UnwrapExpr<Constant<FromType>>(w)}) {

0 commit comments

Comments
 (0)