Skip to content

Commit 14e17ea

Browse files
[flang][acc] Add support for lowering combined constructs (llvm#86696)
PR#80319 added support to record combined construct semantics via an attribute. Add lowering support for this.
1 parent 4c72cfa commit 14e17ea

File tree

6 files changed

+273
-216
lines changed

6 files changed

+273
-216
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,15 +1667,17 @@ static void privatizeIv(Fortran::lower::AbstractConverter &converter,
16671667
ivPrivate.push_back(privateValue);
16681668
}
16691669

1670-
static mlir::acc::LoopOp
1671-
createLoopOp(Fortran::lower::AbstractConverter &converter,
1672-
mlir::Location currentLocation,
1673-
Fortran::semantics::SemanticsContext &semanticsContext,
1674-
Fortran::lower::StatementContext &stmtCtx,
1675-
const Fortran::parser::DoConstruct &outerDoConstruct,
1676-
Fortran::lower::pft::Evaluation &eval,
1677-
const Fortran::parser::AccClauseList &accClauseList,
1678-
bool needEarlyReturnHandling = false) {
1670+
static mlir::acc::LoopOp createLoopOp(
1671+
Fortran::lower::AbstractConverter &converter,
1672+
mlir::Location currentLocation,
1673+
Fortran::semantics::SemanticsContext &semanticsContext,
1674+
Fortran::lower::StatementContext &stmtCtx,
1675+
const Fortran::parser::DoConstruct &outerDoConstruct,
1676+
Fortran::lower::pft::Evaluation &eval,
1677+
const Fortran::parser::AccClauseList &accClauseList,
1678+
std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
1679+
std::nullopt,
1680+
bool needEarlyReturnHandling = false) {
16791681
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
16801682
llvm::SmallVector<mlir::Value> tileOperands, privateOperands, ivPrivate,
16811683
reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
@@ -2015,6 +2017,10 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
20152017
if (!collapseDeviceTypes.empty())
20162018
loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
20172019

2020+
if (combinedConstructs)
2021+
loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
2022+
builder.getContext(), *combinedConstructs));
2023+
20182024
return loopOp;
20192025
}
20202026

@@ -2060,7 +2066,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
20602066
std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
20612067
auto loopOp = createLoopOp(converter, currentLocation, semanticsContext,
20622068
stmtCtx, *outerDoConstruct, eval, accClauseList,
2063-
needEarlyExitHandling);
2069+
/*combinedConstructs=*/{}, needEarlyExitHandling);
20642070
if (needEarlyExitHandling)
20652071
return loopOp.getResult(0);
20662072

@@ -2092,14 +2098,14 @@ static void genDataOperandOperationsWithModifier(
20922098
}
20932099

20942100
template <typename Op>
2095-
static Op
2096-
createComputeOp(Fortran::lower::AbstractConverter &converter,
2097-
mlir::Location currentLocation,
2098-
Fortran::lower::pft::Evaluation &eval,
2099-
Fortran::semantics::SemanticsContext &semanticsContext,
2100-
Fortran::lower::StatementContext &stmtCtx,
2101-
const Fortran::parser::AccClauseList &accClauseList,
2102-
bool outerCombined = false) {
2101+
static Op createComputeOp(
2102+
Fortran::lower::AbstractConverter &converter,
2103+
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
2104+
Fortran::semantics::SemanticsContext &semanticsContext,
2105+
Fortran::lower::StatementContext &stmtCtx,
2106+
const Fortran::parser::AccClauseList &accClauseList,
2107+
std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
2108+
std::nullopt) {
21032109

21042110
// Parallel operation operands
21052111
mlir::Value ifCond;
@@ -2292,7 +2298,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
22922298
} else if (const auto *privateClause =
22932299
std::get_if<Fortran::parser::AccClause::Private>(
22942300
&clause.u)) {
2295-
if (!outerCombined)
2301+
if (!combinedConstructs)
22962302
genPrivatizations<mlir::acc::PrivateRecipeOp>(
22972303
privateClause->v, converter, semanticsContext, stmtCtx,
22982304
privateOperands, privatizations);
@@ -2310,7 +2316,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
23102316
// combined - delay it to the loop. However, a reduction clause on a
23112317
// combined construct implies a copy clause so issue an implicit copy
23122318
// instead.
2313-
if (!outerCombined) {
2319+
if (!combinedConstructs) {
23142320
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
23152321
reductionOperands, reductionRecipes);
23162322
} else {
@@ -2362,11 +2368,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
23622368
if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
23632369
computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
23642370
builder, currentLocation, currentLocation, eval, operands,
2365-
operandSegments, outerCombined);
2371+
operandSegments, /*outerCombined=*/combinedConstructs.has_value());
23662372
else
23672373
computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
23682374
builder, currentLocation, currentLocation, eval, operands,
2369-
operandSegments, outerCombined);
2375+
operandSegments, /*outerCombined=*/combinedConstructs.has_value());
23702376

23712377
if (addSelfAttr)
23722378
computeOp.setSelfAttrAttr(builder.getUnitAttr());
@@ -2419,6 +2425,9 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
24192425
mlir::ArrayAttr::get(builder.getContext(), firstPrivatizations));
24202426
}
24212427

2428+
if (combinedConstructs)
2429+
computeOp.setCombinedAttr(builder.getUnitAttr());
2430+
24222431
auto insPt = builder.saveInsertionPoint();
24232432
builder.setInsertionPointAfter(computeOp);
24242433

@@ -2734,21 +2743,24 @@ genACC(Fortran::lower::AbstractConverter &converter,
27342743
if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
27352744
createComputeOp<mlir::acc::KernelsOp>(
27362745
converter, currentLocation, eval, semanticsContext, stmtCtx,
2737-
accClauseList, /*outerCombined=*/true);
2746+
accClauseList, mlir::acc::CombinedConstructsType::KernelsLoop);
27382747
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2739-
*outerDoConstruct, eval, accClauseList);
2748+
*outerDoConstruct, eval, accClauseList,
2749+
mlir::acc::CombinedConstructsType::KernelsLoop);
27402750
} else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
27412751
createComputeOp<mlir::acc::ParallelOp>(
27422752
converter, currentLocation, eval, semanticsContext, stmtCtx,
2743-
accClauseList, /*outerCombined=*/true);
2753+
accClauseList, mlir::acc::CombinedConstructsType::ParallelLoop);
27442754
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2745-
*outerDoConstruct, eval, accClauseList);
2755+
*outerDoConstruct, eval, accClauseList,
2756+
mlir::acc::CombinedConstructsType::ParallelLoop);
27462757
} else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
2747-
createComputeOp<mlir::acc::SerialOp>(converter, currentLocation, eval,
2748-
semanticsContext, stmtCtx,
2749-
accClauseList, /*outerCombined=*/true);
2758+
createComputeOp<mlir::acc::SerialOp>(
2759+
converter, currentLocation, eval, semanticsContext, stmtCtx,
2760+
accClauseList, mlir::acc::CombinedConstructsType::SerialLoop);
27502761
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
2751-
*outerDoConstruct, eval, accClauseList);
2762+
*outerDoConstruct, eval, accClauseList,
2763+
mlir::acc::CombinedConstructsType::SerialLoop);
27522764
} else {
27532765
llvm::report_fatal_error("Unknown combined construct encountered");
27542766
}

0 commit comments

Comments
 (0)