Skip to content

[flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref #84304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
void createStoreWithConvert(mlir::Location loc, mlir::Value val,
mlir::Value addr);

/// Create a fir.load if \p val is a reference or pointer type. Return the
/// result of the load if it was created, otherwise return \p val
mlir::Value loadIfRef(mlir::Location loc, mlir::Value val);

/// Create a new FuncOp. If the function may have already been created, use
/// `addNamedFunction` instead.
mlir::func::FuncOp createFunction(mlir::Location loc, llvm::StringRef name,
Expand Down
12 changes: 2 additions & 10 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1041,14 +1041,6 @@ static mlir::Value genLogicalCombiner(fir::FirOpBuilder &builder,
return builder.create<fir::ConvertOp>(loc, value1.getType(), combined);
}

static mlir::Value loadIfRef(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value value) {
if (mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType>(
value.getType()))
return builder.create<fir::LoadOp>(loc, value);
return value;
}

static mlir::Value genComparisonCombiner(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::arith::CmpIPredicate pred,
Expand All @@ -1066,8 +1058,8 @@ static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
mlir::acc::ReductionOperator op,
mlir::Type ty, mlir::Value value1,
mlir::Value value2) {
value1 = loadIfRef(builder, loc, value1);
value2 = loadIfRef(builder, loc, value2);
value1 = builder.loadIfRef(loc, value1);
value2 = builder.loadIfRef(loc, value2);
if (op == mlir::acc::ReductionOperator::AccAdd) {
if (ty.isIntOrIndex())
return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
Expand Down
18 changes: 13 additions & 5 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
return reductionSymbols;
};

mlir::UnitAttr byrefAttr;
if (ReductionProcessor::doReductionByRef(reductionVars))
byrefAttr = converter.getFirOpBuilder().getUnitAttr();

OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
Expand All @@ -619,7 +623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
/*privatizers=*/nullptr);
/*privatizers=*/nullptr, byrefAttr);
}

bool privatize = !outerCombined;
Expand Down Expand Up @@ -683,7 +687,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
delayedPrivatizationInfo.privatizers.empty()
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
privatizers));
privatizers),
byrefAttr);
}

static mlir::omp::SectionOp
Expand Down Expand Up @@ -1568,7 +1573,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
mlir::IntegerAttr orderedClauseOperand;
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
std::size_t loopVarTypeSize;
Expand All @@ -1585,6 +1590,9 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
loopVarTypeSize);

if (ReductionProcessor::doReductionByRef(reductionVars))
byrefOperand = firOpBuilder.getUnitAttr();

auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
loc, lowerBound, upperBound, step, linearVars, linearStepVars,
reductionVars,
Expand All @@ -1594,8 +1602,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
reductionDeclSymbols),
scheduleValClauseOperand, scheduleChunkClauseOperand,
/*schedule_modifiers=*/nullptr,
/*simd_modifier=*/nullptr, nowaitClauseOperand, orderedClauseOperand,
orderClauseOperand,
/*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
orderedClauseOperand, orderClauseOperand,
/*inclusive=*/firOpBuilder.getUnitAttr());

// Handle attribute based clauses.
Expand Down
137 changes: 110 additions & 27 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@

#include "flang/Lower/AbstractConverter.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"

static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
llvm::cl::desc("Pass all reduction arguments by reference"),
llvm::cl::Hidden);

namespace Fortran {
namespace lower {
Expand Down Expand Up @@ -76,16 +83,24 @@ bool ReductionProcessor::supportedIntrinsicProcReduction(
}

std::string ReductionProcessor::getReductionName(llvm::StringRef name,
mlir::Type ty) {
mlir::Type ty, bool isByRef) {
ty = fir::unwrapRefType(ty);

// extra string to distinguish reduction functions for variables passed by
// reference
llvm::StringRef byrefAddition{""};
if (isByRef)
byrefAddition = "_byref";

return (llvm::Twine(name) +
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
llvm::Twine(ty.getIntOrFloatBitWidth()))
llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
.str();
}

std::string ReductionProcessor::getReductionName(
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
mlir::Type ty) {
mlir::Type ty, bool isByRef) {
std::string reductionName;

switch (intrinsicOp) {
Expand All @@ -108,13 +123,14 @@ std::string ReductionProcessor::getReductionName(
break;
}

return getReductionName(reductionName, ty);
return getReductionName(reductionName, ty, isByRef);
}

mlir::Value
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
ReductionIdentifier redId,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
assert((fir::isa_integer(type) || fir::isa_real(type) ||
type.isa<fir::LogicalType>()) &&
"only integer, logical and real types are currently supported");
Expand Down Expand Up @@ -188,6 +204,7 @@ mlir::Value ReductionProcessor::createScalarCombiner(
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
mlir::Type type, mlir::Value op1, mlir::Value op2) {
mlir::Value reductionOp;
type = fir::unwrapRefType(type);
switch (redId) {
case ReductionIdentifier::MAX:
reductionOp =
Expand Down Expand Up @@ -268,7 +285,8 @@ mlir::Value ReductionProcessor::createScalarCombiner(

mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef) {
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();

Expand All @@ -278,14 +296,24 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
return decl;

mlir::OpBuilder modBuilder(module.getBodyRegion());
mlir::Type valTy = fir::unwrapRefType(type);
if (!isByRef)
type = valTy;

decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
type);
builder.createBlock(&decl.getInitializerRegion(),
decl.getInitializerRegion().end(), {type}, {loc});
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());

mlir::Value init = getReductionInitValue(loc, type, redId, builder);
builder.create<mlir::omp::YieldOp>(loc, init);
if (isByRef) {
mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
builder.createStoreWithConvert(loc, init, alloca);
builder.create<mlir::omp::YieldOp>(loc, alloca);
} else {
builder.create<mlir::omp::YieldOp>(loc, init);
}

builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
Expand All @@ -294,14 +322,45 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
mlir::Value outAddr = op1;

op1 = builder.loadIfRef(loc, op1);
op2 = builder.loadIfRef(loc, op2);

mlir::Value reductionOp =
createScalarCombiner(builder, loc, redId, type, op1, op2);
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
if (isByRef) {
builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
builder.create<mlir::omp::YieldOp>(loc, outAddr);
} else {
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
}

return decl;
}

// TODO: By-ref vs by-val reductions are currently toggled for the whole
// operation (possibly effecting multiple reduction variables).
// This could cause a problem with openmp target reductions because
// by-ref trivial types may not be supported.
bool ReductionProcessor::doReductionByRef(
const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
if (reductionVars.empty())
return false;
if (forceByrefReduction)
return true;

for (mlir::Value reductionVar : reductionVars) {
if (auto declare =
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
reductionVar = declare.getMemref();

if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
return true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this imply that all reductions on a clause have to be by ref or by val? E.g. if we have an array reduction on the clause does that mean an integer reduction also changes to byref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does. I did this to keep things simpler. Currently byref vs byval is toggled over the whole wsloop or parallel region. A more sophisticated implementation could instead track this per reduction argument. I chose not to do this to keep things simple.

I suspect that in most cases, if an integer reduction and an array reduction are used together, the array reduction would take long enough that the performance loss from doing the integer reduction by reference would not be significant. I have not measured this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably true wrt performance, but I believe when it comes to openmp target reductions doing reductions on basic types will affect correctness. As far as I remember you do not need to ensure manually that for example the INTEGER exists on the target device, whereas you do for an array (and would need to if the integer is passed by reference).

I think fixing that in a subsequent patch is probably fine though, as long as we add a TODO mentioning that ideally it should be considered separately per argument.

}
return false;
}

void ReductionProcessor::addReductionDecl(
mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
Expand All @@ -315,6 +374,37 @@ void ReductionProcessor::addReductionDecl(
const auto &redOperator{
std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};

if (!std::holds_alternative<Fortran::parser::DefinedOperator>(
redOperator.u)) {
if (const auto *reductionIntrinsic =
std::get_if<Fortran::parser::ProcedureDesignator>(&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
return;
}
} else {
return;
}
}

// initial pass to collect all recuction vars so we can figure out if this
// should happen byref
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
reductionVars.push_back(symVal);
}
}
}
const bool isByRef = doReductionByRef(reductionVars);

if (const auto &redDefinedOp =
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
Expand All @@ -338,23 +428,20 @@ void ReductionProcessor::addReductionDecl(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
mlir::Type redType =
symVal.getType().cast<fir::ReferenceType>().getEleTy();
reductionVars.push_back(symVal);
if (redType.isa<fir::LogicalType>())
auto redType = symVal.getType().cast<fir::ReferenceType>();
if (redType.getEleTy().isa<fir::LogicalType>())
decl = createReductionDecl(
firOpBuilder,
getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
redType, currentLocation);
else if (redType.isIntOrIndexOrFloat()) {
decl = createReductionDecl(firOpBuilder,
getReductionName(intrinsicOp, redType),
redId, redType, currentLocation);
getReductionName(intrinsicOp, firOpBuilder.getI1Type(),
isByRef),
redId, redType, currentLocation, isByRef);
else if (redType.getEleTy().isIntOrIndexOrFloat()) {
decl = createReductionDecl(
firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
redId, redType, currentLocation, isByRef);
} else {
TODO(currentLocation, "Reduction of some types is not supported");
}
Expand All @@ -374,21 +461,17 @@ void ReductionProcessor::addReductionDecl(
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
mlir::Type redType =
symVal.getType().cast<fir::ReferenceType>().getEleTy();
reductionVars.push_back(symVal);
assert(redType.isIntOrIndexOrFloat() &&
auto redType = symVal.getType().cast<fir::ReferenceType>();
assert(redType.getEleTy().isIntOrIndexOrFloat() &&
"Unsupported reduction type");
decl = createReductionDecl(
firOpBuilder,
getReductionName(getRealName(*reductionIntrinsic).ToString(),
redType),
redId, redType, currentLocation);
redType, isByRef),
redId, redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
Expand Down
18 changes: 13 additions & 5 deletions flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
Expand Down Expand Up @@ -71,11 +72,15 @@ class ReductionProcessor {
static const Fortran::semantics::SourceName
getRealName(const Fortran::parser::ProcedureDesignator &pd);

static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
static bool
doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);

static std::string getReductionName(llvm::StringRef name, mlir::Type ty,
bool isByRef);

static std::string getReductionName(
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
mlir::Type ty);
mlir::Type ty, bool isByRef);

/// This function returns the identity value of the operator \p
/// reductionOpName. For example:
Expand Down Expand Up @@ -103,9 +108,11 @@ class ReductionProcessor {
/// symbol table. The declaration has a constant initializer with the neutral
/// value `initValue`, and the reduction combiner carried over from `reduce`.
/// TODO: Generalize this for non-integer types, add atomic region.
static mlir::omp::ReductionDeclareOp createReductionDecl(
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc);
static mlir::omp::ReductionDeclareOp
createReductionDecl(fir::FirOpBuilder &builder,
llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type,
mlir::Location loc, bool isByRef);

/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
Expand All @@ -124,6 +131,7 @@ mlir::Value
ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2) {
type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
"only integer and float types are currently supported");
if (type.isIntOrIndex())
Expand Down
Loading