@@ -78,9 +78,7 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
78
78
static void gatherFuncAndVarSyms (
79
79
const Fortran::parser::OmpObjectList &objList,
80
80
mlir::omp::DeclareTargetCaptureClause clause,
81
- llvm::SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause,
82
- Fortran::semantics::Symbol>>
83
- &symbolAndClause) {
81
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
84
82
for (const Fortran::parser::OmpObject &ompObject : objList.v ) {
85
83
Fortran::common::visit (
86
84
Fortran::common::visitors{
@@ -2453,6 +2451,71 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
2453
2451
reductionDeclSymbols));
2454
2452
}
2455
2453
2454
+ // / Extract the list of function and variable symbols affected by the given
2455
+ // / 'declare target' directive and return the intended device type for them.
2456
+ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo (
2457
+ Fortran::lower::AbstractConverter &converter,
2458
+ Fortran::lower::pft::Evaluation &eval,
2459
+ const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
2460
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
2461
+
2462
+ // The default capture type
2463
+ mlir::omp::DeclareTargetDeviceType deviceType =
2464
+ mlir::omp::DeclareTargetDeviceType::any;
2465
+ const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
2466
+ declareTargetConstruct.t );
2467
+
2468
+ if (const auto *objectList{
2469
+ Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u )}) {
2470
+ // Case: declare target(func, var1, var2)
2471
+ gatherFuncAndVarSyms (*objectList, mlir::omp::DeclareTargetCaptureClause::to,
2472
+ symbolAndClause);
2473
+ } else if (const auto *clauseList{
2474
+ Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
2475
+ spec.u )}) {
2476
+ if (clauseList->v .empty ()) {
2477
+ // Case: declare target, implicit capture of function
2478
+ symbolAndClause.emplace_back (
2479
+ mlir::omp::DeclareTargetCaptureClause::to,
2480
+ eval.getOwningProcedure ()->getSubprogramSymbol ());
2481
+ }
2482
+
2483
+ ClauseProcessor cp (converter, *clauseList);
2484
+ cp.processTo (symbolAndClause);
2485
+ cp.processLink (symbolAndClause);
2486
+ cp.processDeviceType (deviceType);
2487
+ cp.processTODO <Fortran::parser::OmpClause::Indirect>(
2488
+ converter.getCurrentLocation (),
2489
+ llvm::omp::Directive::OMPD_declare_target);
2490
+ }
2491
+
2492
+ return deviceType;
2493
+ }
2494
+
2495
+ static std::optional<mlir::omp::DeclareTargetDeviceType>
2496
+ getDeclareTargetFunctionDevice (
2497
+ Fortran::lower::AbstractConverter &converter,
2498
+ Fortran::lower::pft::Evaluation &eval,
2499
+ const Fortran::parser::OpenMPDeclareTargetConstruct
2500
+ &declareTargetConstruct) {
2501
+ llvm::SmallVector<DeclareTargetCapturePair, 0 > symbolAndClause;
2502
+ mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo (
2503
+ converter, eval, declareTargetConstruct, symbolAndClause);
2504
+
2505
+ // Return the device type only if at least one of the targets for the
2506
+ // directive is a function or subroutine
2507
+ mlir::ModuleOp mod = converter.getFirOpBuilder ().getModule ();
2508
+ for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
2509
+ mlir::Operation *op = mod.lookupSymbol (
2510
+ converter.mangleName (std::get<Fortran::semantics::Symbol>(symClause)));
2511
+
2512
+ if (mlir::isa<mlir::func::FuncOp>(op))
2513
+ return deviceType;
2514
+ }
2515
+
2516
+ return std::nullopt;
2517
+ }
2518
+
2456
2519
// ===----------------------------------------------------------------------===//
2457
2520
// genOMP() Code generation helper functions
2458
2521
// ===----------------------------------------------------------------------===//
@@ -2973,35 +3036,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
2973
3036
&declareTargetConstruct) {
2974
3037
llvm::SmallVector<DeclareTargetCapturePair, 0 > symbolAndClause;
2975
3038
mlir::ModuleOp mod = converter.getFirOpBuilder ().getModule ();
2976
-
2977
- // The default capture type
2978
- mlir::omp::DeclareTargetDeviceType deviceType =
2979
- mlir::omp::DeclareTargetDeviceType::any;
2980
- const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
2981
- declareTargetConstruct.t );
2982
- if (const auto *objectList{
2983
- Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u )}) {
2984
- // Case: declare target(func, var1, var2)
2985
- gatherFuncAndVarSyms (*objectList, mlir::omp::DeclareTargetCaptureClause::to,
2986
- symbolAndClause);
2987
- } else if (const auto *clauseList{
2988
- Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
2989
- spec.u )}) {
2990
- if (clauseList->v .empty ()) {
2991
- // Case: declare target, implicit capture of function
2992
- symbolAndClause.emplace_back (
2993
- mlir::omp::DeclareTargetCaptureClause::to,
2994
- eval.getOwningProcedure ()->getSubprogramSymbol ());
2995
- }
2996
-
2997
- ClauseProcessor cp (converter, *clauseList);
2998
- cp.processTo (symbolAndClause);
2999
- cp.processLink (symbolAndClause);
3000
- cp.processDeviceType (deviceType);
3001
- cp.processTODO <Fortran::parser::OmpClause::Indirect>(
3002
- converter.getCurrentLocation (),
3003
- llvm::omp::Directive::OMPD_declare_target);
3004
- }
3039
+ mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo (
3040
+ converter, eval, declareTargetConstruct, symbolAndClause);
3005
3041
3006
3042
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
3007
3043
mlir::Operation *op = mod.lookupSymbol (
@@ -3130,7 +3166,10 @@ void Fortran::lower::genOpenMPDeclarativeConstruct(
3130
3166
},
3131
3167
[&](const Fortran::parser::OpenMPRequiresConstruct
3132
3168
&requiresConstruct) {
3133
- TODO (converter.getCurrentLocation (), " OpenMPRequiresConstruct" );
3169
+ // Requires directives are gathered and processed in semantics and
3170
+ // then combined in the lowering bridge before triggering codegen
3171
+ // just once. Hence, there is no need to lower each individual
3172
+ // occurrence here.
3134
3173
},
3135
3174
[&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
3136
3175
// The directive is lowered when instantiating the variable to
@@ -3444,3 +3483,72 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
3444
3483
}
3445
3484
}
3446
3485
}
3486
+
3487
+ bool Fortran::lower::isOpenMPTargetConstruct (
3488
+ const Fortran::parser::OpenMPConstruct &omp) {
3489
+ llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
3490
+ if (const auto *block =
3491
+ std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u )) {
3492
+ const auto &begin =
3493
+ std::get<Fortran::parser::OmpBeginBlockDirective>(block->t );
3494
+ dir = std::get<Fortran::parser::OmpBlockDirective>(begin.t ).v ;
3495
+ } else if (const auto *loop =
3496
+ std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u )) {
3497
+ const auto &begin =
3498
+ std::get<Fortran::parser::OmpBeginLoopDirective>(loop->t );
3499
+ dir = std::get<Fortran::parser::OmpLoopDirective>(begin.t ).v ;
3500
+ }
3501
+ return llvm::omp::allTargetSet.test (dir);
3502
+ }
3503
+
3504
+ bool Fortran::lower::isOpenMPDeviceDeclareTarget (
3505
+ Fortran::lower::AbstractConverter &converter,
3506
+ Fortran::lower::pft::Evaluation &eval,
3507
+ const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
3508
+ return std::visit (
3509
+ Fortran::common::visitors{
3510
+ [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
3511
+ mlir::omp::DeclareTargetDeviceType targetType =
3512
+ getDeclareTargetFunctionDevice (converter, eval, ompReq)
3513
+ .value_or (mlir::omp::DeclareTargetDeviceType::host);
3514
+ return targetType != mlir::omp::DeclareTargetDeviceType::host;
3515
+ },
3516
+ [&](const auto &) { return false ; },
3517
+ },
3518
+ ompDecl.u );
3519
+ }
3520
+
3521
+ void Fortran::lower::genOpenMPRequires (
3522
+ mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
3523
+ using MlirRequires = mlir::omp::ClauseRequires;
3524
+ using SemaRequires = Fortran::semantics::WithOmpDeclarative::RequiresFlag;
3525
+
3526
+ if (auto offloadMod =
3527
+ llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
3528
+ Fortran::semantics::WithOmpDeclarative::RequiresFlags semaFlags;
3529
+ if (symbol) {
3530
+ Fortran::common::visit (
3531
+ [&](const auto &details) {
3532
+ if constexpr (std::is_base_of_v<
3533
+ Fortran::semantics::WithOmpDeclarative,
3534
+ std::decay_t <decltype (details)>>) {
3535
+ if (details.has_ompRequires ())
3536
+ semaFlags = *details.ompRequires ();
3537
+ }
3538
+ },
3539
+ symbol->details ());
3540
+ }
3541
+
3542
+ MlirRequires mlirFlags = MlirRequires::none;
3543
+ if (semaFlags.test (SemaRequires::ReverseOffload))
3544
+ mlirFlags = mlirFlags | MlirRequires::reverse_offload;
3545
+ if (semaFlags.test (SemaRequires::UnifiedAddress))
3546
+ mlirFlags = mlirFlags | MlirRequires::unified_address;
3547
+ if (semaFlags.test (SemaRequires::UnifiedSharedMemory))
3548
+ mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
3549
+ if (semaFlags.test (SemaRequires::DynamicAllocators))
3550
+ mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
3551
+
3552
+ offloadMod.setRequires (mlirFlags);
3553
+ }
3554
+ }
0 commit comments