Skip to content

Commit 86077c4

Browse files
authored
[flang][OpenMP] Rewrite min/max with more than 2 arguments (#146423)
Given an atomic operation `w = max(w, x1, x2, ...)` rewrite it as `w = max(w, max(x1, x2, ...))`. This will avoid unnecessary non-atomic comparisons inside of the atomic operation (min/max are expanded inline). In particular, if some of the x_i's are optional dummy parameters in the containing function, this will avoid any presence tests within the atomic operation. Fixes #144838
1 parent 6e3465c commit 86077c4

File tree

3 files changed

+348
-13
lines changed

3 files changed

+348
-13
lines changed

flang/lib/Lower/OpenMP/Atomic.cpp

Lines changed: 277 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include "flang/Evaluate/expression.h"
1212
#include "flang/Evaluate/fold.h"
1313
#include "flang/Evaluate/tools.h"
14+
#include "flang/Evaluate/traverse.h"
15+
#include "flang/Evaluate/type.h"
1416
#include "flang/Lower/AbstractConverter.h"
1517
#include "flang/Lower/PFTBuilder.h"
1618
#include "flang/Lower/StatementContext.h"
@@ -41,6 +43,179 @@ namespace omp {
4143
using namespace Fortran::lower::omp;
4244
}
4345

46+
namespace {
47+
// An example of a type that can be used to get the return value from
48+
// the visitor:
49+
// visitor(type_identity<Xyz>) -> result_type
50+
using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;
51+
52+
struct GetProc
53+
: public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
54+
false> {
55+
using Result = const evaluate::ProcedureDesignator *;
56+
using Base = evaluate::Traverse<GetProc, Result, false>;
57+
GetProc() : Base(*this) {}
58+
59+
using Base::operator();
60+
61+
static Result Default() { return nullptr; }
62+
63+
Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
64+
static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
65+
};
66+
67+
struct WithType {
68+
WithType(const evaluate::DynamicType &t) : type(t) {
69+
assert(type.category() != common::TypeCategory::Derived &&
70+
"Type cannot be a derived type");
71+
}
72+
73+
template <typename VisitorTy> //
74+
auto visit(VisitorTy &&visitor) const
75+
-> std::invoke_result_t<VisitorTy, SomeArgType> {
76+
switch (type.category()) {
77+
case common::TypeCategory::Integer:
78+
switch (type.kind()) {
79+
case 1:
80+
return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
81+
case 2:
82+
return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
83+
case 4:
84+
return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
85+
case 8:
86+
return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
87+
case 16:
88+
return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
89+
}
90+
break;
91+
case common::TypeCategory::Unsigned:
92+
switch (type.kind()) {
93+
case 1:
94+
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
95+
case 2:
96+
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
97+
case 4:
98+
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
99+
case 8:
100+
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
101+
case 16:
102+
return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
103+
}
104+
break;
105+
case common::TypeCategory::Real:
106+
switch (type.kind()) {
107+
case 2:
108+
return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
109+
case 3:
110+
return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
111+
case 4:
112+
return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
113+
case 8:
114+
return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
115+
case 10:
116+
return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
117+
case 16:
118+
return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
119+
}
120+
break;
121+
case common::TypeCategory::Complex:
122+
switch (type.kind()) {
123+
case 2:
124+
return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
125+
case 3:
126+
return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
127+
case 4:
128+
return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
129+
case 8:
130+
return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
131+
case 10:
132+
return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
133+
case 16:
134+
return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
135+
}
136+
break;
137+
case common::TypeCategory::Logical:
138+
switch (type.kind()) {
139+
case 1:
140+
return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
141+
case 2:
142+
return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
143+
case 4:
144+
return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
145+
case 8:
146+
return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
147+
}
148+
break;
149+
case common::TypeCategory::Character:
150+
switch (type.kind()) {
151+
case 1:
152+
return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
153+
case 2:
154+
return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
155+
case 4:
156+
return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
157+
}
158+
break;
159+
case common::TypeCategory::Derived:
160+
(void)Derived;
161+
break;
162+
}
163+
llvm_unreachable("Unhandled type");
164+
}
165+
166+
const evaluate::DynamicType &type;
167+
168+
private:
169+
// Shorter names.
170+
static constexpr auto Character = common::TypeCategory::Character;
171+
static constexpr auto Complex = common::TypeCategory::Complex;
172+
static constexpr auto Derived = common::TypeCategory::Derived;
173+
static constexpr auto Integer = common::TypeCategory::Integer;
174+
static constexpr auto Logical = common::TypeCategory::Logical;
175+
static constexpr auto Real = common::TypeCategory::Real;
176+
static constexpr auto Unsigned = common::TypeCategory::Unsigned;
177+
};
178+
179+
template <typename T, typename U = std::remove_const_t<T>>
180+
U AsRvalue(T &t) {
181+
U copy{t};
182+
return std::move(copy);
183+
}
184+
185+
template <typename T>
186+
T &&AsRvalue(T &&t) {
187+
return std::move(t);
188+
}
189+
190+
struct ArgumentReplacer
191+
: public evaluate::Traverse<ArgumentReplacer, bool, false> {
192+
using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
193+
using Result = bool;
194+
195+
Result Default() const { return false; }
196+
197+
ArgumentReplacer(evaluate::ActualArguments &&newArgs)
198+
: Base(*this), args_(std::move(newArgs)) {}
199+
200+
using Base::operator();
201+
202+
template <typename T>
203+
Result operator()(const evaluate::FunctionRef<T> &x) {
204+
assert(!done_);
205+
auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
206+
mut.arguments() = args_;
207+
done_ = true;
208+
return true;
209+
}
210+
211+
Result Combine(Result &&a, Result &&b) { return a || b; }
212+
213+
private:
214+
bool done_{false};
215+
evaluate::ActualArguments &&args_;
216+
};
217+
} // namespace
218+
44219
[[maybe_unused]] static void
45220
dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
46221
auto whatStr = [](int k) {
@@ -237,6 +412,85 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
237412
return nullptr;
238413
}
239414

415+
static bool replaceArgs(semantics::SomeExpr &expr,
416+
evaluate::ActualArguments &&newArgs) {
417+
return ArgumentReplacer(std::move(newArgs))(expr);
418+
}
419+
420+
static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
421+
const evaluate::ProcedureDesignator &proc,
422+
const evaluate::ActualArguments &args) {
423+
return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
424+
using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
425+
return evaluate::AsGenericExpr(
426+
evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
427+
});
428+
}
429+
430+
static const evaluate::ProcedureDesignator &
431+
getProcedureDesignator(const semantics::SomeExpr &call) {
432+
const evaluate::ProcedureDesignator *proc = GetProc{}(call);
433+
assert(proc && "Call has no procedure designator");
434+
return *proc;
435+
}
436+
437+
static semantics::SomeExpr //
438+
genReducedMinMax(const semantics::SomeExpr &orig,
439+
const semantics::SomeExpr *atomArg,
440+
const std::vector<semantics::SomeExpr> &args) {
441+
// Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
442+
// One of the a_i's, say a_t, must be atomArg.
443+
// Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
444+
// call = min/max(a_t, tmp).
445+
// Return "call".
446+
447+
// The min/max intrinsics have 2 mandatory arguments, the rest is optional.
448+
// Make sure that the "tmp = min/max(...)" doesn't promote an optional
449+
// argument to a non-optional position. This could happen if a_t is at
450+
// position 0 or 1.
451+
if (args.size() <= 2)
452+
return orig;
453+
454+
evaluate::ActualArguments nonAtoms;
455+
456+
auto AsActual = [](const semantics::SomeExpr &x) {
457+
semantics::SomeExpr copy = x;
458+
return evaluate::ActualArgument(std::move(copy));
459+
};
460+
// Semantic checks guarantee that the "atom" shows exactly once in the
461+
// argument list (with potential conversions around it).
462+
// For the first two (non-optional) arguments, if "atom" is among them,
463+
// replace it with another occurrence of the other non-optional argument.
464+
if (atomArg == &args[0]) {
465+
// (atom, x, y...) -> (x, x, y...)
466+
nonAtoms.push_back(AsActual(args[1]));
467+
nonAtoms.push_back(AsActual(args[1]));
468+
} else if (atomArg == &args[1]) {
469+
// (x, atom, y...) -> (x, x, y...)
470+
nonAtoms.push_back(AsActual(args[0]));
471+
nonAtoms.push_back(AsActual(args[0]));
472+
} else {
473+
// (x, y, z...) -> unchanged
474+
nonAtoms.push_back(AsActual(args[0]));
475+
nonAtoms.push_back(AsActual(args[1]));
476+
}
477+
478+
// The rest of arguments are optional, so we can just skip "atom".
479+
for (size_t i = 2, e = args.size(); i != e; ++i) {
480+
if (atomArg != &args[i])
481+
nonAtoms.push_back(AsActual(args[i]));
482+
}
483+
484+
// The type of the intermediate min/max is the same as the type of its
485+
// arguments, which may be different from the type of the original
486+
// expression. The original expression may have additional coverts.
487+
auto tmp =
488+
makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
489+
semantics::SomeExpr call = orig;
490+
replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
491+
return call;
492+
}
493+
240494
static mlir::Operation * //
241495
genAtomicRead(lower::AbstractConverter &converter,
242496
semantics::SemanticsContext &semaCtx, mlir::Location loc,
@@ -350,10 +604,29 @@ genAtomicUpdate(lower::AbstractConverter &converter,
350604
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
351605

352606
// This must exist by now.
353-
semantics::SomeExpr input = *evaluate::GetConvertInput(assign.rhs);
354-
std::vector<semantics::SomeExpr> args =
355-
evaluate::GetTopLevelOperation(input).second;
607+
semantics::SomeExpr rhs = assign.rhs;
608+
semantics::SomeExpr input = *evaluate::GetConvertInput(rhs);
609+
auto [opcode, args] = evaluate::GetTopLevelOperation(input);
356610
assert(!args.empty() && "Update operation without arguments");
611+
612+
// Pass args as an argument to avoid capturing a structured binding.
613+
const semantics::SomeExpr *atomArg = [&](auto &args) {
614+
for (const semantics::SomeExpr &e : args) {
615+
if (evaluate::IsSameOrConvertOf(e, atom))
616+
return &e;
617+
}
618+
llvm_unreachable("Atomic variable not in argument list");
619+
}(args);
620+
621+
if (opcode == evaluate::operation::Operator::Min ||
622+
opcode == evaluate::operation::Operator::Max) {
623+
// Min and max operations are expanded inline, so reduce them to
624+
// operations with exactly two (non-optional) arguments.
625+
rhs = genReducedMinMax(rhs, atomArg, args);
626+
input = *evaluate::GetConvertInput(rhs);
627+
std::tie(opcode, args) = evaluate::GetTopLevelOperation(input);
628+
atomArg = nullptr; // No longer valid.
629+
}
357630
for (auto &arg : args) {
358631
if (!evaluate::IsSameOrConvertOf(arg, atom)) {
359632
mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
@@ -372,7 +645,7 @@ genAtomicUpdate(lower::AbstractConverter &converter,
372645

373646
converter.overrideExprValues(&overrides);
374647
mlir::Value updated =
375-
fir::getBase(converter.genExprValue(assign.rhs, stmtCtx, &loc));
648+
fir::getBase(converter.genExprValue(rhs, stmtCtx, &loc));
376649
mlir::Value converted = builder.createConvert(loc, atomType, updated);
377650
builder.create<mlir::omp::YieldOp>(loc, converted);
378651
converter.resetExprOverrides();

flang/test/Lower/OpenMP/atomic-update.f90

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ program OmpAtomicUpdate
107107
!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#0 : !fir.ref<i32> {
108108
!CHECK: ^bb0(%[[ARG:.*]]: i32):
109109
!CHECK: {{.*}} = arith.cmpi sgt, %[[ARG]], {{.*}} : i32
110-
!CHECK: {{.*}} = arith.select {{.*}}, %[[ARG]], {{.*}} : i32
111-
!CHECK: {{.*}} = arith.cmpi sgt, {{.*}}
112110
!CHECK: %[[TEMP:.*]] = arith.select {{.*}} : i32
113111
!CHECK: omp.yield(%[[TEMP]] : i32)
114112
!CHECK: }
@@ -177,13 +175,9 @@ program OmpAtomicUpdate
177175
!CHECK: %[[VAL_Z_LOADED:.*]] = fir.load %[[VAL_Z_DECLARE]]#0 : !fir.ref<i32>
178176
!CHECK: omp.atomic.update %[[VAL_W_DECLARE]]#0 : !fir.ref<i32> {
179177
!CHECK: ^bb0(%[[ARG_W:.*]]: i32):
180-
!CHECK: %[[WX_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], %[[VAL_X_LOADED]] : i32
181-
!CHECK: %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[ARG_W]], %[[VAL_X_LOADED]] : i32
182-
!CHECK: %[[WXY_CMP:.*]] = arith.cmpi sgt, %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
183-
!CHECK: %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
184-
!CHECK: %[[WXYZ_CMP:.*]] = arith.cmpi sgt, %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
185-
!CHECK: %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
186-
!CHECK: omp.yield(%[[WXYZ_MIN]] : i32)
178+
!CHECK: %[[W_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], {{.*}} : i32
179+
!CHECK: %[[WXYZ_MAX:.*]] = arith.select %[[W_CMP]], %[[ARG_W]], {{.*}} : i32
180+
!CHECK: omp.yield(%[[WXYZ_MAX]] : i32)
187181
!CHECK: }
188182
!$omp atomic update
189183
w = max(w,x,y,z)

0 commit comments

Comments
 (0)