@@ -723,6 +723,48 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
723
723
}
724
724
}
725
725
726
+ static void createBodyOfTargetOp (
727
+ Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
728
+ const llvm::SmallVector<mlir::Type> &useDeviceTypes,
729
+ const llvm::SmallVector<mlir::Location> &useDeviceLocs,
730
+ const SmallVector<const Fortran::semantics::Symbol *> &useDeviceSymbols,
731
+ const mlir::Location ¤tLocation) {
732
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
733
+ mlir::Region ®ion = dataOp.getRegion ();
734
+
735
+ firOpBuilder.createBlock (®ion, {}, useDeviceTypes, useDeviceLocs);
736
+ firOpBuilder.create <mlir::omp::TerminatorOp>(currentLocation);
737
+ firOpBuilder.setInsertionPointToStart (®ion.front ());
738
+
739
+ unsigned argIndex = 0 ;
740
+ for (auto *sym : useDeviceSymbols) {
741
+ const mlir::BlockArgument &arg = region.front ().getArgument (argIndex);
742
+ mlir::Value val = fir::getBase (arg);
743
+ fir::ExtendedValue extVal = converter.getSymbolExtendedValue (*sym);
744
+ if (auto refType = val.getType ().dyn_cast <fir::ReferenceType>()) {
745
+ if (fir::isa_builtin_cptr_type (refType.getElementType ())) {
746
+ converter.bindSymbol (*sym, val);
747
+ } else {
748
+ extVal.match (
749
+ [&](const fir::MutableBoxValue &mbv) {
750
+ converter.bindSymbol (
751
+ *sym,
752
+ fir::MutableBoxValue (
753
+ val, fir::factory::getNonDeferredLenParams (extVal), {}));
754
+ },
755
+ [&](const auto &) {
756
+ TODO (converter.getCurrentLocation (),
757
+ " use_device clause operand unsupported type" );
758
+ });
759
+ }
760
+ } else {
761
+ TODO (converter.getCurrentLocation (),
762
+ " use_device clause operand unsupported type" );
763
+ }
764
+ argIndex++;
765
+ }
766
+ }
767
+
726
768
static void createTargetOp (Fortran::lower::AbstractConverter &converter,
727
769
const Fortran::parser::OmpClauseList &opClauseList,
728
770
const llvm::omp::Directive &directive,
@@ -732,13 +774,24 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
732
774
733
775
mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand;
734
776
mlir::UnitAttr nowaitAttr;
735
- llvm::SmallVector<mlir::Value> useDevicePtrOperand, useDeviceAddrOperand ,
736
- mapOperands ;
777
+ llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands ,
778
+ deviceAddrOperands ;
737
779
llvm::SmallVector<mlir::IntegerAttr> mapTypes;
780
+ llvm::SmallVector<mlir::Type> useDeviceTypes;
781
+ llvm::SmallVector<mlir::Location> useDeviceLocs;
782
+ SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
783
+
784
+ // / Check for unsupported map operand types.
785
+ auto checkType = [](auto currentLocation, mlir::Type type) {
786
+ if (auto refType = type.dyn_cast <fir::ReferenceType>())
787
+ type = refType.getElementType ();
788
+ if (auto boxType = type.dyn_cast_or_null <fir::BoxType>())
789
+ if (!boxType.getElementType ().isa <fir::PointerType>())
790
+ TODO (currentLocation, " OMPD_target_data MapOperand BoxType" );
791
+ };
738
792
739
- auto addMapClause = [&firOpBuilder, &converter, &mapOperands,
740
- &mapTypes](const auto &mapClause,
741
- mlir::Location ¤tLocation) {
793
+ auto addMapClause = [&](const auto &mapClause,
794
+ mlir::Location ¤tLocation) {
742
795
auto mapType = std::get<Fortran::parser::OmpMapType::Type>(
743
796
std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v .t )
744
797
->t );
@@ -793,18 +846,25 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
793
846
converter, mapOperand);
794
847
795
848
for (mlir::Value mapOp : mapOperand) {
796
- // / Check for unsupported map operand types.
797
- mlir::Type checkType = mapOp.getType ();
798
- if (auto refType = checkType.dyn_cast <fir::ReferenceType>())
799
- checkType = refType.getElementType ();
800
- if (checkType.isa <fir::BoxType>())
801
- TODO (currentLocation, " OMPD_target_data MapOperand BoxType" );
802
-
849
+ checkType (mapOp.getLoc (), mapOp.getType ());
803
850
mapOperands.push_back (mapOp);
804
851
mapTypes.push_back (mapTypeAttr);
805
852
}
806
853
};
807
854
855
+ auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) {
856
+ genObjectList (useDeviceClause, converter, operands);
857
+ for (auto &operand : operands) {
858
+ checkType (operand.getLoc (), operand.getType ());
859
+ useDeviceTypes.push_back (operand.getType ());
860
+ useDeviceLocs.push_back (operand.getLoc ());
861
+ }
862
+ for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v ) {
863
+ Fortran::semantics::Symbol *sym = getOmpObjectSymbol (ompObject);
864
+ useDeviceSymbols.push_back (sym);
865
+ }
866
+ };
867
+
808
868
for (const Fortran::parser::OmpClause &clause : opClauseList.v ) {
809
869
mlir::Location currentLocation = converter.genLocation (clause.source );
810
870
if (const auto &ifClause =
@@ -825,19 +885,21 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
825
885
deviceOperand =
826
886
fir::getBase (converter.genExprValue (*deviceExpr, stmtCtx));
827
887
}
828
- } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
829
- &clause.u )) {
830
- TODO (currentLocation, " OMPD_target Use Device Ptr" );
831
- } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
832
- &clause.u )) {
833
- TODO (currentLocation, " OMPD_target Use Device Addr" );
834
888
} else if (const auto &threadLmtClause =
835
889
std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
836
890
&clause.u )) {
837
891
threadLmtOperand = fir::getBase (converter.genExprValue (
838
892
*Fortran::semantics::GetExpr (threadLmtClause->v ), stmtCtx));
839
893
} else if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u )) {
840
894
nowaitAttr = firOpBuilder.getUnitAttr ();
895
+ } else if (const auto &devPtrClause =
896
+ std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
897
+ &clause.u )) {
898
+ addUseDeviceClause (devPtrClause->v , devicePtrOperands);
899
+ } else if (const auto &devAddrClause =
900
+ std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
901
+ &clause.u )) {
902
+ addUseDeviceClause (devAddrClause->v , deviceAddrOperands);
841
903
} else if (const auto &mapClause =
842
904
std::get_if<Fortran::parser::OmpClause::Map>(&clause.u )) {
843
905
addMapClause (mapClause, currentLocation);
@@ -859,9 +921,10 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
859
921
createBodyOfOp (targetOp, converter, currentLocation, *eval, &opClauseList);
860
922
} else if (directive == llvm::omp::Directive::OMPD_target_data) {
861
923
auto dataOp = firOpBuilder.create <omp::DataOp>(
862
- currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand,
863
- useDeviceAddrOperand, mapOperands, mapTypesArrayAttr);
864
- createBodyOfOp (dataOp, converter, currentLocation, *eval, &opClauseList);
924
+ currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
925
+ deviceAddrOperands, mapOperands, mapTypesArrayAttr);
926
+ createBodyOfTargetOp (converter, dataOp, useDeviceTypes, useDeviceLocs,
927
+ useDeviceSymbols, currentLocation);
865
928
} else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
866
929
firOpBuilder.create <omp::EnterDataOp>(currentLocation, ifClauseOperand,
867
930
deviceOperand, nowaitAttr,
@@ -1157,7 +1220,17 @@ genOMP(Fortran::lower::AbstractConverter &converter,
1157
1220
continue ;
1158
1221
} else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u )) {
1159
1222
// Map clause is exclusive to Target Data directives. It is handled
1160
- // as part of the DataOp creation.
1223
+ // as part of the TargetOp creation.
1224
+ continue ;
1225
+ } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
1226
+ &clause.u )) {
1227
+ // UseDevicePtr clause is exclusive to Target Data directives. It is
1228
+ // handled as part of the TargetOp creation.
1229
+ continue ;
1230
+ } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
1231
+ &clause.u )) {
1232
+ // UseDeviceAddr clause is exclusive to Target Data directives. It is
1233
+ // handled as part of the TargetOp creation.
1161
1234
continue ;
1162
1235
} else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
1163
1236
&clause.u )) {
0 commit comments