@@ -397,12 +397,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
397
397
}
398
398
}
399
399
400
- static mlir::acc::ParallelOp
401
- createParallelOp (Fortran::lower::AbstractConverter &converter,
402
- mlir::Location currentLocation,
403
- Fortran::semantics::SemanticsContext &semanticsContext,
404
- Fortran::lower::StatementContext &stmtCtx,
405
- const Fortran::parser::AccClauseList &accClauseList) {
400
+ template <typename Op>
401
+ static Op
402
+ createComputeOp (Fortran::lower::AbstractConverter &converter,
403
+ mlir::Location currentLocation,
404
+ Fortran::semantics::SemanticsContext &semanticsContext,
405
+ Fortran::lower::StatementContext &stmtCtx,
406
+ const Fortran::parser::AccClauseList &accClauseList) {
406
407
407
408
// Parallel operation operands
408
409
mlir::Value async;
@@ -550,9 +551,11 @@ createParallelOp(Fortran::lower::AbstractConverter &converter,
550
551
llvm::SmallVector<int32_t , 8 > operandSegments;
551
552
addOperand (operands, operandSegments, async);
552
553
addOperands (operands, operandSegments, waitOperands);
553
- addOperand (operands, operandSegments, numGangs);
554
- addOperand (operands, operandSegments, numWorkers);
555
- addOperand (operands, operandSegments, vectorLength);
554
+ if constexpr (std::is_same_v<Op, mlir::acc::ParallelOp>) {
555
+ addOperand (operands, operandSegments, numGangs);
556
+ addOperand (operands, operandSegments, numWorkers);
557
+ addOperand (operands, operandSegments, vectorLength);
558
+ }
556
559
addOperand (operands, operandSegments, ifCond);
557
560
addOperand (operands, operandSegments, selfCond);
558
561
addOperands (operands, operandSegments, reductionOperands);
@@ -570,28 +573,17 @@ createParallelOp(Fortran::lower::AbstractConverter &converter,
570
573
addOperands (operands, operandSegments, privateOperands);
571
574
addOperands (operands, operandSegments, firstprivateOperands);
572
575
573
- mlir::acc::ParallelOp parallelOp =
574
- createRegionOp<mlir::acc::ParallelOp, mlir::acc::YieldOp>(
575
- firOpBuilder, currentLocation, operands, operandSegments);
576
+ Op computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
577
+ firOpBuilder, currentLocation, operands, operandSegments);
576
578
577
579
if (addAsyncAttr)
578
- parallelOp .setAsyncAttrAttr (firOpBuilder.getUnitAttr ());
580
+ computeOp .setAsyncAttrAttr (firOpBuilder.getUnitAttr ());
579
581
if (addWaitAttr)
580
- parallelOp .setWaitAttrAttr (firOpBuilder.getUnitAttr ());
582
+ computeOp .setWaitAttrAttr (firOpBuilder.getUnitAttr ());
581
583
if (addSelfAttr)
582
- parallelOp .setSelfAttrAttr (firOpBuilder.getUnitAttr ());
584
+ computeOp .setSelfAttrAttr (firOpBuilder.getUnitAttr ());
583
585
584
- return parallelOp;
585
- }
586
-
587
- static void
588
- genACCParallelOp (Fortran::lower::AbstractConverter &converter,
589
- mlir::Location currentLocation,
590
- Fortran::semantics::SemanticsContext &semanticsContext,
591
- Fortran::lower::StatementContext &stmtCtx,
592
- const Fortran::parser::AccClauseList &accClauseList) {
593
- createParallelOp (converter, currentLocation, semanticsContext, stmtCtx,
594
- accClauseList);
586
+ return computeOp;
595
587
}
596
588
597
589
static void genACCDataOp (Fortran::lower::AbstractConverter &converter,
@@ -696,32 +688,21 @@ genACC(Fortran::lower::AbstractConverter &converter,
696
688
Fortran::lower::StatementContext stmtCtx;
697
689
698
690
if (blockDirective.v == llvm::acc::ACCD_parallel) {
699
- genACCParallelOp (converter, currentLocation, semanticsContext, stmtCtx,
700
- accClauseList);
691
+ createComputeOp<mlir::acc::ParallelOp>(
692
+ converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
701
693
} else if (blockDirective.v == llvm::acc::ACCD_data) {
702
694
genACCDataOp (converter, currentLocation, semanticsContext, stmtCtx,
703
695
accClauseList);
704
696
} else if (blockDirective.v == llvm::acc::ACCD_serial) {
705
- TODO (currentLocation, " serial construct lowering" );
697
+ createComputeOp<mlir::acc::SerialOp>(
698
+ converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
706
699
} else if (blockDirective.v == llvm::acc::ACCD_kernels) {
707
700
TODO (currentLocation, " kernels construct lowering" );
708
701
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
709
702
TODO (currentLocation, " host_data construct lowering" );
710
703
}
711
704
}
712
705
713
- static void
714
- genACCParallelLoopOps (Fortran::lower::AbstractConverter &converter,
715
- mlir::Location currentLocation,
716
- Fortran::semantics::SemanticsContext &semanticsContext,
717
- Fortran::lower::StatementContext &stmtCtx,
718
- const Fortran::parser::AccClauseList &accClauseList) {
719
- createParallelOp (converter, currentLocation, semanticsContext, stmtCtx,
720
- accClauseList);
721
- createLoopOp (converter, currentLocation, semanticsContext, stmtCtx,
722
- accClauseList);
723
- }
724
-
725
706
static void
726
707
genACC (Fortran::lower::AbstractConverter &converter,
727
708
Fortran::semantics::SemanticsContext &semanticsContext,
@@ -741,10 +722,15 @@ genACC(Fortran::lower::AbstractConverter &converter,
741
722
if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
742
723
TODO (currentLocation, " OpenACC Kernels Loop construct not lowered yet!" );
743
724
} else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
744
- genACCParallelLoopOps (converter, currentLocation, semanticsContext, stmtCtx,
745
- accClauseList);
725
+ createComputeOp<mlir::acc::ParallelOp>(
726
+ converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
727
+ createLoopOp (converter, currentLocation, semanticsContext, stmtCtx,
728
+ accClauseList);
746
729
} else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
747
- TODO (currentLocation, " OpenACC Serial Loop construct not lowered yet!" );
730
+ createComputeOp<mlir::acc::SerialOp>(
731
+ converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
732
+ createLoopOp (converter, currentLocation, semanticsContext, stmtCtx,
733
+ accClauseList);
748
734
} else {
749
735
llvm::report_fatal_error (" Unknown combined construct encountered" );
750
736
}
0 commit comments