Skip to content

Commit d30b4e5

Browse files
committed
[flang][openacc] Lower serial and serial loop construct
Lower the parse tree to acc dialects operations. Make use of the parallel construct lowering and make it suitable for all compute constructs lowering. Reviewed By: PeteSteinfeld Differential Revision: https://reviews.llvm.org/D148273
1 parent 7cf1608 commit d30b4e5

File tree

3 files changed

+841
-44
lines changed

3 files changed

+841
-44
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
397397
}
398398
}
399399

400-
static mlir::acc::ParallelOp
401-
createParallelOp(Fortran::lower::AbstractConverter &converter,
402-
mlir::Location currentLocation,
403-
Fortran::semantics::SemanticsContext &semanticsContext,
404-
Fortran::lower::StatementContext &stmtCtx,
405-
const Fortran::parser::AccClauseList &accClauseList) {
400+
template <typename Op>
401+
static Op
402+
createComputeOp(Fortran::lower::AbstractConverter &converter,
403+
mlir::Location currentLocation,
404+
Fortran::semantics::SemanticsContext &semanticsContext,
405+
Fortran::lower::StatementContext &stmtCtx,
406+
const Fortran::parser::AccClauseList &accClauseList) {
406407

407408
// Parallel operation operands
408409
mlir::Value async;
@@ -550,9 +551,11 @@ createParallelOp(Fortran::lower::AbstractConverter &converter,
550551
llvm::SmallVector<int32_t, 8> operandSegments;
551552
addOperand(operands, operandSegments, async);
552553
addOperands(operands, operandSegments, waitOperands);
553-
addOperand(operands, operandSegments, numGangs);
554-
addOperand(operands, operandSegments, numWorkers);
555-
addOperand(operands, operandSegments, vectorLength);
554+
if constexpr (std::is_same_v<Op, mlir::acc::ParallelOp>) {
555+
addOperand(operands, operandSegments, numGangs);
556+
addOperand(operands, operandSegments, numWorkers);
557+
addOperand(operands, operandSegments, vectorLength);
558+
}
556559
addOperand(operands, operandSegments, ifCond);
557560
addOperand(operands, operandSegments, selfCond);
558561
addOperands(operands, operandSegments, reductionOperands);
@@ -570,28 +573,17 @@ createParallelOp(Fortran::lower::AbstractConverter &converter,
570573
addOperands(operands, operandSegments, privateOperands);
571574
addOperands(operands, operandSegments, firstprivateOperands);
572575

573-
mlir::acc::ParallelOp parallelOp =
574-
createRegionOp<mlir::acc::ParallelOp, mlir::acc::YieldOp>(
575-
firOpBuilder, currentLocation, operands, operandSegments);
576+
Op computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
577+
firOpBuilder, currentLocation, operands, operandSegments);
576578

577579
if (addAsyncAttr)
578-
parallelOp.setAsyncAttrAttr(firOpBuilder.getUnitAttr());
580+
computeOp.setAsyncAttrAttr(firOpBuilder.getUnitAttr());
579581
if (addWaitAttr)
580-
parallelOp.setWaitAttrAttr(firOpBuilder.getUnitAttr());
582+
computeOp.setWaitAttrAttr(firOpBuilder.getUnitAttr());
581583
if (addSelfAttr)
582-
parallelOp.setSelfAttrAttr(firOpBuilder.getUnitAttr());
584+
computeOp.setSelfAttrAttr(firOpBuilder.getUnitAttr());
583585

584-
return parallelOp;
585-
}
586-
587-
static void
588-
genACCParallelOp(Fortran::lower::AbstractConverter &converter,
589-
mlir::Location currentLocation,
590-
Fortran::semantics::SemanticsContext &semanticsContext,
591-
Fortran::lower::StatementContext &stmtCtx,
592-
const Fortran::parser::AccClauseList &accClauseList) {
593-
createParallelOp(converter, currentLocation, semanticsContext, stmtCtx,
594-
accClauseList);
586+
return computeOp;
595587
}
596588

597589
static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
@@ -696,32 +688,21 @@ genACC(Fortran::lower::AbstractConverter &converter,
696688
Fortran::lower::StatementContext stmtCtx;
697689

698690
if (blockDirective.v == llvm::acc::ACCD_parallel) {
699-
genACCParallelOp(converter, currentLocation, semanticsContext, stmtCtx,
700-
accClauseList);
691+
createComputeOp<mlir::acc::ParallelOp>(
692+
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
701693
} else if (blockDirective.v == llvm::acc::ACCD_data) {
702694
genACCDataOp(converter, currentLocation, semanticsContext, stmtCtx,
703695
accClauseList);
704696
} else if (blockDirective.v == llvm::acc::ACCD_serial) {
705-
TODO(currentLocation, "serial construct lowering");
697+
createComputeOp<mlir::acc::SerialOp>(
698+
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
706699
} else if (blockDirective.v == llvm::acc::ACCD_kernels) {
707700
TODO(currentLocation, "kernels construct lowering");
708701
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
709702
TODO(currentLocation, "host_data construct lowering");
710703
}
711704
}
712705

713-
static void
714-
genACCParallelLoopOps(Fortran::lower::AbstractConverter &converter,
715-
mlir::Location currentLocation,
716-
Fortran::semantics::SemanticsContext &semanticsContext,
717-
Fortran::lower::StatementContext &stmtCtx,
718-
const Fortran::parser::AccClauseList &accClauseList) {
719-
createParallelOp(converter, currentLocation, semanticsContext, stmtCtx,
720-
accClauseList);
721-
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
722-
accClauseList);
723-
}
724-
725706
static void
726707
genACC(Fortran::lower::AbstractConverter &converter,
727708
Fortran::semantics::SemanticsContext &semanticsContext,
@@ -741,10 +722,15 @@ genACC(Fortran::lower::AbstractConverter &converter,
741722
if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
742723
TODO(currentLocation, "OpenACC Kernels Loop construct not lowered yet!");
743724
} else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
744-
genACCParallelLoopOps(converter, currentLocation, semanticsContext, stmtCtx,
745-
accClauseList);
725+
createComputeOp<mlir::acc::ParallelOp>(
726+
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
727+
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
728+
accClauseList);
746729
} else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
747-
TODO(currentLocation, "OpenACC Serial Loop construct not lowered yet!");
730+
createComputeOp<mlir::acc::SerialOp>(
731+
converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
732+
createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
733+
accClauseList);
748734
} else {
749735
llvm::report_fatal_error("Unknown combined construct encountered");
750736
}

0 commit comments

Comments
 (0)