Skip to content

Commit f46f5a0

Browse files
[flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref (#84304)
Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation. This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region). Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel. Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing. Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering. Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref. Commit series for by-ref OpenMP reductions 3/3 --------- Co-authored-by: Mats Petersson <[email protected]>
1 parent f18d78b commit f46f5a0

38 files changed

+4451
-76
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
601601
return reductionSymbols;
602602
};
603603

604+
mlir::UnitAttr byrefAttr;
605+
if (ReductionProcessor::doReductionByRef(reductionVars))
606+
byrefAttr = converter.getFirOpBuilder().getUnitAttr();
607+
604608
OpWithBodyGenInfo genInfo =
605609
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
606610
.setGenNested(genNested)
@@ -620,7 +624,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
620624
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
621625
reductionDeclSymbols),
622626
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
623-
/*privatizers=*/nullptr);
627+
/*privatizers=*/nullptr, byrefAttr);
624628
}
625629

626630
bool privatize = !outerCombined;
@@ -684,7 +688,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
684688
delayedPrivatizationInfo.privatizers.empty()
685689
? nullptr
686690
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
687-
privatizers));
691+
privatizers),
692+
byrefAttr);
688693
}
689694

690695
static mlir::omp::SectionOp
@@ -1583,7 +1588,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
15831588
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
15841589
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
15851590
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
1586-
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
1591+
mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
15871592
mlir::IntegerAttr orderedClauseOperand;
15881593
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
15891594
std::size_t loopVarTypeSize;
@@ -1600,6 +1605,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
16001605
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
16011606
loopVarTypeSize);
16021607

1608+
if (ReductionProcessor::doReductionByRef(reductionVars))
1609+
byrefOperand = firOpBuilder.getUnitAttr();
1610+
16031611
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
16041612
loc, lowerBound, upperBound, step, linearVars, linearStepVars,
16051613
reductionVars,
@@ -1609,8 +1617,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
16091617
reductionDeclSymbols),
16101618
scheduleValClauseOperand, scheduleChunkClauseOperand,
16111619
/*schedule_modifiers=*/nullptr,
1612-
/*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
1613-
orderClauseOperand,
1620+
/*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
1621+
orderedClauseOperand, orderClauseOperand,
16141622
/*inclusive=*/firOpBuilder.getUnitAttr());
16151623

16161624
// Handle attribute based clauses.

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 110 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@
1414

1515
#include "flang/Lower/AbstractConverter.h"
1616
#include "flang/Optimizer/Builder/Todo.h"
17+
#include "flang/Optimizer/Dialect/FIRType.h"
1718
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1819
#include "flang/Parser/tools.h"
1920
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21+
#include "llvm/Support/CommandLine.h"
22+
23+
static llvm::cl::opt<bool> forceByrefReduction(
24+
"force-byref-reduction",
25+
llvm::cl::desc("Pass all reduction arguments by reference"),
26+
llvm::cl::Hidden);
2027

2128
namespace Fortran {
2229
namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
7683
}
7784

7885
std::string ReductionProcessor::getReductionName(llvm::StringRef name,
79-
mlir::Type ty) {
86+
mlir::Type ty, bool isByRef) {
87+
ty = fir::unwrapRefType(ty);
88+
89+
// extra string to distinguish reduction functions for variables passed by
90+
// reference
91+
llvm::StringRef byrefAddition{""};
92+
if (isByRef)
93+
byrefAddition = "_byref";
94+
8095
return (llvm::Twine(name) +
8196
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
82-
llvm::Twine(ty.getIntOrFloatBitWidth()))
97+
llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
8398
.str();
8499
}
85100

86101
std::string ReductionProcessor::getReductionName(
87102
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
88-
mlir::Type ty) {
103+
mlir::Type ty, bool isByRef) {
89104
std::string reductionName;
90105

91106
switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
108123
break;
109124
}
110125

111-
return getReductionName(reductionName, ty);
126+
return getReductionName(reductionName, ty, isByRef);
112127
}
113128

114129
mlir::Value
115130
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
116131
ReductionIdentifier redId,
117132
fir::FirOpBuilder &builder) {
133+
type = fir::unwrapRefType(type);
118134
assert((fir::isa_integer(type) || fir::isa_real(type) ||
119135
type.isa<fir::LogicalType>()) &&
120136
"only integer, logical and real types are currently supported");
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
188204
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
189205
mlir::Type type, mlir::Value op1, mlir::Value op2) {
190206
mlir::Value reductionOp;
207+
type = fir::unwrapRefType(type);
191208
switch (redId) {
192209
case ReductionIdentifier::MAX:
193210
reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
268285

269286
mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
270287
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
271-
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
288+
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
289+
bool isByRef) {
272290
mlir::OpBuilder::InsertionGuard guard(builder);
273291
mlir::ModuleOp module = builder.getModule();
274292

@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
278296
return decl;
279297

280298
mlir::OpBuilder modBuilder(module.getBodyRegion());
299+
mlir::Type valTy = fir::unwrapRefType(type);
300+
if (!isByRef)
301+
type = valTy;
281302

282303
decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
283304
type);
284305
builder.createBlock(&decl.getInitializerRegion(),
285306
decl.getInitializerRegion().end(), {type}, {loc});
286307
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
308+
287309
mlir::Value init = getReductionInitValue(loc, type, redId, builder);
288-
builder.create<mlir::omp::YieldOp>(loc, init);
310+
if (isByRef) {
311+
mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
312+
builder.createStoreWithConvert(loc, init, alloca);
313+
builder.create<mlir::omp::YieldOp>(loc, alloca);
314+
} else {
315+
builder.create<mlir::omp::YieldOp>(loc, init);
316+
}
289317

290318
builder.createBlock(&decl.getReductionRegion(),
291319
decl.getReductionRegion().end(), {type, type},
@@ -294,14 +322,45 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
294322
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
295323
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
296324
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
325+
mlir::Value outAddr = op1;
326+
327+
op1 = builder.loadIfRef(loc, op1);
328+
op2 = builder.loadIfRef(loc, op2);
297329

298330
mlir::Value reductionOp =
299331
createScalarCombiner(builder, loc, redId, type, op1, op2);
300-
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
332+
if (isByRef) {
333+
builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
334+
builder.create<mlir::omp::YieldOp>(loc, outAddr);
335+
} else {
336+
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
337+
}
301338

302339
return decl;
303340
}
304341

342+
// TODO: By-ref vs by-val reductions are currently toggled for the whole
343+
// operation (possibly effecting multiple reduction variables).
344+
// This could cause a problem with openmp target reductions because
345+
// by-ref trivial types may not be supported.
346+
bool ReductionProcessor::doReductionByRef(
347+
const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
348+
if (reductionVars.empty())
349+
return false;
350+
if (forceByrefReduction)
351+
return true;
352+
353+
for (mlir::Value reductionVar : reductionVars) {
354+
if (auto declare =
355+
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
356+
reductionVar = declare.getMemref();
357+
358+
if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
359+
return true;
360+
}
361+
return false;
362+
}
363+
305364
void ReductionProcessor::addReductionDecl(
306365
mlir::Location currentLocation,
307366
Fortran::lower::AbstractConverter &converter,
@@ -315,6 +374,37 @@ void ReductionProcessor::addReductionDecl(
315374
const auto &redOperator{
316375
std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
317376
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
377+
378+
if (!std::holds_alternative<Fortran::parser::DefinedOperator>(
379+
redOperator.u)) {
380+
if (const auto *reductionIntrinsic =
381+
std::get_if<Fortran::parser::ProcedureDesignator>(&redOperator.u)) {
382+
if (!ReductionProcessor::supportedIntrinsicProcReduction(
383+
*reductionIntrinsic)) {
384+
return;
385+
}
386+
} else {
387+
return;
388+
}
389+
}
390+
391+
// initial pass to collect all recuction vars so we can figure out if this
392+
// should happen byref
393+
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
394+
if (const auto *name{
395+
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
396+
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
397+
if (reductionSymbols)
398+
reductionSymbols->push_back(symbol);
399+
mlir::Value symVal = converter.getSymbolAddress(*symbol);
400+
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
401+
symVal = declOp.getBase();
402+
reductionVars.push_back(symVal);
403+
}
404+
}
405+
}
406+
const bool isByRef = doReductionByRef(reductionVars);
407+
318408
if (const auto &redDefinedOp =
319409
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
320410
const auto &intrinsicOp{
@@ -338,23 +428,20 @@ void ReductionProcessor::addReductionDecl(
338428
if (const auto *name{
339429
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
340430
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
341-
if (reductionSymbols)
342-
reductionSymbols->push_back(symbol);
343431
mlir::Value symVal = converter.getSymbolAddress(*symbol);
344432
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
345433
symVal = declOp.getBase();
346-
mlir::Type redType =
347-
symVal.getType().cast<fir::ReferenceType>().getEleTy();
348-
reductionVars.push_back(symVal);
349-
if (redType.isa<fir::LogicalType>())
434+
auto redType = symVal.getType().cast<fir::ReferenceType>();
435+
if (redType.getEleTy().isa<fir::LogicalType>())
350436
decl = createReductionDecl(
351437
firOpBuilder,
352-
getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
353-
redType, currentLocation);
354-
else if (redType.isIntOrIndexOrFloat()) {
355-
decl = createReductionDecl(firOpBuilder,
356-
getReductionName(intrinsicOp, redType),
357-
redId, redType, currentLocation);
438+
getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
439+
isByRef),
440+
redId, redType, currentLocation, isByRef);
441+
else if (redType.getEleTy().isIntOrIndexOrFloat()) {
442+
decl = createReductionDecl(
443+
firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
444+
redId, redType, currentLocation, isByRef);
358445
} else {
359446
TODO(currentLocation, "Reduction of some types is not supported");
360447
}
@@ -374,21 +461,17 @@ void ReductionProcessor::addReductionDecl(
374461
if (const auto *name{
375462
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
376463
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
377-
if (reductionSymbols)
378-
reductionSymbols->push_back(symbol);
379464
mlir::Value symVal = converter.getSymbolAddress(*symbol);
380465
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
381466
symVal = declOp.getBase();
382-
mlir::Type redType =
383-
symVal.getType().cast<fir::ReferenceType>().getEleTy();
384-
reductionVars.push_back(symVal);
385-
assert(redType.isIntOrIndexOrFloat() &&
467+
auto redType = symVal.getType().cast<fir::ReferenceType>();
468+
assert(redType.getEleTy().isIntOrIndexOrFloat() &&
386469
"Unsupported reduction type");
387470
decl = createReductionDecl(
388471
firOpBuilder,
389472
getReductionName(getRealName(*reductionIntrinsic).ToString(),
390-
redType),
391-
redId, redType, currentLocation);
473+
redType, isByRef),
474+
redId, redType, currentLocation, isByRef);
392475
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
393476
firOpBuilder.getContext(), decl.getSymName()));
394477
}

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
1515

1616
#include "flang/Optimizer/Builder/FIRBuilder.h"
17+
#include "flang/Optimizer/Dialect/FIRType.h"
1718
#include "flang/Parser/parse-tree.h"
1819
#include "flang/Semantics/symbol.h"
1920
#include "flang/Semantics/type.h"
@@ -71,11 +72,15 @@ class ReductionProcessor {
7172
static const Fortran::semantics::SourceName
7273
getRealName(const Fortran::parser::ProcedureDesignator &pd);
7374

74-
static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
75+
static bool
76+
doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
77+
78+
static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
79+
bool isByRef);
7580

7681
static std::string getReductionName(
7782
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
78-
mlir::Type ty);
83+
mlir::Type ty, bool isByRef);
7984

8085
/// This function returns the identity value of the operator \p
8186
/// reductionOpName. For example:
@@ -103,9 +108,11 @@ class ReductionProcessor {
103108
/// symbol table. The declaration has a constant initializer with the neutral
104109
/// value `initValue`, and the reduction combiner carried over from `reduce`.
105110
/// TODO: Generalize this for non-integer types, add atomic region.
106-
static mlir::omp::ReductionDeclareOp createReductionDecl(
107-
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
108-
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
111+
static mlir::omp::ReductionDeclareOp
112+
createReductionDecl(fir::FirOpBuilder &builder,
113+
llvm::StringRef reductionOpName,
114+
const ReductionIdentifier redId, mlir::Type type,
115+
mlir::Location loc, bool isByRef);
109116

110117
/// Creates a reduction declaration and associates it with an OpenMP block
111118
/// directive.
@@ -124,6 +131,7 @@ mlir::Value
124131
ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
125132
mlir::Type type, mlir::Location loc,
126133
mlir::Value op1, mlir::Value op2) {
134+
type = fir::unwrapRefType(type);
127135
assert(type.isIntOrIndexOrFloat() &&
128136
"only integer and float types are currently supported");
129137
if (type.isIntOrIndex())

0 commit comments

Comments
 (0)