Skip to content

[flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… #81623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,28 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
}
}

struct ExtractSubstringHelper {
template <typename T> static std::optional<Substring> visit(T &&) {
return std::nullopt;
}

static std::optional<Substring> visit(const Substring &e) { return e; }

template <typename T>
static std::optional<Substring> visit(const Designator<T> &e) {
return std::visit([](auto &&s) { return visit(s); }, e.u);
}

template <typename T>
static std::optional<Substring> visit(const Expr<T> &e) {
return std::visit([](auto &&s) { return visit(s); }, e.u);
}
};

template <typename A> std::optional<Substring> ExtractSubstring(const A &x) {
return ExtractSubstringHelper::visit(x);
}

// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
Expand Down
218 changes: 96 additions & 122 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Large diffs are not rendered by default.

29 changes: 25 additions & 4 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
mlir::Value &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::Value &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;

Expand Down Expand Up @@ -178,6 +177,10 @@ class ClauseProcessor {
/// if at least one instance was found.
template <typename T>
bool findRepeatableClause(
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
callbackFn) const;
template <typename T>
bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const;

Expand All @@ -195,7 +198,7 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
return findRepeatableClause<T>(
return findRepeatableClause2<T>(
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand Down Expand Up @@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause(

template <typename T>
bool ClauseProcessor::findRepeatableClause(
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
ClauseIterator nextIt, endIt = clauses.end();
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
nextIt = findClause<T>(it, endIt);

if (nextIt != endIt) {
callbackFn(std::get<T>(nextIt->u), nextIt->source);
found = true;
++nextIt;
}
}
return found;
}

template <typename T>
bool ClauseProcessor::findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
Expand Down
6 changes: 0 additions & 6 deletions flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,6 @@ namespace clause {
#undef EMPTY_CLASS
#undef WRAPPER_CLASS

using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
using ProcedureDesignator =
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
using ReductionOperator =
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;

DefinedOperator makeDefinedOperator(const parser::DefinedOperator &inp,
semantics::SemanticsContext &semaCtx) {
return std::visit(
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Lower/OpenMP/Clauses.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ namespace clause {
#undef EMPTY_CLASS
#undef WRAPPER_CLASS

using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
using ProcedureDesignator =
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
using ReductionOperator =
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;

// "Requires" clauses are handled early on, and the aggregated information
// is stored in the Symbol details of modules, programs, and subprograms.
// These clauses are still handled here to cover all alternatives in the
Expand Down
186 changes: 86 additions & 100 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
cp.processProcBind(procBindKindAttr);
cp.processDefault();
Expand Down Expand Up @@ -751,8 +750,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
dependOperands;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processFinal(stmtCtx, finalClauseOperand);
Expand Down Expand Up @@ -865,8 +863,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::TargetData, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
useDeviceSymbols);
Expand Down Expand Up @@ -911,20 +908,17 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
llvm::SmallVector<mlir::Attribute> dependTypeOperands;

Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
clause::If::DirectiveNameModifier directiveName;
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
[[maybe_unused]] llvm::omp::Directive directive;
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
directive = llvm::omp::Directive::OMPD_target_enter_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
directiveName = clause::If::DirectiveNameModifier::TargetExitData;
directive = llvm::omp::Directive::OMPD_target_exit_data;
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
directiveName =
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
directive = llvm::omp::Directive::OMPD_target_update;
} else {
return nullptr;
Expand Down Expand Up @@ -1126,8 +1120,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
cp.processDepend(dependTypeOperands, dependOperands);
Expand Down Expand Up @@ -1258,8 +1251,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;

ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
cp.processAllocate(allocatorOperands, allocateOperands);
cp.processDefault();
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
Expand Down Expand Up @@ -1298,8 +1290,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(

if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeList(*objectList, semaCtx)};
// Case: declare target(func, var1, var2)
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
symbolAndClause);
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
Expand Down Expand Up @@ -1438,7 +1431,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
if (const auto &ompObjectList =
std::get<std::optional<Fortran::parser::OmpObjectList>>(
flushConstruct.t))
genObjectList(*ompObjectList, converter, operandRange);
genObjectList2(*ompObjectList, converter, operandRange);
const auto &memOrderClause =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
Expand Down Expand Up @@ -1600,8 +1593,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
loopVarTypeSize);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
ifClauseOperand);
cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
cp.processSimdlen(simdlenClauseOperand);
cp.processSafelen(safelenClauseOperand);
cp.processTODO<Fortran::parser::OmpClause::Aligned,
Expand Down Expand Up @@ -2419,106 +2411,100 @@ void Fortran::lower::genOpenMPReduction(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

for (const Fortran::parser::OmpClause &clause : clauseList.v) {
List<Clause> clauses{makeList(clauseList, semaCtx)};

for (const Clause &clause : clauses) {
if (const auto &reductionClause =
std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
reductionClause->v.t)};
const auto &objectList{
std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
std::get_if<clause::Reduction>(&clause.u)) {
const auto &redOperator{
std::get<clause::ReductionOperator>(reductionClause->t)};
const auto &objects{std::get<ObjectList>(reductionClause->t)};
if (const auto *reductionOp =
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
std::get_if<clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
std::get<clause::DefinedOperator::IntrinsicOperator>(
reductionOp->u)};

switch (intrinsicOp) {
case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
case clause::DefinedOperator::IntrinsicOperator::Add:
case clause::DefinedOperator::IntrinsicOperator::Multiply:
case clause::DefinedOperator::IntrinsicOperator::AND:
case clause::DefinedOperator::IntrinsicOperator::EQV:
case clause::DefinedOperator::IntrinsicOperator::OR:
case clause::DefinedOperator::IntrinsicOperator::NEQV:
break;
default:
continue;
}
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
mlir::Type reductionType =
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
if (!reductionType.isa<fir::LogicalType>()) {
if (!reductionType.isIntOrIndexOrFloat())
continue;
}
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
if (reductionType.isa<fir::LogicalType>()) {
mlir::Operation *reductionOp = findReductionChain(loadVal);
fir::ConvertOp convertOp =
getConvertFromReductionOp(reductionOp, loadVal);
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal, &convertOp);
removeStoreOp(reductionOp, reductionVal);
} else if (mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal)) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
for (const Object &object : objects) {
if (const Fortran::semantics::Symbol *symbol = object.id()) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
mlir::Type reductionType =
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
if (!reductionType.isa<fir::LogicalType>()) {
if (!reductionType.isIntOrIndexOrFloat())
continue;
}
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
if (auto loadOp =
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
if (reductionType.isa<fir::LogicalType>()) {
mlir::Operation *reductionOp = findReductionChain(loadVal);
fir::ConvertOp convertOp =
getConvertFromReductionOp(reductionOp, loadVal);
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal, &convertOp);
removeStoreOp(reductionOp, reductionVal);
} else if (mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal)) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
}
}
}
}
} else if (const auto *reductionIntrinsic =
std::get_if<Fortran::parser::ProcedureDesignator>(
&redOperator.u)) {
std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic))
continue;
ReductionProcessor::ReductionIdentifier redId =
ReductionProcessor::getReductionType(*reductionIntrinsic);
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
for (const mlir::OpOperand &reductionValUse :
reductionVal.getUses()) {
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
// Max is lowered as a compare -> select.
// Match the pattern here.
mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal);
if (reductionOp == nullptr)
continue;

if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
redId == ReductionProcessor::ReductionIdentifier::MIN) {
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
"Selection Op not found in reduction intrinsic");
mlir::Operation *compareOp =
getCompareFromReductionOp(reductionOp, loadVal);
updateReduction(compareOp, firOpBuilder, loadVal,
reductionVal);
}
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
redId == ReductionProcessor::ReductionIdentifier::IAND) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
for (const Object &object : objects) {
if (const Fortran::semantics::Symbol *symbol = object.id()) {
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
reductionVal = declOp.getBase();
for (const mlir::OpOperand &reductionValUse :
reductionVal.getUses()) {
if (auto loadOp =
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
mlir::Value loadVal = loadOp.getRes();
// Max is lowered as a compare -> select.
// Match the pattern here.
mlir::Operation *reductionOp =
findReductionChain(loadVal, &reductionVal);
if (reductionOp == nullptr)
continue;

if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
redId == ReductionProcessor::ReductionIdentifier::MIN) {
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
"Selection Op not found in reduction intrinsic");
mlir::Operation *compareOp =
getCompareFromReductionOp(reductionOp, loadVal);
updateReduction(compareOp, firOpBuilder, loadVal,
reductionVal);
}
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
redId == ReductionProcessor::ReductionIdentifier::IAND) {
updateReduction(reductionOp, firOpBuilder, loadVal,
reductionVal);
}
}
}
Expand Down
Loading