Skip to content

Commit 0bbb2d0

Browse files
committed
[flang] Fold CSHIFT
Implement folding of the transformational intrinsic function CSHIFT for all types. Differential Revision: https://reviews.llvm.org/D108931
1 parent db9de22 commit 0bbb2d0

File tree

9 files changed

+123
-20
lines changed

9 files changed

+123
-20
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,23 @@ class ScalarConstantExpander {
992992
std::optional<ConstantSubscripts> lbounds_;
993993
};
994994

995+
// Given a collection of element values, package them as a Constant.
996+
// If the type is Character or a derived type, take the length or type
997+
// (resp.) from a another Constant.
998+
template <typename T>
999+
Constant<T> PackageConstant(std::vector<Scalar<T>> &&elements,
1000+
const Constant<T> &reference, const ConstantSubscripts &shape) {
1001+
if constexpr (T::category == TypeCategory::Character) {
1002+
return Constant<T>{
1003+
reference.LEN(), std::move(elements), ConstantSubscripts{shape}};
1004+
} else if constexpr (T::category == TypeCategory::Derived) {
1005+
return Constant<T>{reference.GetType().GetDerivedTypeSpec(),
1006+
std::move(elements), ConstantSubscripts{shape}};
1007+
} else {
1008+
return Constant<T>{std::move(elements), ConstantSubscripts{shape}};
1009+
}
1010+
}
1011+
9951012
} // namespace Fortran::evaluate
9961013

9971014
namespace Fortran::semantics {

flang/lib/Evaluate/fold-character.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
102102
CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
103103
}
104104
}
105-
// TODO: cshift, eoshift, maxloc, minloc, pack, spread, transfer,
106-
// transpose, unpack
105+
// TODO: findloc, maxloc, minloc, transfer
107106
return Expr<T>{std::move(funcRef)};
108107
}
109108

flang/lib/Evaluate/fold-complex.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
6060
} else if (name == "sum") {
6161
return FoldSum<T>(context, std::move(funcRef));
6262
}
63-
// TODO: cshift, dot_product, eoshift, matmul, pack, spread, transfer,
64-
// transpose, unpack
63+
// TODO: dot_product, matmul, transfer
6564
return Expr<T>{std::move(funcRef)};
6665
}
6766

flang/lib/Evaluate/fold-implementation.h

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ template <typename T> class Folder {
6060
std::optional<Constant<T>> Folding(ArrayRef &);
6161
Expr<T> Folding(Designator<T> &&);
6262
Constant<T> *Folding(std::optional<ActualArgument> &);
63-
Expr<T> Reshape(FunctionRef<T> &&);
63+
64+
Expr<T> CSHIFT(FunctionRef<T> &&);
65+
Expr<T> RESHAPE(FunctionRef<T> &&);
6466

6567
private:
6668
FoldingContext &context_;
@@ -546,7 +548,78 @@ template <typename T> Expr<T> MakeInvalidIntrinsic(FunctionRef<T> &&funcRef) {
546548
ActualArguments{std::move(funcRef.arguments())}}};
547549
}
548550

549-
template <typename T> Expr<T> Folder<T>::Reshape(FunctionRef<T> &&funcRef) {
551+
template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
552+
auto args{funcRef.arguments()};
553+
CHECK(args.size() == 3);
554+
const auto *array{UnwrapConstantValue<T>(args[0])};
555+
const auto *shiftExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])};
556+
auto dim{GetInt64ArgOr(args[2], 1)};
557+
if (!array || !shiftExpr || !dim) {
558+
return Expr<T>{std::move(funcRef)};
559+
}
560+
auto convertedShift{Fold(context_,
561+
ConvertToType<SubscriptInteger>(Expr<SomeInteger>{*shiftExpr}))};
562+
const auto *shift{UnwrapConstantValue<SubscriptInteger>(convertedShift)};
563+
if (!shift) {
564+
return Expr<T>{std::move(funcRef)};
565+
}
566+
// Arguments are constant
567+
if (*dim < 1 || *dim > array->Rank()) {
568+
context_.messages().Say("Invalid 'dim=' argument (%jd) in CSHIFT"_err_en_US,
569+
static_cast<std::intmax_t>(*dim));
570+
} else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) {
571+
// message already emitted from intrinsic look-up
572+
} else {
573+
int rank{array->Rank()};
574+
int zbDim{static_cast<int>(*dim) - 1};
575+
bool ok{true};
576+
if (shift->Rank() > 0) {
577+
int k{0};
578+
for (int j{0}; j < rank; ++j) {
579+
if (j != zbDim) {
580+
if (array->shape()[j] != shift->shape()[k]) {
581+
context_.messages().Say(
582+
"Invalid 'shift=' argument in CSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
583+
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
584+
static_cast<std::intmax_t>(array->shape()[j]));
585+
ok = false;
586+
}
587+
++k;
588+
}
589+
}
590+
}
591+
if (ok) {
592+
std::vector<Scalar<T>> resultElements;
593+
ConstantSubscripts arrayAt{array->lbounds()};
594+
ConstantSubscript dimLB{arrayAt[zbDim]};
595+
ConstantSubscript dimExtent{array->shape()[zbDim]};
596+
ConstantSubscripts shiftAt{shift->lbounds()};
597+
for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
598+
ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
599+
ConstantSubscript zbDimIndex{shiftCount % dimExtent};
600+
if (zbDimIndex < 0) {
601+
zbDimIndex += dimExtent;
602+
}
603+
for (ConstantSubscript j{0}; j < dimExtent; ++j) {
604+
arrayAt[zbDim] = dimLB + zbDimIndex;
605+
resultElements.push_back(array->At(arrayAt));
606+
if (++zbDimIndex == dimExtent) {
607+
zbDimIndex = 0;
608+
}
609+
}
610+
arrayAt[zbDim] = dimLB + dimExtent - 1;
611+
array->IncrementSubscripts(arrayAt);
612+
shift->IncrementSubscripts(shiftAt);
613+
}
614+
return Expr<T>{PackageConstant<T>(
615+
std::move(resultElements), *array, array->shape())};
616+
}
617+
}
618+
// Invalid, prevent re-folding
619+
return MakeInvalidIntrinsic(std::move(funcRef));
620+
}
621+
622+
template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
550623
auto args{funcRef.arguments()};
551624
CHECK(args.size() == 4);
552625
const auto *source{UnwrapConstantValue<T>(args[0])};
@@ -679,10 +752,13 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
679752
}
680753
if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
681754
const std::string name{intrinsic->name};
682-
if (name == "reshape") {
683-
return Folder<T>{context}.Reshape(std::move(funcRef));
755+
if (name == "cshift") {
756+
return Folder<T>{context}.CSHIFT(std::move(funcRef));
757+
} else if (name == "reshape") {
758+
return Folder<T>{context}.RESHAPE(std::move(funcRef));
684759
}
685-
// TODO: other type independent transformationals
760+
// TODO: eoshift, pack, spread, unpack, transpose
761+
// TODO: extends_type_of, same_type_as
686762
if constexpr (!std::is_same_v<T, SomeDerived>) {
687763
return FoldIntrinsicFunction(context, std::move(funcRef));
688764
}

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,10 +689,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
689689
} else if (name == "ubound") {
690690
return UBOUND(context, std::move(funcRef));
691691
}
692-
// TODO:
693-
// cshift, dot_product, eoshift, findloc, ibits, image_status, ishftc,
694-
// matmul, maxloc, minloc, not, pack, sign, spread, transfer, transpose,
695-
// unpack
692+
// TODO: count(w/ dim), dot_product, findloc, ibits, image_status, ishftc,
693+
// matmul, maxloc, minloc, sign, transfer
696694
return Expr<T>{std::move(funcRef)};
697695
}
698696

flang/lib/Evaluate/fold-logical.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,9 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
125125
name == "__builtin_ieee_support_underflow_control") {
126126
return Expr<T>{true};
127127
}
128-
// TODO: btest, cshift, dot_product, eoshift, is_iostat_end,
128+
// TODO: btest, dot_product, eoshift, is_iostat_end,
129129
// is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range,
130-
// pack, parity, spread, transfer, transpose, unpack, extends_type_of,
131-
// same_type_as
130+
// parity, transfer
132131
return Expr<T>{std::move(funcRef)};
133132
}
134133

flang/lib/Evaluate/fold-real.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
135135
} else if (name == "tiny") {
136136
return Expr<T>{Scalar<T>::TINY()};
137137
}
138-
// TODO: cshift, dim, dot_product, eoshift, fraction, matmul,
139-
// maxloc, minloc, modulo, nearest, norm2, pack, rrspacing, scale,
140-
// set_exponent, spacing, spread, transfer, transpose, unpack,
138+
// TODO: dim, dot_product, fraction, matmul,
139+
// maxloc, minloc, modulo, nearest, norm2, rrspacing, scale,
140+
// set_exponent, spacing, transfer,
141141
// bessel_jn (transformational) and bessel_yn (transformational)
142142
return Expr<T>{std::move(funcRef)};
143143
}

flang/test/Evaluate/folding22.f90

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,3 @@
2020
logical, parameter :: test_zero_sized = len(zero_sized).eq.6
2121

2222
end
23-

flang/test/Evaluate/folding27.f90

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
! RUN: %S/test_folding.sh %s %t %flang_fc1
2+
! REQUIRES: shell
3+
! Tests folding of CSHIFT (valid cases)
4+
module m
5+
integer, parameter :: arr(2,3) = reshape([1, 2, 3, 4, 5, 6], shape(arr))
6+
logical, parameter :: test_sanity = all([arr] == [1, 2, 3, 4, 5, 6])
7+
logical, parameter :: test_cshift_0 = all(cshift([1, 2, 3], 0) == [1, 2, 3])
8+
logical, parameter :: test_cshift_1 = all(cshift([1, 2, 3], 1) == [2, 3, 1])
9+
logical, parameter :: test_cshift_2 = all(cshift([1, 2, 3], 3) == [1, 2, 3])
10+
logical, parameter :: test_cshift_3 = all(cshift([1, 2, 3], 4) == [2, 3, 1])
11+
logical, parameter :: test_cshift_4 = all(cshift([1, 2, 3], -1) == [3, 1, 2])
12+
logical, parameter :: test_cshift_5 = all([cshift(arr, 1, dim=1)] == [2, 1, 4, 3, 6, 5])
13+
logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 5, 1, 4, 6, 2])
14+
logical, parameter :: test_cshift_7 = all([cshift(arr, [1, 2, 3])] == [2, 1, 3, 4, 6, 5])
15+
logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 5, 1, 6, 2, 4])
16+
end module

0 commit comments

Comments
 (0)