Skip to content

Commit 8256867

Browse files
committed
[flang] Fold FINDLOC()
Fold the transformational intrinsic function FINDLOC() for all combinations of optional arguments and data types. Differential Revision: https://reviews.llvm.org/D110757
1 parent 3fcb00d commit 8256867

File tree

10 files changed

+264
-86
lines changed

10 files changed

+264
-86
lines changed

flang/include/flang/Evaluate/constant.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class ConstantBounds {
6565
const ConstantSubscripts &shape() const { return shape_; }
6666
const ConstantSubscripts &lbounds() const { return lbounds_; }
6767
void set_lbounds(ConstantSubscripts &&);
68+
void SetLowerBoundsToOne();
6869
int Rank() const { return GetRank(shape_); }
6970
Constant<SubscriptInteger> SHAPE() const;
7071

@@ -140,8 +141,8 @@ template <typename T> class Constant : public ConstantBase<T> {
140141
}
141142
}
142143

143-
// Apply subscripts. An empty subscript list is allowed for
144-
// a scalar constant.
144+
// Apply subscripts. Excess subscripts are ignored, including the
145+
// case of a scalar.
145146
Element At(const ConstantSubscripts &) const;
146147

147148
Constant Reshape(ConstantSubscripts &&) const;

flang/include/flang/Parser/provenance.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace Fortran::parser {
3030

3131
// Each character in the contiguous source stream built by the
3232
// prescanner corresponds to a particular character in a source file,
33-
// include file, macro expansion, or compiler-inserted padding.
33+
// include file, macro expansion, or compiler-inserted text.
3434
// The location of this original character to which a parsable character
3535
// corresponds is its provenance.
3636
//

flang/lib/Evaluate/constant.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,6 @@
1414

1515
namespace Fortran::evaluate {
1616

17-
std::size_t TotalElementCount(const ConstantSubscripts &shape) {
18-
std::size_t size{1};
19-
for (auto dim : shape) {
20-
CHECK(dim >= 0);
21-
size *= dim;
22-
}
23-
return size;
24-
}
25-
2617
ConstantBounds::ConstantBounds(const ConstantSubscripts &shape)
2718
: shape_(shape), lbounds_(shape_.size(), 1) {}
2819

@@ -36,6 +27,12 @@ void ConstantBounds::set_lbounds(ConstantSubscripts &&lb) {
3627
lbounds_ = std::move(lb);
3728
}
3829

30+
void ConstantBounds::SetLowerBoundsToOne() {
31+
for (auto &n : lbounds_) {
32+
n = 1;
33+
}
34+
}
35+
3936
Constant<SubscriptInteger> ConstantBounds::SHAPE() const {
4037
return AsConstantShape(shape_);
4138
}
@@ -55,6 +52,10 @@ ConstantSubscript ConstantBounds::SubscriptsToOffset(
5552
return offset;
5653
}
5754

55+
std::size_t TotalElementCount(const ConstantSubscripts &shape) {
56+
return static_cast<std::size_t>(GetSize(shape));
57+
}
58+
5859
bool ConstantBounds::IncrementSubscripts(
5960
ConstantSubscripts &indices, const std::vector<int> *dimOrder) const {
6061
int rank{GetRank(shape_)};

flang/lib/Evaluate/fold-character.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ Expr<Type<TypeCategory::Character, KIND>> FoldIntrinsicFunction(
102102
CharacterUtils<KIND>::TRIM(std::get<Scalar<T>>(*scalar))}};
103103
}
104104
}
105-
// TODO: findloc, maxloc, minloc, transfer
105+
// TODO: maxloc, minloc, transfer
106106
return Expr<T>{std::move(funcRef)};
107107
}
108108

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 152 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,8 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
182182
if (const Constant<LogicalResult> *mask{arg.empty()
183183
? nullptr
184184
: Folder<LogicalResult>{context}.Folding(arg[0])}) {
185-
std::optional<ConstantSubscript> dim;
186-
if (arg.size() > 1 && arg[1]) {
187-
dim = CheckDIM(context, arg[1], mask->Rank());
188-
if (!dim) {
189-
mask = nullptr;
190-
}
191-
}
192-
if (mask) {
185+
std::optional<int> dim;
186+
if (CheckReductionDIM(dim, context, arg, 1, mask->Rank())) {
193187
auto accumulator{[&](Scalar<T> &element, const ConstantSubscripts &at) {
194188
if (mask->At(at).IsTrue()) {
195189
element = element.AddSigned(Scalar<T>{1}).value;
@@ -201,13 +195,159 @@ static Expr<T> FoldCount(FoldingContext &context, FunctionRef<T> &&ref) {
201195
return Expr<T>{std::move(ref)};
202196
}
203197

198+
// FINDLOC()
199+
class FindlocHelper {
200+
public:
201+
FindlocHelper(
202+
DynamicType &&type, ActualArguments &arg, FoldingContext &context)
203+
: type_{type}, arg_{arg}, context_{context} {}
204+
using Result = std::optional<Constant<SubscriptInteger>>;
205+
using Types = AllIntrinsicTypes;
206+
207+
template <typename T> Result Test() const {
208+
if (T::category != type_.category() || T::kind != type_.kind()) {
209+
return std::nullopt;
210+
}
211+
CHECK(arg_.size() == 6);
212+
Folder<T> folder{context_};
213+
Constant<T> *array{folder.Folding(arg_[0])};
214+
Constant<T> *value{folder.Folding(arg_[1])};
215+
if (!array || !value) {
216+
return std::nullopt;
217+
}
218+
std::optional<int> dim;
219+
Constant<LogicalResult> *mask{
220+
GetReductionMASK(arg_[3], array->shape(), context_)};
221+
if ((!mask && arg_[3]) ||
222+
!CheckReductionDIM(dim, context_, arg_, 2, array->Rank())) {
223+
return std::nullopt;
224+
}
225+
bool back{false};
226+
if (arg_[5]) {
227+
const auto *backConst{Folder<LogicalResult>{context_}.Folding(arg_[5])};
228+
if (backConst) {
229+
back = backConst->GetScalarValue().value().IsTrue();
230+
} else {
231+
return std::nullopt;
232+
}
233+
}
234+
// Use lower bounds of 1 exclusively.
235+
array->SetLowerBoundsToOne();
236+
ConstantSubscripts at{array->lbounds()}, maskAt, resultIndices, resultShape;
237+
if (mask) {
238+
mask->SetLowerBoundsToOne();
239+
maskAt = mask->lbounds();
240+
}
241+
if (dim) { // DIM=
242+
if (*dim < 1 || *dim > array->Rank()) {
243+
context_.messages().Say(
244+
"FINDLOC(DIM=%d) is out of range"_err_en_US, *dim);
245+
return std::nullopt;
246+
}
247+
int zbDim{*dim - 1};
248+
resultShape = array->shape();
249+
resultShape.erase(
250+
resultShape.begin() + zbDim); // scalar if array is vector
251+
ConstantSubscript dimLength{array->shape()[zbDim]};
252+
ConstantSubscript n{GetSize(resultShape)};
253+
for (ConstantSubscript j{0}; j < n; ++j) {
254+
ConstantSubscript hit{array->lbounds()[zbDim] - 1};
255+
for (ConstantSubscript k{0}; k < dimLength;
256+
++k, ++at[zbDim], mask && ++maskAt[zbDim]) {
257+
if ((!mask || mask->At(maskAt).IsTrue()) &&
258+
IsHit(array->At(at), *value)) {
259+
hit = at[zbDim];
260+
if (!back) {
261+
break;
262+
}
263+
}
264+
}
265+
resultIndices.emplace_back(hit);
266+
at[zbDim] = array->lbounds()[zbDim] + dimLength - 1;
267+
array->IncrementSubscripts(at);
268+
at[zbDim] = array->lbounds()[zbDim];
269+
if (mask) {
270+
maskAt[zbDim] = mask->lbounds()[zbDim] + dimLength - 1;
271+
mask->IncrementSubscripts(maskAt);
272+
maskAt[zbDim] = mask->lbounds()[zbDim];
273+
}
274+
}
275+
} else { // no DIM=
276+
resultShape = ConstantSubscripts{array->Rank()}; // always a vector
277+
ConstantSubscript n{GetSize(array->shape())};
278+
resultIndices = ConstantSubscripts(array->Rank(), 0);
279+
for (ConstantSubscript j{0}; j < n; ++j, array->IncrementSubscripts(at),
280+
mask && mask->IncrementSubscripts(maskAt)) {
281+
if ((!mask || mask->At(maskAt).IsTrue()) &&
282+
IsHit(array->At(at), *value)) {
283+
resultIndices = at;
284+
if (!back) {
285+
break;
286+
}
287+
}
288+
}
289+
}
290+
std::vector<Scalar<SubscriptInteger>> resultElements;
291+
for (ConstantSubscript j : resultIndices) {
292+
resultElements.emplace_back(j);
293+
}
294+
return Constant<SubscriptInteger>{
295+
std::move(resultElements), std::move(resultShape)};
296+
}
297+
298+
private:
299+
template <typename T>
300+
bool IsHit(typename Constant<T>::Element element, Constant<T> value) const {
301+
std::optional<Expr<LogicalResult>> cmp;
302+
if constexpr (T::category == TypeCategory::Logical) {
303+
// array(at) .EQV. value?
304+
cmp.emplace(
305+
ConvertToType<LogicalResult>(Expr<T>{LogicalOperation<T::kind>{
306+
LogicalOperator::Eqv, Expr<T>{Constant<T>{std::move(element)}},
307+
Expr<T>{std::move(value)}}}));
308+
} else { // array(at) .EQ. value?
309+
cmp.emplace(PackageRelation(RelationalOperator::EQ,
310+
Expr<T>{Constant<T>{std::move(element)}}, Expr<T>{std::move(value)}));
311+
}
312+
Expr<LogicalResult> folded{Fold(context_, std::move(*cmp))};
313+
return GetScalarConstantValue<LogicalResult>(folded).value().IsTrue();
314+
}
315+
316+
DynamicType type_;
317+
ActualArguments &arg_;
318+
FoldingContext &context_;
319+
};
320+
321+
static std::optional<Constant<SubscriptInteger>> FoldFindlocCall(
322+
ActualArguments &arg, FoldingContext &context) {
323+
CHECK(arg.size() == 6);
324+
if (arg[0]) {
325+
if (auto type{arg[0]->GetType()}) {
326+
return common::SearchTypes(FindlocHelper{std::move(*type), arg, context});
327+
}
328+
}
329+
return std::nullopt;
330+
}
331+
332+
template <typename T>
333+
static Expr<T> FoldFindloc(FoldingContext &context, FunctionRef<T> &&ref) {
334+
static_assert(T::category == TypeCategory::Integer);
335+
if (std::optional<Constant<SubscriptInteger>> found{
336+
FoldFindlocCall(ref.arguments(), context)}) {
337+
return Expr<T>{Fold(
338+
context, ConvertToType<T>(Expr<SubscriptInteger>{std::move(*found)}))};
339+
} else {
340+
return Expr<T>{std::move(ref)};
341+
}
342+
}
343+
204344
// for IALL, IANY, & IPARITY
205345
template <typename T>
206346
static Expr<T> FoldBitReduction(FoldingContext &context, FunctionRef<T> &&ref,
207347
Scalar<T> (Scalar<T>::*operation)(const Scalar<T> &) const,
208348
Scalar<T> identity) {
209349
static_assert(T::category == TypeCategory::Integer);
210-
std::optional<ConstantSubscript> dim;
350+
std::optional<int> dim;
211351
if (std::optional<Constant<T>> array{
212352
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
213353
/*ARRAY=*/0, /*DIM=*/1, /*MASK=*/2)}) {
@@ -310,6 +450,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
310450
} else {
311451
DIE("exponent argument must be real");
312452
}
453+
} else if (name == "findloc") {
454+
return FoldFindloc<T>(context, std::move(funcRef));
313455
} else if (name == "huge") {
314456
return Expr<T>{Scalar<T>::HUGE()};
315457
} else if (name == "iachar" || name == "ichar") {
@@ -711,7 +853,7 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
711853
} else if (name == "ubound") {
712854
return UBOUND(context, std::move(funcRef));
713855
}
714-
// TODO: dot_product, findloc, ibits, image_status, ishftc,
856+
// TODO: dot_product, ibits, image_status, ishftc,
715857
// matmul, maxloc, minloc, sign, transfer
716858
return Expr<T>{std::move(funcRef)};
717859
}

flang/lib/Evaluate/fold-logical.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ static Expr<T> FoldAllAny(FoldingContext &context, FunctionRef<T> &&ref,
1919
Scalar<T> identity) {
2020
static_assert(T::category == TypeCategory::Logical);
2121
using Element = Scalar<T>;
22-
std::optional<ConstantSubscript> dim;
22+
std::optional<int> dim;
2323
if (std::optional<Constant<T>> array{
2424
ProcessReductionArgs<T>(context, ref.arguments(), dim, identity,
2525
/*ARRAY(MASK)=*/0, /*DIM=*/1)}) {

flang/lib/Evaluate/fold-reduction.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,39 @@
99
#include "fold-reduction.h"
1010

1111
namespace Fortran::evaluate {
12-
13-
std::optional<ConstantSubscript> CheckDIM(
14-
FoldingContext &context, std::optional<ActualArgument> &arg, int rank) {
15-
if (arg) {
16-
if (auto *dimConst{Folder<SubscriptInteger>{context}.Folding(arg)}) {
12+
bool CheckReductionDIM(std::optional<int> &dim, FoldingContext &context,
13+
ActualArguments &arg, std::optional<int> dimIndex, int rank) {
14+
if (dimIndex && static_cast<std::size_t>(*dimIndex) < arg.size()) {
15+
if (auto *dimConst{
16+
Folder<SubscriptInteger>{context}.Folding(arg[*dimIndex])}) {
1717
if (auto dimScalar{dimConst->GetScalarValue()}) {
18-
auto dim{dimScalar->ToInt64()};
19-
if (dim >= 1 && dim <= rank) {
20-
return {dim};
18+
auto dimVal{dimScalar->ToInt64()};
19+
if (dimVal >= 1 && dimVal <= rank) {
20+
dim = dimVal;
2121
} else {
2222
context.messages().Say(
2323
"DIM=%jd is not valid for an array of rank %d"_err_en_US,
24-
static_cast<std::intmax_t>(dim), rank);
24+
static_cast<std::intmax_t>(dimVal), rank);
25+
return false;
2526
}
2627
}
2728
}
2829
}
29-
return std::nullopt;
30+
return true;
3031
}
3132

33+
Constant<LogicalResult> *GetReductionMASK(
34+
std::optional<ActualArgument> &maskArg, const ConstantSubscripts &shape,
35+
FoldingContext &context) {
36+
Constant<LogicalResult> *mask{
37+
Folder<LogicalResult>{context}.Folding(maskArg)};
38+
if (mask &&
39+
!CheckConformance(context.messages(), AsShape(shape),
40+
AsShape(mask->shape()), CheckConformanceFlags::RightScalarExpandable,
41+
"ARRAY=", "MASK=")
42+
.value_or(false)) {
43+
mask = nullptr;
44+
}
45+
return mask;
46+
}
3247
} // namespace Fortran::evaluate

0 commit comments

Comments
 (0)