@@ -725,6 +725,112 @@ genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
725
725
exitDataOp.finalizeAttr (firOpBuilder.getUnitAttr ());
726
726
}
727
727
728
+ static void
729
+ genACCUpdateOp (Fortran::lower::AbstractConverter &converter,
730
+ const Fortran::parser::AccClauseList &accClauseList) {
731
+ mlir::Value ifCond, async, waitDevnum;
732
+ SmallVector<Value, 2 > hostOperands, deviceOperands, waitOperands,
733
+ deviceTypeOperands;
734
+
735
+ // Async and wait clause have optional values but can be present with
736
+ // no value as well. When there is no value, the op has an attribute to
737
+ // represent the clause.
738
+ bool addAsyncAttr = false ;
739
+ bool addWaitAttr = false ;
740
+ bool addIfPresentAttr = false ;
741
+
742
+ auto &firOpBuilder = converter.getFirOpBuilder ();
743
+ auto currentLocation = converter.getCurrentLocation ();
744
+
745
+ // Lower clauses values mapped to operands.
746
+ // Keep track of each group of operands separatly as clauses can appear
747
+ // more than once.
748
+ for (const auto &clause : accClauseList.v ) {
749
+ if (const auto *ifClause =
750
+ std::get_if<Fortran::parser::AccClause::If>(&clause.u )) {
751
+ mlir::Value cond = fir::getBase (
752
+ converter.genExprValue (*Fortran::semantics::GetExpr (ifClause->v )));
753
+ ifCond = firOpBuilder.createConvert (currentLocation,
754
+ firOpBuilder.getI1Type (), cond);
755
+ } else if (const auto *asyncClause =
756
+ std::get_if<Fortran::parser::AccClause::Async>(&clause.u )) {
757
+ const auto &asyncClauseValue = asyncClause->v ;
758
+ if (asyncClauseValue) { // async has a value.
759
+ async = fir::getBase (converter.genExprValue (
760
+ *Fortran::semantics::GetExpr (*asyncClauseValue)));
761
+ } else {
762
+ addAsyncAttr = true ;
763
+ }
764
+ } else if (const auto *waitClause =
765
+ std::get_if<Fortran::parser::AccClause::Wait>(&clause.u )) {
766
+ const auto &waitClauseValue = waitClause->v ;
767
+ if (waitClauseValue) { // wait has a value.
768
+ const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
769
+ const std::list<Fortran::parser::ScalarIntExpr> &waitList =
770
+ std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t );
771
+ for (const Fortran::parser::ScalarIntExpr &value : waitList) {
772
+ mlir::Value v = fir::getBase (
773
+ converter.genExprValue (*Fortran::semantics::GetExpr (value)));
774
+ waitOperands.push_back (v);
775
+ }
776
+
777
+ const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
778
+ std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t );
779
+ if (waitDevnumValue)
780
+ waitDevnum = fir::getBase (converter.genExprValue (
781
+ *Fortran::semantics::GetExpr (*waitDevnumValue)));
782
+ } else {
783
+ addWaitAttr = true ;
784
+ }
785
+ } else if (const auto *deviceTypeClause =
786
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
787
+ &clause.u )) {
788
+
789
+ const auto &deviceTypeValue = deviceTypeClause->v ;
790
+ if (deviceTypeValue) {
791
+ for (const auto &scalarIntExpr : *deviceTypeValue) {
792
+ mlir::Value expr = fir::getBase (converter.genExprValue (
793
+ *Fortran::semantics::GetExpr (scalarIntExpr)));
794
+ deviceTypeOperands.push_back (expr);
795
+ }
796
+ } else {
797
+ // * was passed as value and will be represented as a -1 constant
798
+ // integer.
799
+ mlir::Value star = firOpBuilder.createIntegerConstant (
800
+ currentLocation, firOpBuilder.getIntegerType (32 ), /* STAR */ -1 );
801
+ deviceTypeOperands.push_back (star);
802
+ }
803
+ } else if (const auto *hostClause =
804
+ std::get_if<Fortran::parser::AccClause::Host>(&clause.u )) {
805
+ genObjectList (hostClause->v , converter, hostOperands);
806
+ } else if (const auto *deviceClause =
807
+ std::get_if<Fortran::parser::AccClause::Device>(&clause.u )) {
808
+ genObjectList (deviceClause->v , converter, deviceOperands);
809
+ }
810
+ }
811
+
812
+ // Prepare the operand segement size attribute and the operands value range.
813
+ SmallVector<mlir::Value, 14 > operands;
814
+ SmallVector<int32_t , 7 > operandSegments;
815
+ addOperand (operands, operandSegments, async);
816
+ addOperand (operands, operandSegments, waitDevnum);
817
+ addOperands (operands, operandSegments, waitOperands);
818
+ addOperands (operands, operandSegments, deviceTypeOperands);
819
+ addOperand (operands, operandSegments, ifCond);
820
+ addOperands (operands, operandSegments, hostOperands);
821
+ addOperands (operands, operandSegments, deviceOperands);
822
+
823
+ auto updateOp = createSimpleOp<mlir::acc::UpdateOp>(
824
+ firOpBuilder, currentLocation, operands, operandSegments);
825
+
826
+ if (addAsyncAttr)
827
+ updateOp.asyncAttr (firOpBuilder.getUnitAttr ());
828
+ if (addWaitAttr)
829
+ updateOp.waitAttr (firOpBuilder.getUnitAttr ());
830
+ if (addIfPresentAttr)
831
+ updateOp.ifPresentAttr (firOpBuilder.getUnitAttr ());
832
+ }
833
+
728
834
static void
729
835
genACC (Fortran::lower::AbstractConverter &converter,
730
836
Fortran::lower::pft::Evaluation &eval,
@@ -745,7 +851,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
745
851
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) {
746
852
TODO (" OpenACC set directive not lowered yet!" );
747
853
} else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) {
748
- TODO ( " OpenACC update directive not lowered yet! " );
854
+ genACCUpdateOp (converter, accClauseList );
749
855
}
750
856
}
751
857
0 commit comments