Skip to content

Commit 07d9f41

Browse files
[flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref
Previously reduction variables were always passed by value into and out of the initialization and combiner regions of the OpenMP reduction declare operation. This worked well for reductions of primitive types (and might perform better than passing by reference). But passing by reference will be useful for array and derived type reductions (e.g. to move allocation inside of the init region). Passing reductions by reference requires different LLVM-IR generation when lowering from MLIR because some of the loads/stores/allocations will now be moved inside of the init and combiner regions. This alternate code generation is requested using a new attribute to omp.wsloop and omp.parallel. Existing lowerings from mlir are unaffected (these will continue to use the by-value argument passing. Flang will continue to pass by-value argument passing for trivial types unless a (hidden) command line argument is supplied. Non-trivial types will always use the by-ref lowering. Array reductions are not ready yet (but are coming very soon). In the meantime, this is tested by forcing existing reductions to use by-ref. Commit series for by-ref OpenMP reductions 3/3 Co-authored-by: Mats Petersson <[email protected]>
1 parent ba0182b commit 07d9f41

37 files changed

+4393
-76
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
600600
return reductionSymbols;
601601
};
602602

603+
mlir::UnitAttr byrefAttr;
604+
if (ReductionProcessor::doReductionByRef(reductionVars))
605+
byrefAttr = converter.getFirOpBuilder().getUnitAttr();
606+
603607
OpWithBodyGenInfo genInfo =
604608
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
605609
.setGenNested(genNested)
@@ -619,7 +623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
619623
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
620624
reductionDeclSymbols),
621625
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
622-
/*privatizers=*/nullptr);
626+
/*privatizers=*/nullptr, byrefAttr);
623627
}
624628

625629
bool privatize = !outerCombined;
@@ -683,7 +687,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
683687
delayedPrivatizationInfo.privatizers.empty()
684688
? nullptr
685689
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
686-
privatizers));
690+
privatizers),
691+
byrefAttr);
687692
}
688693

689694
static mlir::omp::SectionOp
@@ -1568,7 +1573,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
15681573
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
15691574
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
15701575
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
1571-
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
1576+
mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
15721577
mlir::IntegerAttr orderedClauseOperand;
15731578
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
15741579
std::size_t loopVarTypeSize;
@@ -1585,6 +1590,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
15851590
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
15861591
loopVarTypeSize);
15871592

1593+
if (ReductionProcessor::doReductionByRef(reductionVars))
1594+
byrefOperand = firOpBuilder.getUnitAttr();
1595+
15881596
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
15891597
loc, lowerBound, upperBound, step, linearVars, linearStepVars,
15901598
reductionVars,
@@ -1594,8 +1602,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
15941602
reductionDeclSymbols),
15951603
scheduleValClauseOperand, scheduleChunkClauseOperand,
15961604
/*schedule_modifiers=*/nullptr,
1597-
/*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
1598-
orderClauseOperand,
1605+
/*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
1606+
orderedClauseOperand, orderClauseOperand,
15991607
/*inclusive=*/firOpBuilder.getUnitAttr());
16001608

16011609
// Handle attribute based clauses.

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 93 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@
1414

1515
#include "flang/Lower/AbstractConverter.h"
1616
#include "flang/Optimizer/Builder/Todo.h"
17+
#include "flang/Optimizer/Dialect/FIRType.h"
1718
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1819
#include "flang/Parser/tools.h"
1920
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21+
#include "llvm/Support/CommandLine.h"
22+
23+
static llvm::cl::opt<bool> forceByrefReduction(
24+
"force-byref-reduction",
25+
llvm::cl::desc("Pass all reduction arguments by reference"),
26+
llvm::cl::Hidden);
2027

2128
namespace Fortran {
2229
namespace lower {
@@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
7683
}
7784

7885
std::string ReductionProcessor::getReductionName(llvm::StringRef name,
79-
mlir::Type ty) {
86+
mlir::Type ty, bool isByRef) {
87+
ty = fir::unwrapRefType(ty);
88+
89+
// extra string to distinguish reduction functions for variables passed by
90+
// reference
91+
llvm::StringRef byrefAddition{""};
92+
if (isByRef)
93+
byrefAddition = "_byref";
94+
8095
return (llvm::Twine(name) +
8196
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
82-
llvm::Twine(ty.getIntOrFloatBitWidth()))
97+
llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
8398
.str();
8499
}
85100

86101
std::string ReductionProcessor::getReductionName(
87102
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
88-
mlir::Type ty) {
103+
mlir::Type ty, bool isByRef) {
89104
std::string reductionName;
90105

91106
switch (intrinsicOp) {
@@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
108123
break;
109124
}
110125

111-
return getReductionName(reductionName, ty);
126+
return getReductionName(reductionName, ty, isByRef);
112127
}
113128

114129
mlir::Value
115130
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
116131
ReductionIdentifier redId,
117132
fir::FirOpBuilder &builder) {
133+
type = fir::unwrapRefType(type);
118134
assert((fir::isa_integer(type) || fir::isa_real(type) ||
119135
type.isa<fir::LogicalType>()) &&
120136
"only integer, logical and real types are currently supported");
@@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
188204
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
189205
mlir::Type type, mlir::Value op1, mlir::Value op2) {
190206
mlir::Value reductionOp;
207+
type = fir::unwrapRefType(type);
191208
switch (redId) {
192209
case ReductionIdentifier::MAX:
193210
reductionOp =
@@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(
268285

269286
mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
270287
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
271-
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
288+
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
289+
bool isByRef) {
272290
mlir::OpBuilder::InsertionGuard guard(builder);
273291
mlir::ModuleOp module = builder.getModule();
274292

@@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
278296
return decl;
279297

280298
mlir::OpBuilder modBuilder(module.getBodyRegion());
299+
mlir::Type valTy = fir::unwrapRefType(type);
300+
if (!isByRef)
301+
type = valTy;
281302

282303
decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
283304
type);
284305
builder.createBlock(&decl.getInitializerRegion(),
285306
decl.getInitializerRegion().end(), {type}, {loc});
286307
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
308+
287309
mlir::Value init = getReductionInitValue(loc, type, redId, builder);
288-
builder.create<mlir::omp::YieldOp>(loc, init);
310+
if (isByRef) {
311+
mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
312+
builder.createStoreWithConvert(loc, init, alloca);
313+
builder.create<mlir::omp::YieldOp>(loc, alloca);
314+
} else {
315+
builder.create<mlir::omp::YieldOp>(loc, init);
316+
}
289317

290318
builder.createBlock(&decl.getReductionRegion(),
291319
decl.getReductionRegion().end(), {type, type},
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
294322
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
295323
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
296324
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
325+
mlir::Value outAddr = op1;
326+
327+
op1 = builder.loadIfRef(loc, op1);
328+
op2 = builder.loadIfRef(loc, op2);
297329

298330
mlir::Value reductionOp =
299331
createScalarCombiner(builder, loc, redId, type, op1, op2);
300-
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
332+
if (isByRef) {
333+
builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
334+
builder.create<mlir::omp::YieldOp>(loc, outAddr);
335+
} else {
336+
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
337+
}
301338

302339
return decl;
303340
}
304341

342+
bool ReductionProcessor::doReductionByRef(
343+
const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
344+
if (reductionVars.empty())
345+
return false;
346+
if (forceByrefReduction)
347+
return true;
348+
349+
for (mlir::Value reductionVar : reductionVars) {
350+
if (auto declare =
351+
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
352+
reductionVar = declare.getMemref();
353+
354+
if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
355+
return true;
356+
}
357+
return false;
358+
}
359+
305360
void ReductionProcessor::addReductionDecl(
306361
mlir::Location currentLocation,
307362
Fortran::lower::AbstractConverter &converter,
@@ -315,6 +370,24 @@ void ReductionProcessor::addReductionDecl(
315370
const auto &redOperator{
316371
std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
317372
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
373+
374+
// initial pass to collect all recuction vars so we can figure out if this
375+
// should happen byref
376+
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
377+
if (const auto *name{
378+
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
379+
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
380+
if (reductionSymbols)
381+
reductionSymbols->push_back(symbol);
382+
mlir::Value symVal = converter.getSymbolAddress(*symbol);
383+
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
384+
symVal = declOp.getBase();
385+
reductionVars.push_back(symVal);
386+
}
387+
}
388+
}
389+
const bool isByRef = doReductionByRef(reductionVars);
390+
318391
if (const auto &redDefinedOp =
319392
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
320393
const auto &intrinsicOp{
@@ -338,23 +411,20 @@ void ReductionProcessor::addReductionDecl(
338411
if (const auto *name{
339412
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
340413
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
341-
if (reductionSymbols)
342-
reductionSymbols->push_back(symbol);
343414
mlir::Value symVal = converter.getSymbolAddress(*symbol);
344415
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
345416
symVal = declOp.getBase();
346-
mlir::Type redType =
347-
symVal.getType().cast<fir::ReferenceType>().getEleTy();
348-
reductionVars.push_back(symVal);
349-
if (redType.isa<fir::LogicalType>())
417+
auto redType = symVal.getType().cast<fir::ReferenceType>();
418+
if (redType.getEleTy().isa<fir::LogicalType>())
350419
decl = createReductionDecl(
351420
firOpBuilder,
352-
getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
353-
redType, currentLocation);
354-
else if (redType.isIntOrIndexOrFloat()) {
355-
decl = createReductionDecl(firOpBuilder,
356-
getReductionName(intrinsicOp, redType),
357-
redId, redType, currentLocation);
421+
getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
422+
isByRef),
423+
redId, redType, currentLocation, isByRef);
424+
else if (redType.getEleTy().isIntOrIndexOrFloat()) {
425+
decl = createReductionDecl(
426+
firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
427+
redId, redType, currentLocation, isByRef);
358428
} else {
359429
TODO(currentLocation, "Reduction of some types is not supported");
360430
}
@@ -374,21 +444,17 @@ void ReductionProcessor::addReductionDecl(
374444
if (const auto *name{
375445
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
376446
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
377-
if (reductionSymbols)
378-
reductionSymbols->push_back(symbol);
379447
mlir::Value symVal = converter.getSymbolAddress(*symbol);
380448
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
381449
symVal = declOp.getBase();
382-
mlir::Type redType =
383-
symVal.getType().cast<fir::ReferenceType>().getEleTy();
384-
reductionVars.push_back(symVal);
385-
assert(redType.isIntOrIndexOrFloat() &&
450+
auto redType = symVal.getType().cast<fir::ReferenceType>();
451+
assert(redType.getEleTy().isIntOrIndexOrFloat() &&
386452
"Unsupported reduction type");
387453
decl = createReductionDecl(
388454
firOpBuilder,
389455
getReductionName(getRealName(*reductionIntrinsic).ToString(),
390-
redType),
391-
redId, redType, currentLocation);
456+
redType, isByRef),
457+
redId, redType, currentLocation, isByRef);
392458
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
393459
firOpBuilder.getContext(), decl.getSymName()));
394460
}

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
1515

1616
#include "flang/Optimizer/Builder/FIRBuilder.h"
17+
#include "flang/Optimizer/Dialect/FIRType.h"
1718
#include "flang/Parser/parse-tree.h"
1819
#include "flang/Semantics/symbol.h"
1920
#include "flang/Semantics/type.h"
@@ -71,11 +72,15 @@ class ReductionProcessor {
7172
static const Fortran::semantics::SourceName
7273
getRealName(const Fortran::parser::ProcedureDesignator &pd);
7374

74-
static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
75+
static bool
76+
doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
77+
78+
static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
79+
bool isByRef);
7580

7681
static std::string getReductionName(
7782
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
78-
mlir::Type ty);
83+
mlir::Type ty, bool isByRef);
7984

8085
/// This function returns the identity value of the operator \p
8186
/// reductionOpName. For example:
@@ -103,9 +108,11 @@ class ReductionProcessor {
103108
/// symbol table. The declaration has a constant initializer with the neutral
104109
/// value `initValue`, and the reduction combiner carried over from `reduce`.
105110
/// TODO: Generalize this for non-integer types, add atomic region.
106-
static mlir::omp::ReductionDeclareOp createReductionDecl(
107-
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
108-
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
111+
static mlir::omp::ReductionDeclareOp
112+
createReductionDecl(fir::FirOpBuilder &builder,
113+
llvm::StringRef reductionOpName,
114+
const ReductionIdentifier redId, mlir::Type type,
115+
mlir::Location loc, bool isByRef);
109116

110117
/// Creates a reduction declaration and associates it with an OpenMP block
111118
/// directive.
@@ -124,6 +131,7 @@ mlir::Value
124131
ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
125132
mlir::Type type, mlir::Location loc,
126133
mlir::Value op1, mlir::Value op2) {
134+
type = fir::unwrapRefType(type);
127135
assert(type.isIntOrIndexOrFloat() &&
128136
"only integer and float types are currently supported");
129137
if (type.isIntOrIndex())

0 commit comments

Comments
 (0)