Skip to content

Commit 82e1e41

Browse files
authored
[flang][runtime] Treatment of NaN in MAXVAL/MAXLOC/MINVAL/MINLOC (#76999)
Detect NaN elements in data and handle them like gfortran does (at runtime); namely, NaN can be returned if all the data are NaNs, but any non-NaN value is preferable. Ensure that folding returns the same results as runtime computation. Fixes llvm-test-suite/Fortran/gfortran/regression/maxloc_2.f90 (and probably others).
1 parent 927b8a0 commit 82e1e41

File tree

10 files changed

+173
-103
lines changed

10 files changed

+173
-103
lines changed

flang/docs/Extensions.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,10 @@ end
657657
we don't round. This seems to be how the Intel Fortran compilers
658658
behave.
659659

660+
* For real `MAXVAL`, `MINVAL`, `MAXLOC`, and `MINLOC`, NaN values are
661+
essentially ignored unless there are some unmasked array entries and
662+
*all* of them are NaNs.
663+
660664
## De Facto Standard Features
661665

662666
* `EXTENDS_TYPE_OF()` returns `.TRUE.` if both of its arguments have the

flang/lib/Evaluate/fold-character.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
8484
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
8585
} else if (name == "minval") {
8686
// Collating sequences correspond to positive integers (3.31)
87-
SingleCharType most{0x7fffffff >> (8 * (4 - KIND))};
87+
auto most{static_cast<SingleCharType>(0xffffffff >> (8 * (4 - KIND)))};
8888
if (auto identity{Identity<T>(
8989
StringType{most}, GetConstantLength(context, funcRef, 0))}) {
9090
return FoldMaxvalMinval<T>(

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ template <typename T, int MASK_KIND> class CountAccumulator {
270270

271271
public:
272272
CountAccumulator(const Constant<MaskT> &mask) : mask_{mask} {}
273-
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
273+
void operator()(
274+
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
274275
if (mask_.At(at).IsTrue()) {
275276
auto incremented{element.AddSigned(Scalar<T>{1})};
276277
overflow_ |= incremented.overflow;
@@ -287,22 +288,20 @@ template <typename T, int MASK_KIND> class CountAccumulator {
287288

288289
template <typename T, int maskKind>
289290
static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
290-
using LogicalResult = Type<TypeCategory::Logical, maskKind>;
291+
using KindLogical = Type<TypeCategory::Logical, maskKind>;
291292
static_assert(T::category == TypeCategory::Integer);
292-
ActualArguments &arg{ref.arguments()};
293-
if (const Constant<LogicalResult> *mask{arg.empty()
294-
? nullptr
295-
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
296-
std::optional<int> dim;
297-
if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
298-
CountAccumulator<T, maskKind> accumulator{*mask};
299-
Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
300-
if (accumulator.overflow()) {
301-
context.messages().Say(
302-
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
303-
}
304-
return Expr<T>{std::move(result)};
293+
std::optional<int> dim;
294+
if (std::optional<ArrayAndMask<KindLogical>> arrayAndMask{
295+
ProcessReductionArgs<KindLogical>(
296+
context, ref.arguments(), dim, /*ARRAY=*/0, /*DIM=*/1)}) {
297+
CountAccumulator<T, maskKind> accumulator{arrayAndMask->array};
298+
Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
299+
dim, Scalar<T>{}, accumulator)};
300+
if (accumulator.overflow()) {
301+
context.messages().Say(
302+
"Result of intrinsic function COUNT overflows its result type"_warn_en_US);
305303
}
304+
return Expr<T>{std::move(result)};
306305
}
307306
return Expr<T>{std::move(ref)};
308307
}
@@ -395,7 +394,7 @@ template <WhichLocation WHICH> class LocationHelper {
395394
for (ConstantSubscript k{0}; k < dimLength;
396395
++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
397396
if ((!mask || mask->At(maskAt).IsTrue()) &&
398-
IsHit(array->At(at), value, relation)) {
397+
IsHit(array->At(at), value, relation, back)) {
399398
hit = at[zbDim];
400399
if constexpr (WHICH == WhichLocation::Findloc) {
401400
if (!back) {
@@ -422,7 +421,7 @@ template <WhichLocation WHICH> class LocationHelper {
422421
for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
423422
mask && mask->IncrementSubscripts(maskAt)) {
424423
if ((!mask || mask->At(maskAt).IsTrue()) &&
425-
IsHit(array->At(at), value, relation)) {
424+
IsHit(array->At(at), value, relation, back)) {
426425
resultIndices = at;
427426
if constexpr (WHICH == WhichLocation::Findloc) {
428427
if (!back) {
@@ -444,7 +443,8 @@ template <WhichLocation WHICH> class LocationHelper {
444443
template <typename T>
445444
bool IsHit(typename Constant<T>::Element element,
446445
std::optional<Constant<T>> &value,
447-
[[maybe_unused]] RelationalOperator relation) const {
446+
[[maybe_unused]] RelationalOperator relation,
447+
[[maybe_unused]] bool back) const {
448448
std::optional<Expr<LogicalResult>> cmp;
449449
bool result{true};
450450
if (value) {
@@ -455,8 +455,19 @@ template <WhichLocation WHICH> class LocationHelper {
455455
Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv,
456456
Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}}));
457457
} else { // compare array(at) to value
458-
cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
459-
Expr<T>{Constant<T>{*value}}));
458+
if constexpr (T::category == TypeCategory::Real &&
459+
(WHICH == WhichLocation::Maxloc ||
460+
WHICH == WhichLocation::Minloc)) {
461+
if (value && value->GetScalarValue().value().IsNotANumber() &&
462+
(back || !element.IsNotANumber())) {
463+
// Replace NaN
464+
cmp.emplace(Constant<LogicalResult>{Scalar<LogicalResult>{true}});
465+
}
466+
}
467+
if (!cmp) {
468+
cmp.emplace(PackageRelation(relation, Expr<T>{Constant<T>{element}},
469+
Expr<T>{Constant<T>{*value}}));
470+
}
460471
}
461472
Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
462473
result = GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
@@ -523,11 +534,12 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
523534
Scalar<T> identity) {
524535
static_assert(T::category == TypeCategory::Integer);
525536
std::optional<int> dim;
526-
if (std::optional<Constant<T>> array{
527-
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
537+
if (std::optional<ArrayAndMask<T>> arrayAndMask{
538+
ProcessReductionArgs<T>(context, ref.arguments(), dim,
528539
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
529-
OperationAccumulator<T> accumulator{*array, operation};
530-
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
540+
OperationAccumulator<T> accumulator{arrayAndMask->array, operation};
541+
return Expr<T>{DoReduction<T>(
542+
arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
531543
}
532544
return Expr<T>{std::move(ref)};
533545
}

flang/lib/Evaluate/fold-logical.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ static Expr<T> FoldAllAnyParity(FoldingContext &context, FunctionRef<T> &&ref,
3131
Scalar<T> identity) {
3232
static_assert(T::category == TypeCategory::Logical);
3333
std::optional<int> dim;
34-
if (std::optional<Constant<T>> array{
35-
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
34+
if (std::optional<ArrayAndMask<T>> arrayAndMask{
35+
ProcessReductionArgs<T>(context, ref.arguments(), dim,
3636
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {
37-
OperationAccumulator accumulator{*array, operation};
38-
return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
37+
OperationAccumulator accumulator{arrayAndMask->array, operation};
38+
return Expr<T>{DoReduction<T>(
39+
arrayAndMask->array, arrayAndMask->mask, dim, identity, accumulator)};
3940
}
4041
return Expr<T>{std::move(ref)};
4142
}

flang/lib/Evaluate/fold-real.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ template <int KIND> class Norm2Accumulator {
5252
Norm2Accumulator(
5353
const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
5454
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
55-
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
55+
void operator()(
56+
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
5657
// Kahan summation of scaled elements:
5758
// Naively,
5859
// NORM2(A(:)) = SQRT(SUM(A(:)**2))
@@ -114,17 +115,18 @@ static Expr<Type<TypeCategory::Real, KIND>> FoldNorm2(FoldingContext &context,
114115
using T = Type<TypeCategory::Real, KIND>;
115116
using Element = typename Constant<T>::Element;
116117
std::optional<int> dim;
117-
const Element identity{};
118-
if (std::optional<Constant<T>> array{
119-
ProcessReductionArgs<T>(context, funcRef.arguments(), dim, identity,
118+
if (std::optional<ArrayAndMask<T>> arrayAndMask{
119+
ProcessReductionArgs<T>(context, funcRef.arguments(), dim,
120120
/*X=*/0, /*DIM=*/1)}) {
121121
MaxvalMinvalAccumulator<T, /*ABS=*/true> maxAbsAccumulator{
122-
RelationalOperator::GT, context, *array};
123-
Constant<T> maxAbs{
124-
DoReduction<T>(*array, dim, identity, maxAbsAccumulator)};
125-
Norm2Accumulator norm2Accumulator{
126-
*array, maxAbs, context.targetCharacteristics().roundingMode()};
127-
Constant<T> result{DoReduction<T>(*array, dim, identity, norm2Accumulator)};
122+
RelationalOperator::GT, context, arrayAndMask->array};
123+
const Element identity{};
124+
Constant<T> maxAbs{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
125+
dim, identity, maxAbsAccumulator)};
126+
Norm2Accumulator norm2Accumulator{arrayAndMask->array, maxAbs,
127+
context.targetCharacteristics().roundingMode()};
128+
Constant<T> result{DoReduction<T>(arrayAndMask->array, arrayAndMask->mask,
129+
dim, identity, norm2Accumulator)};
128130
if (norm2Accumulator.overflow()) {
129131
context.messages().Say(
130132
"NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND);

0 commit comments

Comments
 (0)