@@ -270,7 +270,8 @@ template <typename T, int MASK_KIND> class CountAccumulator {
270
270
271
271
public:
272
272
CountAccumulator (const Constant<MaskT> &mask) : mask_{mask} {}
273
- void operator ()(Scalar<T> &element, const ConstantSubscripts &at) {
273
+ void operator ()(
274
+ Scalar<T> &element, const ConstantSubscripts &at, bool /* first*/ ) {
274
275
if (mask_.At (at).IsTrue ()) {
275
276
auto incremented{element.AddSigned (Scalar<T>{1 })};
276
277
overflow_ |= incremented.overflow ;
@@ -287,22 +288,20 @@ template <typename T, int MASK_KIND> class CountAccumulator {
287
288
288
289
template <typename T, int maskKind>
289
290
static Expr<T> FoldCount (FoldingContext &context, FunctionRef<T> &&ref) {
290
- using LogicalResult = Type<TypeCategory::Logical, maskKind>;
291
+ using KindLogical = Type<TypeCategory::Logical, maskKind>;
291
292
static_assert (T::category == TypeCategory::Integer);
292
- ActualArguments &arg{ref.arguments ()};
293
- if (const Constant<LogicalResult> *mask{arg.empty ()
294
- ? nullptr
295
- : Folder<LogicalResult>{context}.Folding (arg[0 ])}) {
296
- std::optional<int > dim;
297
- if (CheckReductionDIM (dim, context, arg, 1 , mask->Rank ())) {
298
- CountAccumulator<T, maskKind> accumulator{*mask};
299
- Constant<T> result{DoReduction<T>(*mask, dim, Scalar<T>{}, accumulator)};
300
- if (accumulator.overflow ()) {
301
- context.messages ().Say (
302
- " Result of intrinsic function COUNT overflows its result type" _warn_en_US);
303
- }
304
- return Expr<T>{std::move (result)};
293
+ std::optional<int > dim;
294
+ if (std::optional<ArrayAndMask<KindLogical>> arrayAndMask{
295
+ ProcessReductionArgs<KindLogical>(
296
+ context, ref.arguments (), dim, /* ARRAY=*/ 0 , /* DIM=*/ 1 )}) {
297
+ CountAccumulator<T, maskKind> accumulator{arrayAndMask->array };
298
+ Constant<T> result{DoReduction<T>(arrayAndMask->array , arrayAndMask->mask ,
299
+ dim, Scalar<T>{}, accumulator)};
300
+ if (accumulator.overflow ()) {
301
+ context.messages ().Say (
302
+ " Result of intrinsic function COUNT overflows its result type" _warn_en_US);
305
303
}
304
+ return Expr<T>{std::move (result)};
306
305
}
307
306
return Expr<T>{std::move (ref)};
308
307
}
@@ -395,7 +394,7 @@ template <WhichLocation WHICH> class LocationHelper {
395
394
for (ConstantSubscript k{0 }; k < dimLength;
396
395
++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
397
396
if ((!mask || mask->At (maskAt).IsTrue ()) &&
398
- IsHit (array->At (at), value, relation)) {
397
+ IsHit (array->At (at), value, relation, back )) {
399
398
hit = at[zbDim];
400
399
if constexpr (WHICH == WhichLocation::Findloc) {
401
400
if (!back) {
@@ -422,7 +421,7 @@ template <WhichLocation WHICH> class LocationHelper {
422
421
for (ConstantSubscript j{0 }; j < n; ++j, array->IncrementSubscripts (at),
423
422
mask && mask->IncrementSubscripts (maskAt)) {
424
423
if ((!mask || mask->At (maskAt).IsTrue ()) &&
425
- IsHit (array->At (at), value, relation)) {
424
+ IsHit (array->At (at), value, relation, back )) {
426
425
resultIndices = at;
427
426
if constexpr (WHICH == WhichLocation::Findloc) {
428
427
if (!back) {
@@ -444,7 +443,8 @@ template <WhichLocation WHICH> class LocationHelper {
444
443
template <typename T>
445
444
bool IsHit (typename Constant<T>::Element element,
446
445
std::optional<Constant<T>> &value,
447
- [[maybe_unused]] RelationalOperator relation) const {
446
+ [[maybe_unused]] RelationalOperator relation,
447
+ [[maybe_unused]] bool back) const {
448
448
std::optional<Expr<LogicalResult>> cmp;
449
449
bool result{true };
450
450
if (value) {
@@ -455,8 +455,19 @@ template <WhichLocation WHICH> class LocationHelper {
455
455
Expr<T>{LogicalOperation<T::kind>{LogicalOperator::Eqv,
456
456
Expr<T>{Constant<T>{element}}, Expr<T>{Constant<T>{*value}}}}));
457
457
} else { // compare array(at) to value
458
- cmp.emplace (PackageRelation (relation, Expr<T>{Constant<T>{element}},
459
- Expr<T>{Constant<T>{*value}}));
458
+ if constexpr (T::category == TypeCategory::Real &&
459
+ (WHICH == WhichLocation::Maxloc ||
460
+ WHICH == WhichLocation::Minloc)) {
461
+ if (value && value->GetScalarValue ().value ().IsNotANumber () &&
462
+ (back || !element.IsNotANumber ())) {
463
+ // Replace NaN
464
+ cmp.emplace (Constant<LogicalResult>{Scalar<LogicalResult>{true }});
465
+ }
466
+ }
467
+ if (!cmp) {
468
+ cmp.emplace (PackageRelation (relation, Expr<T>{Constant<T>{element}},
469
+ Expr<T>{Constant<T>{*value}}));
470
+ }
460
471
}
461
472
Expr<LogicalResult> folded{Fold (context_, std::move (*cmp))};
462
473
result = GetScalarConstantValue<LogicalResult>(folded).value ().IsTrue ();
@@ -523,11 +534,12 @@ static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
523
534
Scalar<T> identity) {
524
535
static_assert (T::category == TypeCategory::Integer);
525
536
std::optional<int > dim;
526
- if (std::optional<Constant <T>> array {
527
- ProcessReductionArgs<T>(context, ref.arguments (), dim, identity,
537
+ if (std::optional<ArrayAndMask <T>> arrayAndMask {
538
+ ProcessReductionArgs<T>(context, ref.arguments (), dim,
528
539
/* ARRAY=*/ 0 , /* DIM=*/ 1 , /* MASK=*/ 2 )}) {
529
- OperationAccumulator<T> accumulator{*array, operation};
530
- return Expr<T>{DoReduction<T>(*array, dim, identity, accumulator)};
540
+ OperationAccumulator<T> accumulator{arrayAndMask->array , operation};
541
+ return Expr<T>{DoReduction<T>(
542
+ arrayAndMask->array , arrayAndMask->mask , dim, identity, accumulator)};
531
543
}
532
544
return Expr<T>{std::move (ref)};
533
545
}
0 commit comments