11
11
#include " flang/Evaluate/expression.h"
12
12
#include " flang/Evaluate/fold.h"
13
13
#include " flang/Evaluate/tools.h"
14
+ #include " flang/Evaluate/traverse.h"
15
+ #include " flang/Evaluate/type.h"
14
16
#include " flang/Lower/AbstractConverter.h"
15
17
#include " flang/Lower/PFTBuilder.h"
16
18
#include " flang/Lower/StatementContext.h"
@@ -41,6 +43,179 @@ namespace omp {
41
43
using namespace Fortran ::lower::omp;
42
44
}
43
45
46
+ namespace {
47
+ // An example of a type that can be used to get the return value from
48
+ // the visitor:
49
+ // visitor(type_identity<Xyz>) -> result_type
50
+ using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4 >;
51
+
52
+ struct GetProc
53
+ : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
54
+ false > {
55
+ using Result = const evaluate::ProcedureDesignator *;
56
+ using Base = evaluate::Traverse<GetProc, Result, false >;
57
+ GetProc () : Base(*this ) {}
58
+
59
+ using Base::operator ();
60
+
61
+ static Result Default () { return nullptr ; }
62
+
63
+ Result operator ()(const evaluate::ProcedureDesignator &p) const { return &p; }
64
+ static Result Combine (Result a, Result b) { return a != nullptr ? a : b; }
65
+ };
66
+
67
+ struct WithType {
68
+ WithType (const evaluate::DynamicType &t) : type(t) {
69
+ assert (type.category () != common::TypeCategory::Derived &&
70
+ " Type cannot be a derived type" );
71
+ }
72
+
73
+ template <typename VisitorTy> //
74
+ auto visit (VisitorTy &&visitor) const
75
+ -> std::invoke_result_t<VisitorTy, SomeArgType> {
76
+ switch (type.category ()) {
77
+ case common::TypeCategory::Integer:
78
+ switch (type.kind ()) {
79
+ case 1 :
80
+ return visitor (llvm::type_identity<evaluate::Type<Integer, 1 >>{});
81
+ case 2 :
82
+ return visitor (llvm::type_identity<evaluate::Type<Integer, 2 >>{});
83
+ case 4 :
84
+ return visitor (llvm::type_identity<evaluate::Type<Integer, 4 >>{});
85
+ case 8 :
86
+ return visitor (llvm::type_identity<evaluate::Type<Integer, 8 >>{});
87
+ case 16 :
88
+ return visitor (llvm::type_identity<evaluate::Type<Integer, 16 >>{});
89
+ }
90
+ break ;
91
+ case common::TypeCategory::Unsigned:
92
+ switch (type.kind ()) {
93
+ case 1 :
94
+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 1 >>{});
95
+ case 2 :
96
+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 2 >>{});
97
+ case 4 :
98
+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 4 >>{});
99
+ case 8 :
100
+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 8 >>{});
101
+ case 16 :
102
+ return visitor (llvm::type_identity<evaluate::Type<Unsigned, 16 >>{});
103
+ }
104
+ break ;
105
+ case common::TypeCategory::Real:
106
+ switch (type.kind ()) {
107
+ case 2 :
108
+ return visitor (llvm::type_identity<evaluate::Type<Real, 2 >>{});
109
+ case 3 :
110
+ return visitor (llvm::type_identity<evaluate::Type<Real, 3 >>{});
111
+ case 4 :
112
+ return visitor (llvm::type_identity<evaluate::Type<Real, 4 >>{});
113
+ case 8 :
114
+ return visitor (llvm::type_identity<evaluate::Type<Real, 8 >>{});
115
+ case 10 :
116
+ return visitor (llvm::type_identity<evaluate::Type<Real, 10 >>{});
117
+ case 16 :
118
+ return visitor (llvm::type_identity<evaluate::Type<Real, 16 >>{});
119
+ }
120
+ break ;
121
+ case common::TypeCategory::Complex:
122
+ switch (type.kind ()) {
123
+ case 2 :
124
+ return visitor (llvm::type_identity<evaluate::Type<Complex, 2 >>{});
125
+ case 3 :
126
+ return visitor (llvm::type_identity<evaluate::Type<Complex, 3 >>{});
127
+ case 4 :
128
+ return visitor (llvm::type_identity<evaluate::Type<Complex, 4 >>{});
129
+ case 8 :
130
+ return visitor (llvm::type_identity<evaluate::Type<Complex, 8 >>{});
131
+ case 10 :
132
+ return visitor (llvm::type_identity<evaluate::Type<Complex, 10 >>{});
133
+ case 16 :
134
+ return visitor (llvm::type_identity<evaluate::Type<Complex, 16 >>{});
135
+ }
136
+ break ;
137
+ case common::TypeCategory::Logical:
138
+ switch (type.kind ()) {
139
+ case 1 :
140
+ return visitor (llvm::type_identity<evaluate::Type<Logical, 1 >>{});
141
+ case 2 :
142
+ return visitor (llvm::type_identity<evaluate::Type<Logical, 2 >>{});
143
+ case 4 :
144
+ return visitor (llvm::type_identity<evaluate::Type<Logical, 4 >>{});
145
+ case 8 :
146
+ return visitor (llvm::type_identity<evaluate::Type<Logical, 8 >>{});
147
+ }
148
+ break ;
149
+ case common::TypeCategory::Character:
150
+ switch (type.kind ()) {
151
+ case 1 :
152
+ return visitor (llvm::type_identity<evaluate::Type<Character, 1 >>{});
153
+ case 2 :
154
+ return visitor (llvm::type_identity<evaluate::Type<Character, 2 >>{});
155
+ case 4 :
156
+ return visitor (llvm::type_identity<evaluate::Type<Character, 4 >>{});
157
+ }
158
+ break ;
159
+ case common::TypeCategory::Derived:
160
+ (void )Derived;
161
+ break ;
162
+ }
163
+ llvm_unreachable (" Unhandled type" );
164
+ }
165
+
166
+ const evaluate::DynamicType &type;
167
+
168
+ private:
169
+ // Shorter names.
170
+ static constexpr auto Character = common::TypeCategory::Character;
171
+ static constexpr auto Complex = common::TypeCategory::Complex;
172
+ static constexpr auto Derived = common::TypeCategory::Derived;
173
+ static constexpr auto Integer = common::TypeCategory::Integer;
174
+ static constexpr auto Logical = common::TypeCategory::Logical;
175
+ static constexpr auto Real = common::TypeCategory::Real;
176
+ static constexpr auto Unsigned = common::TypeCategory::Unsigned;
177
+ };
178
+
179
+ template <typename T, typename U = std::remove_const_t <T>>
180
+ U AsRvalue (T &t) {
181
+ U copy{t};
182
+ return std::move (copy);
183
+ }
184
+
185
+ template <typename T>
186
+ T &&AsRvalue(T &&t) {
187
+ return std::move (t);
188
+ }
189
+
190
+ struct ArgumentReplacer
191
+ : public evaluate::Traverse<ArgumentReplacer, bool , false > {
192
+ using Base = evaluate::Traverse<ArgumentReplacer, bool , false >;
193
+ using Result = bool ;
194
+
195
+ Result Default () const { return false ; }
196
+
197
+ ArgumentReplacer (evaluate::ActualArguments &&newArgs)
198
+ : Base(*this ), args_(std::move(newArgs)) {}
199
+
200
+ using Base::operator ();
201
+
202
+ template <typename T>
203
+ Result operator ()(const evaluate::FunctionRef<T> &x) {
204
+ assert (!done_);
205
+ auto &mut = const_cast <evaluate::FunctionRef<T> &>(x);
206
+ mut.arguments () = args_;
207
+ done_ = true ;
208
+ return true ;
209
+ }
210
+
211
+ Result Combine (Result &&a, Result &&b) { return a || b; }
212
+
213
+ private:
214
+ bool done_{false };
215
+ evaluate::ActualArguments &&args_;
216
+ };
217
+ } // namespace
218
+
44
219
[[maybe_unused]] static void
45
220
dumpAtomicAnalysis (const parser::OpenMPAtomicConstruct::Analysis &analysis) {
46
221
auto whatStr = [](int k) {
@@ -237,6 +412,85 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
237
412
return nullptr ;
238
413
}
239
414
415
+ static bool replaceArgs (semantics::SomeExpr &expr,
416
+ evaluate::ActualArguments &&newArgs) {
417
+ return ArgumentReplacer (std::move (newArgs))(expr);
418
+ }
419
+
420
+ static semantics::SomeExpr makeCall (const evaluate::DynamicType &type,
421
+ const evaluate::ProcedureDesignator &proc,
422
+ const evaluate::ActualArguments &args) {
423
+ return WithType (type).visit ([&](auto &&s) -> semantics::SomeExpr {
424
+ using Type = typename llvm::remove_cvref_t <decltype (s)>::type;
425
+ return evaluate::AsGenericExpr (
426
+ evaluate::FunctionRef<Type>(AsRvalue (proc), AsRvalue (args)));
427
+ });
428
+ }
429
+
430
+ static const evaluate::ProcedureDesignator &
431
+ getProcedureDesignator (const semantics::SomeExpr &call) {
432
+ const evaluate::ProcedureDesignator *proc = GetProc{}(call);
433
+ assert (proc && " Call has no procedure designator" );
434
+ return *proc;
435
+ }
436
+
437
+ static semantics::SomeExpr //
438
+ genReducedMinMax (const semantics::SomeExpr &orig,
439
+ const semantics::SomeExpr *atomArg,
440
+ const std::vector<semantics::SomeExpr> &args) {
441
+ // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
442
+ // One of the a_i's, say a_t, must be atomArg.
443
+ // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
444
+ // call = min/max(a_t, tmp).
445
+ // Return "call".
446
+
447
+ // The min/max intrinsics have 2 mandatory arguments, the rest is optional.
448
+ // Make sure that the "tmp = min/max(...)" doesn't promote an optional
449
+ // argument to a non-optional position. This could happen if a_t is at
450
+ // position 0 or 1.
451
+ if (args.size () <= 2 )
452
+ return orig;
453
+
454
+ evaluate::ActualArguments nonAtoms;
455
+
456
+ auto AsActual = [](const semantics::SomeExpr &x) {
457
+ semantics::SomeExpr copy = x;
458
+ return evaluate::ActualArgument (std::move (copy));
459
+ };
460
+ // Semantic checks guarantee that the "atom" shows exactly once in the
461
+ // argument list (with potential conversions around it).
462
+ // For the first two (non-optional) arguments, if "atom" is among them,
463
+ // replace it with another occurrence of the other non-optional argument.
464
+ if (atomArg == &args[0 ]) {
465
+ // (atom, x, y...) -> (x, x, y...)
466
+ nonAtoms.push_back (AsActual (args[1 ]));
467
+ nonAtoms.push_back (AsActual (args[1 ]));
468
+ } else if (atomArg == &args[1 ]) {
469
+ // (x, atom, y...) -> (x, x, y...)
470
+ nonAtoms.push_back (AsActual (args[0 ]));
471
+ nonAtoms.push_back (AsActual (args[0 ]));
472
+ } else {
473
+ // (x, y, z...) -> unchanged
474
+ nonAtoms.push_back (AsActual (args[0 ]));
475
+ nonAtoms.push_back (AsActual (args[1 ]));
476
+ }
477
+
478
+ // The rest of arguments are optional, so we can just skip "atom".
479
+ for (size_t i = 2 , e = args.size (); i != e; ++i) {
480
+ if (atomArg != &args[i])
481
+ nonAtoms.push_back (AsActual (args[i]));
482
+ }
483
+
484
+ // The type of the intermediate min/max is the same as the type of its
485
+ // arguments, which may be different from the type of the original
486
+ // expression. The original expression may have additional coverts.
487
+ auto tmp =
488
+ makeCall (*atomArg->GetType (), getProcedureDesignator (orig), nonAtoms);
489
+ semantics::SomeExpr call = orig;
490
+ replaceArgs (call, {AsActual (*atomArg), AsActual (tmp)});
491
+ return call;
492
+ }
493
+
240
494
static mlir::Operation * //
241
495
genAtomicRead (lower::AbstractConverter &converter,
242
496
semantics::SemanticsContext &semaCtx, mlir::Location loc,
@@ -350,10 +604,29 @@ genAtomicUpdate(lower::AbstractConverter &converter,
350
604
mlir::Type atomType = fir::unwrapRefType (atomAddr.getType ());
351
605
352
606
// This must exist by now.
353
- semantics::SomeExpr input = * evaluate::GetConvertInput ( assign.rhs ) ;
354
- std::vector< semantics::SomeExpr> args =
355
- evaluate::GetTopLevelOperation (input). second ;
607
+ semantics::SomeExpr rhs = assign.rhs ;
608
+ semantics::SomeExpr input = * evaluate::GetConvertInput (rhs);
609
+ auto [opcode, args] = evaluate::GetTopLevelOperation (input);
356
610
assert (!args.empty () && " Update operation without arguments" );
611
+
612
+ // Pass args as an argument to avoid capturing a structured binding.
613
+ const semantics::SomeExpr *atomArg = [&](auto &args) {
614
+ for (const semantics::SomeExpr &e : args) {
615
+ if (evaluate::IsSameOrConvertOf (e, atom))
616
+ return &e;
617
+ }
618
+ llvm_unreachable (" Atomic variable not in argument list" );
619
+ }(args);
620
+
621
+ if (opcode == evaluate::operation::Operator::Min ||
622
+ opcode == evaluate::operation::Operator::Max) {
623
+ // Min and max operations are expanded inline, so reduce them to
624
+ // operations with exactly two (non-optional) arguments.
625
+ rhs = genReducedMinMax (rhs, atomArg, args);
626
+ input = *evaluate::GetConvertInput (rhs);
627
+ std::tie (opcode, args) = evaluate::GetTopLevelOperation (input);
628
+ atomArg = nullptr ; // No longer valid.
629
+ }
357
630
for (auto &arg : args) {
358
631
if (!evaluate::IsSameOrConvertOf (arg, atom)) {
359
632
mlir::Value val = fir::getBase (converter.genExprValue (arg, naCtx, &loc));
@@ -372,7 +645,7 @@ genAtomicUpdate(lower::AbstractConverter &converter,
372
645
373
646
converter.overrideExprValues (&overrides);
374
647
mlir::Value updated =
375
- fir::getBase (converter.genExprValue (assign. rhs , stmtCtx, &loc));
648
+ fir::getBase (converter.genExprValue (rhs, stmtCtx, &loc));
376
649
mlir::Value converted = builder.createConvert (loc, atomType, updated);
377
650
builder.create <mlir::omp::YieldOp>(loc, converted);
378
651
converter.resetExprOverrides ();
0 commit comments