Skip to content

Commit 29aa749

Browse files
committed
[OpenMP][Flang][MLIR] Lowering of OpenMP requires directive from parse tree to MLIR
This patch implements the lowering of the OpenMP 'requires' directive from Flang parse tree to MLIR attributes attached to the top-level module. Target-related 'requires' clauses are gathered and combined for each top-level unit during semantics. Lastly, a single module-level `omp.requires` attribute is attached to the MLIR module with that information at the end of the process. The `atomic_default_mem_order` clause is not addressed by this patch, but rather it will come as a separate patch and follow a different approach. Depends on D147214, D150328, D150329 and D157983. Differential Revision: https://reviews.llvm.org/D147218
1 parent 094a63a commit 29aa749

File tree

7 files changed

+254
-34
lines changed

7 files changed

+254
-34
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ struct OmpEndLoopDirective;
3434
struct OmpClauseList;
3535
} // namespace parser
3636

37+
namespace semantics {
38+
class Symbol;
39+
} // namespace semantics
40+
3741
namespace lower {
3842

3943
class AbstractConverter;
@@ -62,6 +66,13 @@ fir::ConvertOp getConvertFromReductionOp(mlir::Operation *, mlir::Value);
6266
void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
6367
mlir::Value, fir::ConvertOp * = nullptr);
6468
void removeStoreOp(mlir::Operation *, mlir::Value);
69+
70+
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
71+
bool isOpenMPDeviceDeclareTarget(Fortran::lower::AbstractConverter &,
72+
Fortran::lower::pft::Evaluation &,
73+
const parser::OpenMPDeclarativeConstruct &);
74+
void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *);
75+
6576
} // namespace lower
6677
} // namespace Fortran
6778

flang/lib/Lower/Bridge.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "flang/Parser/parse-tree.h"
5151
#include "flang/Runtime/iostat.h"
5252
#include "flang/Semantics/runtime-type-info.h"
53+
#include "flang/Semantics/symbol.h"
5354
#include "flang/Semantics/tools.h"
5455
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
5556
#include "mlir/IR/PatternMatch.h"
@@ -294,20 +295,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
294295
// that they are available before lowering any function that may use
295296
// them.
296297
bool hasMainProgram = false;
298+
const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
297299
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
298300
std::visit(Fortran::common::visitors{
299301
[&](Fortran::lower::pft::FunctionLikeUnit &f) {
300302
if (f.isMainProgram())
301303
hasMainProgram = true;
302304
declareFunction(f);
305+
if (!globalOmpRequiresSymbol)
306+
globalOmpRequiresSymbol = f.getScope().symbol();
303307
},
304308
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
305309
lowerModuleDeclScope(m);
306310
for (Fortran::lower::pft::FunctionLikeUnit &f :
307311
m.nestedFunctions)
308312
declareFunction(f);
309313
},
310-
[&](Fortran::lower::pft::BlockDataUnit &b) {},
314+
[&](Fortran::lower::pft::BlockDataUnit &b) {
315+
if (!globalOmpRequiresSymbol)
316+
globalOmpRequiresSymbol = b.symTab.symbol();
317+
},
311318
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
312319
},
313320
u);
@@ -352,6 +359,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
352359
});
353360

354361
finalizeOpenACCLowering();
362+
finalizeOpenMPLowering(globalOmpRequiresSymbol);
355363
}
356364

357365
/// Declare a function.
@@ -2347,10 +2355,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
23472355

23482356
localSymbols.popScope();
23492357
builder->restoreInsertionPoint(insertPt);
2358+
2359+
// Register if a target region was found
2360+
ompDeviceCodeFound =
2361+
ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
23502362
}
23512363

23522364
void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
23532365
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
2366+
// Register if a declare target construct intended for a target device was
2367+
// found
2368+
ompDeviceCodeFound =
2369+
ompDeviceCodeFound ||
2370+
Fortran::lower::isOpenMPDeviceDeclareTarget(*this, getEval(), ompDecl);
23542371
genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
23552372
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
23562373
genFIR(e);
@@ -4758,6 +4775,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
47584775
accRoutineInfos);
47594776
}
47604777

4778+
/// Performing OpenMP lowering actions that were deferred to the end of
4779+
/// lowering.
4780+
void finalizeOpenMPLowering(
4781+
const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
4782+
// Set the module attribute related to OpenMP requires directives
4783+
if (ompDeviceCodeFound)
4784+
Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
4785+
globalOmpRequiresSymbol);
4786+
}
4787+
47614788
//===--------------------------------------------------------------------===//
47624789

47634790
Fortran::lower::LoweringBridge &bridge;
@@ -4804,6 +4831,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
48044831

48054832
/// Deferred OpenACC routine attachment.
48064833
Fortran::lower::AccRoutineInfoMappingList accRoutineInfos;
4834+
4835+
/// Whether an OpenMP target region or declare target function/subroutine
4836+
/// intended for device offloading has been detected
4837+
bool ompDeviceCodeFound = false;
48074838
};
48084839

48094840
} // namespace

flang/lib/Lower/OpenMP.cpp

Lines changed: 141 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
7878
static void gatherFuncAndVarSyms(
7979
const Fortran::parser::OmpObjectList &objList,
8080
mlir::omp::DeclareTargetCaptureClause clause,
81-
llvm::SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause,
82-
Fortran::semantics::Symbol>>
83-
&symbolAndClause) {
81+
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
8482
for (const Fortran::parser::OmpObject &ompObject : objList.v) {
8583
Fortran::common::visit(
8684
Fortran::common::visitors{
@@ -2453,6 +2451,71 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
24532451
reductionDeclSymbols));
24542452
}
24552453

2454+
/// Extract the list of function and variable symbols affected by the given
2455+
/// 'declare target' directive and return the intended device type for them.
2456+
static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
2457+
Fortran::lower::AbstractConverter &converter,
2458+
Fortran::lower::pft::Evaluation &eval,
2459+
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
2460+
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
2461+
2462+
// The default capture type
2463+
mlir::omp::DeclareTargetDeviceType deviceType =
2464+
mlir::omp::DeclareTargetDeviceType::any;
2465+
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
2466+
declareTargetConstruct.t);
2467+
2468+
if (const auto *objectList{
2469+
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
2470+
// Case: declare target(func, var1, var2)
2471+
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
2472+
symbolAndClause);
2473+
} else if (const auto *clauseList{
2474+
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
2475+
spec.u)}) {
2476+
if (clauseList->v.empty()) {
2477+
// Case: declare target, implicit capture of function
2478+
symbolAndClause.emplace_back(
2479+
mlir::omp::DeclareTargetCaptureClause::to,
2480+
eval.getOwningProcedure()->getSubprogramSymbol());
2481+
}
2482+
2483+
ClauseProcessor cp(converter, *clauseList);
2484+
cp.processTo(symbolAndClause);
2485+
cp.processLink(symbolAndClause);
2486+
cp.processDeviceType(deviceType);
2487+
cp.processTODO<Fortran::parser::OmpClause::Indirect>(
2488+
converter.getCurrentLocation(),
2489+
llvm::omp::Directive::OMPD_declare_target);
2490+
}
2491+
2492+
return deviceType;
2493+
}
2494+
2495+
static std::optional<mlir::omp::DeclareTargetDeviceType>
2496+
getDeclareTargetFunctionDevice(
2497+
Fortran::lower::AbstractConverter &converter,
2498+
Fortran::lower::pft::Evaluation &eval,
2499+
const Fortran::parser::OpenMPDeclareTargetConstruct
2500+
&declareTargetConstruct) {
2501+
llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
2502+
mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
2503+
converter, eval, declareTargetConstruct, symbolAndClause);
2504+
2505+
// Return the device type only if at least one of the targets for the
2506+
// directive is a function or subroutine
2507+
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
2508+
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
2509+
mlir::Operation *op = mod.lookupSymbol(
2510+
converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
2511+
2512+
if (mlir::isa<mlir::func::FuncOp>(op))
2513+
return deviceType;
2514+
}
2515+
2516+
return std::nullopt;
2517+
}
2518+
24562519
//===----------------------------------------------------------------------===//
24572520
// genOMP() Code generation helper functions
24582521
//===----------------------------------------------------------------------===//
@@ -2973,35 +3036,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
29733036
&declareTargetConstruct) {
29743037
llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
29753038
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
2976-
2977-
// The default capture type
2978-
mlir::omp::DeclareTargetDeviceType deviceType =
2979-
mlir::omp::DeclareTargetDeviceType::any;
2980-
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
2981-
declareTargetConstruct.t);
2982-
if (const auto *objectList{
2983-
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
2984-
// Case: declare target(func, var1, var2)
2985-
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
2986-
symbolAndClause);
2987-
} else if (const auto *clauseList{
2988-
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
2989-
spec.u)}) {
2990-
if (clauseList->v.empty()) {
2991-
// Case: declare target, implicit capture of function
2992-
symbolAndClause.emplace_back(
2993-
mlir::omp::DeclareTargetCaptureClause::to,
2994-
eval.getOwningProcedure()->getSubprogramSymbol());
2995-
}
2996-
2997-
ClauseProcessor cp(converter, *clauseList);
2998-
cp.processTo(symbolAndClause);
2999-
cp.processLink(symbolAndClause);
3000-
cp.processDeviceType(deviceType);
3001-
cp.processTODO<Fortran::parser::OmpClause::Indirect>(
3002-
converter.getCurrentLocation(),
3003-
llvm::omp::Directive::OMPD_declare_target);
3004-
}
3039+
mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
3040+
converter, eval, declareTargetConstruct, symbolAndClause);
30053041

30063042
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
30073043
mlir::Operation *op = mod.lookupSymbol(
@@ -3130,7 +3166,10 @@ void Fortran::lower::genOpenMPDeclarativeConstruct(
31303166
},
31313167
[&](const Fortran::parser::OpenMPRequiresConstruct
31323168
&requiresConstruct) {
3133-
TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct");
3169+
// Requires directives are gathered and processed in semantics and
3170+
// then combined in the lowering bridge before triggering codegen
3171+
// just once. Hence, there is no need to lower each individual
3172+
// occurrence here.
31343173
},
31353174
[&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
31363175
// The directive is lowered when instantiating the variable to
@@ -3444,3 +3483,72 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
34443483
}
34453484
}
34463485
}
3486+
3487+
bool Fortran::lower::isOpenMPTargetConstruct(
3488+
const Fortran::parser::OpenMPConstruct &omp) {
3489+
llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
3490+
if (const auto *block =
3491+
std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) {
3492+
const auto &begin =
3493+
std::get<Fortran::parser::OmpBeginBlockDirective>(block->t);
3494+
dir = std::get<Fortran::parser::OmpBlockDirective>(begin.t).v;
3495+
} else if (const auto *loop =
3496+
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)) {
3497+
const auto &begin =
3498+
std::get<Fortran::parser::OmpBeginLoopDirective>(loop->t);
3499+
dir = std::get<Fortran::parser::OmpLoopDirective>(begin.t).v;
3500+
}
3501+
return llvm::omp::allTargetSet.test(dir);
3502+
}
3503+
3504+
bool Fortran::lower::isOpenMPDeviceDeclareTarget(
3505+
Fortran::lower::AbstractConverter &converter,
3506+
Fortran::lower::pft::Evaluation &eval,
3507+
const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
3508+
return std::visit(
3509+
Fortran::common::visitors{
3510+
[&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
3511+
mlir::omp::DeclareTargetDeviceType targetType =
3512+
getDeclareTargetFunctionDevice(converter, eval, ompReq)
3513+
.value_or(mlir::omp::DeclareTargetDeviceType::host);
3514+
return targetType != mlir::omp::DeclareTargetDeviceType::host;
3515+
},
3516+
[&](const auto &) { return false; },
3517+
},
3518+
ompDecl.u);
3519+
}
3520+
3521+
void Fortran::lower::genOpenMPRequires(
3522+
mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
3523+
using MlirRequires = mlir::omp::ClauseRequires;
3524+
using SemaRequires = Fortran::semantics::WithOmpDeclarative::RequiresFlag;
3525+
3526+
if (auto offloadMod =
3527+
llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
3528+
Fortran::semantics::WithOmpDeclarative::RequiresFlags semaFlags;
3529+
if (symbol) {
3530+
Fortran::common::visit(
3531+
[&](const auto &details) {
3532+
if constexpr (std::is_base_of_v<
3533+
Fortran::semantics::WithOmpDeclarative,
3534+
std::decay_t<decltype(details)>>) {
3535+
if (details.has_ompRequires())
3536+
semaFlags = *details.ompRequires();
3537+
}
3538+
},
3539+
symbol->details());
3540+
}
3541+
3542+
MlirRequires mlirFlags = MlirRequires::none;
3543+
if (semaFlags.test(SemaRequires::ReverseOffload))
3544+
mlirFlags = mlirFlags | MlirRequires::reverse_offload;
3545+
if (semaFlags.test(SemaRequires::UnifiedAddress))
3546+
mlirFlags = mlirFlags | MlirRequires::unified_address;
3547+
if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
3548+
mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
3549+
if (semaFlags.test(SemaRequires::DynamicAllocators))
3550+
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
3551+
3552+
offloadMod.setRequires(mlirFlags);
3553+
}
3554+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
! This test checks the lowering of REQUIRES inside of an unnamed BLOCK DATA.
2+
! The symbol of the `symTab` scope of the `BlockDataUnit` PFT node is null in
3+
! this case, resulting in the inability to store the REQUIRES flags gathered in
4+
! it.
5+
6+
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
7+
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
8+
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
9+
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
10+
! XFAIL: *
11+
12+
!CHECK: module attributes {
13+
!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
14+
block data
15+
!$omp requires unified_shared_memory
16+
integer :: x
17+
common /block/ x
18+
data x / 10 /
19+
end
20+
21+
subroutine f
22+
!$omp declare target
23+
end subroutine f
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
2+
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
3+
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
4+
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
5+
6+
! This test checks the lowering of requires into MLIR
7+
8+
!CHECK: module attributes {
9+
!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
10+
block data init
11+
!$omp requires unified_shared_memory
12+
integer :: x
13+
common /block/ x
14+
data x / 10 /
15+
end
16+
17+
subroutine f
18+
!$omp declare target
19+
end subroutine f
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
2+
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
3+
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
4+
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
5+
6+
! This test checks that requires lowering into MLIR skips creating the
7+
! omp.requires attribute with target-related clauses if there are no device
8+
! functions in the compilation unit
9+
10+
!CHECK: module attributes {
11+
!CHECK-NOT: omp.requires
12+
program requires
13+
!$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
14+
end program requires

flang/test/Lower/OpenMP/requires.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
2+
! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
3+
! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
4+
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
5+
6+
! This test checks the lowering of requires into MLIR
7+
8+
!CHECK: module attributes {
9+
!CHECK-SAME: omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory>
10+
program requires
11+
!$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
12+
!$omp target
13+
!$omp end target
14+
end program requires

0 commit comments

Comments
 (0)