Skip to content

Commit 8926f0f

Browse files
committed
[flang] Fold MERGE() of derived type values
Generalize FoldMerge() to accommodate derived type arguments and results, rename it into Folder<T>::MERGE(), and remove it from the various FoldIntrinsicFunction() routines for intrinsic types. Fixes llvm-test-suite/Fortran/gfortran/regression/merge_init_expr_2.f90. Differential Revision: https://reviews.llvm.org/D157345
1 parent 6b8e338 commit 8926f0f

File tree

7 files changed

+47
-27
lines changed

7 files changed

+47
-27
lines changed

flang/lib/Evaluate/fold-character.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
8080
return FoldMaxvalMinval<T>(
8181
context, std::move(funcRef), RelationalOperator::GT, *identity);
8282
}
83-
} else if (name == "merge") {
84-
return FoldMerge<T>(context, std::move(funcRef));
8583
} else if (name == "min") {
8684
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
8785
} else if (name == "minval") {

flang/lib/Evaluate/fold-complex.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
6464
}
6565
} else if (name == "dot_product") {
6666
return FoldDotProduct<T>(context, std::move(funcRef));
67-
} else if (name == "merge") {
68-
return FoldMerge<T>(context, std::move(funcRef));
6967
} else if (name == "product") {
7068
auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
7169
return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});

flang/lib/Evaluate/fold-implementation.h

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ template <typename T> class Folder {
6464

6565
Expr<T> CSHIFT(FunctionRef<T> &&);
6666
Expr<T> EOSHIFT(FunctionRef<T> &&);
67+
Expr<T> MERGE(FunctionRef<T> &&);
6768
Expr<T> PACK(FunctionRef<T> &&);
6869
Expr<T> RESHAPE(FunctionRef<T> &&);
6970
Expr<T> SPREAD(FunctionRef<T> &&);
@@ -397,9 +398,11 @@ template <typename T> Expr<T> Folder<T>::Folding(Designator<T> &&designator) {
397398
template <typename T>
398399
Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
399400
if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
400-
if (!UnwrapExpr<Expr<T>>(*expr)) {
401-
if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
402-
*expr = Fold(context_, std::move(*converted));
401+
if constexpr (T::category != TypeCategory::Derived) {
402+
if (!UnwrapExpr<Expr<T>>(*expr)) {
403+
if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
404+
*expr = Fold(context_, std::move(*converted));
405+
}
403406
}
404407
}
405408
return UnwrapConstantValue<T>(*expr);
@@ -411,8 +414,6 @@ template <typename... A, std::size_t... I>
411414
std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
412415
FoldingContext &context, ActualArguments &arguments,
413416
std::index_sequence<I...>) {
414-
static_assert(
415-
(... && IsSpecificIntrinsicType<A>)); // TODO derived types for MERGE?
416417
static_assert(sizeof...(A) > 0);
417418
std::tuple<const Constant<A> *...> args{
418419
Folder<A>{context}.Folding(arguments.at(I))...};
@@ -489,7 +490,6 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
489490
}
490491
}
491492
CHECK(rank == GetRank(shape));
492-
493493
// Compute all the scalar values of the results
494494
std::vector<Scalar<TR>> results;
495495
if (TotalElementCount(shape) > 0) {
@@ -513,6 +513,13 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
513513
auto len{static_cast<ConstantSubscript>(
514514
results.empty() ? 0 : results[0].length())};
515515
return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
516+
} else if constexpr (TR::category == TypeCategory::Derived) {
517+
if (!results.empty()) {
518+
return Expr<TR>{rank == 0
519+
? Constant<TR>{results.front()}
520+
: Constant<TR>{results.front().derivedTypeSpec(),
521+
std::move(results), std::move(shape)}};
522+
}
516523
} else {
517524
return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
518525
}
@@ -780,6 +787,16 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
780787
return MakeInvalidIntrinsic(std::move(funcRef));
781788
}
782789

790+
template <typename T> Expr<T> Folder<T>::MERGE(FunctionRef<T> &&funcRef) {
791+
return FoldElementalIntrinsic<T, T, T, LogicalResult>(context_,
792+
std::move(funcRef),
793+
ScalarFunc<T, T, T, LogicalResult>(
794+
[](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
795+
const Scalar<LogicalResult> &predicate) -> Scalar<T> {
796+
return predicate.IsTrue() ? ifTrue : ifFalse;
797+
}));
798+
}
799+
783800
template <typename T> Expr<T> Folder<T>::PACK(FunctionRef<T> &&funcRef) {
784801
auto args{funcRef.arguments()};
785802
CHECK(args.size() == 3);
@@ -1126,6 +1143,8 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
11261143
return Folder<T>{context}.CSHIFT(std::move(funcRef));
11271144
} else if (name == "eoshift") {
11281145
return Folder<T>{context}.EOSHIFT(std::move(funcRef));
1146+
} else if (name == "merge") {
1147+
return Folder<T>{context}.MERGE(std::move(funcRef));
11291148
} else if (name == "pack") {
11301149
return Folder<T>{context}.PACK(std::move(funcRef));
11311150
} else if (name == "reshape") {
@@ -1147,17 +1166,6 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
11471166
return Expr<T>{std::move(funcRef)};
11481167
}
11491168

1150-
template <typename T>
1151-
Expr<T> FoldMerge(FoldingContext &context, FunctionRef<T> &&funcRef) {
1152-
return FoldElementalIntrinsic<T, T, T, LogicalResult>(context,
1153-
std::move(funcRef),
1154-
ScalarFunc<T, T, T, LogicalResult>(
1155-
[](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
1156-
const Scalar<LogicalResult> &predicate) -> Scalar<T> {
1157-
return predicate.IsTrue() ? ifTrue : ifFalse;
1158-
}));
1159-
}
1160-
11611169
Expr<ImpliedDoIndex::Result> FoldOperation(FoldingContext &, ImpliedDoIndex &&);
11621170

11631171
// Array constructor folding

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,8 +1038,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
10381038
} else if (name == "maxval") {
10391039
return FoldMaxvalMinval<T>(context, std::move(funcRef),
10401040
RelationalOperator::GT, T::Scalar::Least());
1041-
} else if (name == "merge") {
1042-
return FoldMerge<T>(context, std::move(funcRef));
10431041
} else if (name == "merge_bits") {
10441042
return FoldElementalIntrinsic<T, T, T, T>(
10451043
context, std::move(funcRef), &Scalar<T>::MERGE_BITS);

flang/lib/Evaluate/fold-logical.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
215215
if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
216216
return Fold(context, ConvertToType<T>(std::move(*expr)));
217217
}
218-
} else if (name == "merge") {
219-
return FoldMerge<T>(context, std::move(funcRef));
220218
} else if (name == "parity") {
221219
return FoldAllAnyParity(
222220
context, std::move(funcRef), &Scalar<T>::NEQV, Scalar<T>{false});

flang/lib/Evaluate/fold-real.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
184184
} else if (name == "maxval") {
185185
return FoldMaxvalMinval<T>(context, std::move(funcRef),
186186
RelationalOperator::GT, T::Scalar::HUGE().Negate());
187-
} else if (name == "merge") {
188-
return FoldMerge<T>(context, std::move(funcRef));
189187
} else if (name == "min") {
190188
return FoldMINorMAX(context, std::move(funcRef), Ordering::Less);
191189
} else if (name == "minval") {

flang/test/Evaluate/fold-merge.f90

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
! RUN: %python %S/test_folding.py %s %flang_fc1
2+
! Tests folding of MERGE
3+
module m
4+
type t
5+
integer n
6+
end type
7+
logical, parameter :: test_01 = all(merge([1,2,3],4,[.true.,.false.,.true.]) == [1,4,3])
8+
logical, parameter :: test_02 = all(merge([1,2,3],4,.true.) == [1,2,3])
9+
logical, parameter :: test_03 = all(merge([1,2,3],4,.false.) == [4,4,4])
10+
logical, parameter :: test_04 = all(merge(1,4,[.true.,.false.,.true.,.false.]) == [1,4,1,4])
11+
type(t), parameter :: dt00a = merge(t(1),t(2),.true.)
12+
logical, parameter :: test_05 = dt00a%n == 1
13+
type(t), parameter :: dt00b = merge(t(1),t(2),.false.)
14+
logical, parameter :: test_06 = dt00b%n == 2
15+
type(t), parameter :: dt01(*) = merge([t(1),t(2)],[t(3),t(4)],[.false.,.true.])
16+
logical, parameter :: test_07 = all(dt01%n == [3,2])
17+
type(t), parameter :: dt02(*) = merge(t(1),[t(3),t(4)],.true.)
18+
logical, parameter :: test_08 = all(dt02%n == [1,1])
19+
type(t), parameter :: dt03(*) = merge([t(1),t(2)],t(3),[.true.,.false.])
20+
logical, parameter :: test_09 = all(dt03%n == [1,3])
21+
logical, parameter :: test_10 = merge('ab','cd',.true.) == 'ab'
22+
end

0 commit comments

Comments
 (0)