Skip to content

[Flang][OpenMP] Push genEval calls to individual operations, NFC #77758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flang/include/flang/Lower/OpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ void genOpenMPTerminator(fir::FirOpBuilder &, mlir::Operation *,
void genOpenMPConstruct(AbstractConverter &, Fortran::lower::SymMap &,
semantics::SemanticsContext &, pft::Evaluation &,
const parser::OpenMPConstruct &);
void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
void genOpenMPDeclarativeConstruct(AbstractConverter &,
Fortran::lower::SymMap &,
semantics::SemanticsContext &,
pft::Evaluation &,
const parser::OpenMPDeclarativeConstruct &);
/// Symbols in OpenMP code can have flags (e.g. threadprivate directive)
/// that require additional handling when lowering the corresponding
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2440,7 +2440,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
ompDeviceCodeFound =
ompDeviceCodeFound ||
Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
genOpenMPDeclarativeConstruct(
*this, localSymbols, bridge.getSemanticsContext(), getEval(), ompDecl);
builder->restoreInsertionPoint(insertPt);
}

Expand Down
137 changes: 83 additions & 54 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,34 @@ static void gatherFuncAndVarSyms(
}
}

static Fortran::lower::pft::Evaluation *
getCollapsedEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe getCollapsedLoopEval would be slightly more self-explanatory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make this change in another commit.

// Return the Evaluation of the innermost collapsed loop, or the current
// evaluation, if there is nothing to collapse.
if (collapseValue == 0)
return &eval;
Comment on lines +115 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Is it better to convert this to an assert (for > 0) and move this code to the parent function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make this change in another commit.


Fortran::lower::pft::Evaluation *curEval = &eval.getFirstNestedEvaluation();
for (int i = 1; i < collapseValue; i++) {
// The nested evaluations should be DoConstructs (i.e. they should form
// a loop nest). Each DoConstruct is a tuple <NonLabelDoStmt, Block,
// EndDoStmt>.
assert(curEval->isA<Fortran::parser::DoConstruct>());
curEval = &*std::next(curEval->getNestedEvaluations().begin());
}
return curEval;
}

static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
int collapseValue = 0) {
Fortran::lower::pft::Evaluation *curEval =
getCollapsedEval(eval, collapseValue);

for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
converter.genEval(e);
}

//===----------------------------------------------------------------------===//
// DataSharingProcessor
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2944,8 +2972,9 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
std::visit(
Fortran::common::visitors{
Expand Down Expand Up @@ -3034,6 +3063,9 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, eval,
&loopOpClauseList, iv,
/*outer=*/false, &dsp);

genNestedEvaluations(converter, eval,
Fortran::lower::getCollapseValue(loopOpClauseList));
}

static void createWsLoop(Fortran::lower::AbstractConverter &converter,
Expand Down Expand Up @@ -3107,11 +3139,15 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, eval,
&beginClauseList, iv,
/*outer=*/false, &dsp);

genNestedEvaluations(converter, eval,
Fortran::lower::getCollapseValue(beginClauseList));
}

static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
Expand Down Expand Up @@ -3179,12 +3215,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
createWsLoop(converter, eval, ompDirective, loopOpClauseList, endClauseList,
currentLocation);
}

genOpenMPReduction(converter, loopOpClauseList);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
const auto &beginBlockDirective =
std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
Expand Down Expand Up @@ -3298,10 +3337,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
}
}

genNestedEvaluations(converter, eval);
genOpenMPReduction(converter, beginClauseList);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand Down Expand Up @@ -3336,10 +3380,13 @@ genOMP(Fortran::lower::AbstractConverter &converter,
}();
createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, converter, currentLocation,
eval);
genNestedEvaluations(converter, eval);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
mlir::Location currentLocation = converter.getCurrentLocation();
Expand All @@ -3359,13 +3406,18 @@ genOMP(Fortran::lower::AbstractConverter &converter,
.t);
// Currently only private/firstprivate clause is handled, and
// all privatization is done within `omp.section` operations.
symTable.pushScope();
genOpWithBody<mlir::omp::SectionOp>(converter, eval, currentLocation,
/*outerCombined=*/false,
&sectionsClauseList);
genNestedEvaluations(converter, eval);
symTable.popScope();
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionsConstruct &sectionsConstruct) {
mlir::Location currentLocation = converter.getCurrentLocation();
Expand Down Expand Up @@ -3406,10 +3458,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
/*reduction_vars=*/mlir::ValueRange(),
/*reductions=*/nullptr, allocateOperands,
allocatorOperands, nowaitClauseOperand);

genNestedEvaluations(converter, eval);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
std::visit(
Expand Down Expand Up @@ -3453,6 +3509,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
}

static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
Expand Down Expand Up @@ -3504,24 +3562,28 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
}

static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPConstruct &ompConstruct) {
std::visit(
Fortran::common::visitors{
[&](const Fortran::parser::OpenMPStandaloneConstruct
&standaloneConstruct) {
genOMP(converter, eval, semanticsContext, standaloneConstruct);
genOMP(converter, symTable, semanticsContext, eval,
standaloneConstruct);
},
[&](const Fortran::parser::OpenMPSectionsConstruct
&sectionsConstruct) {
genOMP(converter, eval, sectionsConstruct);
genOMP(converter, symTable, semanticsContext, eval,
sectionsConstruct);
},
[&](const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
genOMP(converter, eval, sectionConstruct);
genOMP(converter, symTable, semanticsContext, eval,
sectionConstruct);
},
[&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
genOMP(converter, eval, semanticsContext, loopConstruct);
genOMP(converter, symTable, semanticsContext, eval, loopConstruct);
},
[&](const Fortran::parser::OpenMPDeclarativeAllocate
&execAllocConstruct) {
Expand All @@ -3536,21 +3598,25 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct");
},
[&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
genOMP(converter, eval, semanticsContext, blockConstruct);
genOMP(converter, symTable, semanticsContext, eval, blockConstruct);
},
[&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
genOMP(converter, eval, atomicConstruct);
genOMP(converter, symTable, semanticsContext, eval,
atomicConstruct);
},
[&](const Fortran::parser::OpenMPCriticalConstruct
&criticalConstruct) {
genOMP(converter, eval, criticalConstruct);
genOMP(converter, symTable, semanticsContext, eval,
criticalConstruct);
},
},
ompConstruct.u);
}

static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) {
std::visit(
Expand All @@ -3570,7 +3636,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
},
[&](const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
genOMP(converter, eval, declareTargetConstruct);
genOMP(converter, symTable, semanticsContext, eval,
declareTargetConstruct);
},
[&](const Fortran::parser::OpenMPRequiresConstruct
&requiresConstruct) {
Expand Down Expand Up @@ -3607,57 +3674,19 @@ void Fortran::lower::genOpenMPConstruct(
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPConstruct &omp) {

symTable.pushScope();
genOMP(converter, semanticsContext, eval, omp);

const Fortran::parser::OpenMPLoopConstruct *ompLoop =
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
const Fortran::parser::OpenMPBlockConstruct *ompBlock =
std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);

// If loop is part of an OpenMP Construct then the OpenMP dialect
// workshare loop operation has already been created. Only the
// body needs to be created here and the do_loop can be skipped.
// Skip the number of collapsed loops, which is 1 when there is a
// no collapse requested.

Fortran::lower::pft::Evaluation *curEval = &eval;
const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr;
if (ompLoop) {
loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>(
std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList);

curEval = &curEval->getFirstNestedEvaluation();
for (int64_t i = 1; i < collapseValue; i++) {
curEval = &*std::next(curEval->getNestedEvaluations().begin());
}
}

for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
converter.genEval(e);

if (ompLoop) {
genOpenMPReduction(converter, *loopOpClauseList);
} else if (ompBlock) {
const auto &blockStart =
std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
const auto &blockClauses =
std::get<Fortran::parser::OmpClauseList>(blockStart.t);
genOpenMPReduction(converter, blockClauses);
}

genOMP(converter, symTable, semanticsContext, eval, omp);
symTable.popScope();
}

void Fortran::lower::genOpenMPDeclarativeConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclarativeConstruct &omp) {
genOMP(converter, eval, omp);
for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations())
converter.genEval(e);
genOMP(converter, symTable, semanticsContext, eval, omp);
genNestedEvaluations(converter, eval);
}

void Fortran::lower::genOpenMPSymbolProperties(
Expand Down