Skip to content

Commit 85c9855

Browse files
committed
Moved check for whether a 'target teams loop' construct can potentially
be considered equivalent to 'target teams distribute parallel for' from CodeGen to Sema.
1 parent bb118c4 commit 85c9855

File tree

10 files changed

+101
-85
lines changed

10 files changed

+101
-85
lines changed

clang/include/clang/AST/StmtOpenMP.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6109,6 +6109,8 @@ class OMPTeamsGenericLoopDirective final : public OMPLoopDirective {
61096109
class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61106110
friend class ASTStmtReader;
61116111
friend class OMPExecutableDirective;
6112+
/// true if loop directive's associated loop can be a parallel for.
6113+
bool CanBeParallelFor = false;
61126114
/// Build directive with the given start and end location.
61136115
///
61146116
/// \param StartLoc Starting location of the directive kind.
@@ -6131,6 +6133,9 @@ class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61316133
llvm::omp::OMPD_target_teams_loop, SourceLocation(),
61326134
SourceLocation(), CollapsedNum) {}
61336135

6136+
/// Set whether associated loop can be a parallel for.
6137+
void setCanBeParallelFor(bool ParFor) { CanBeParallelFor = ParFor; }
6138+
61346139
public:
61356140
/// Creates directive with a list of \p Clauses.
61366141
///
@@ -6145,7 +6150,7 @@ class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61456150
static OMPTargetTeamsGenericLoopDirective *
61466151
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
61476152
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses,
6148-
Stmt *AssociatedStmt, const HelperExprs &Exprs);
6153+
Stmt *AssociatedStmt, const HelperExprs &Exprs, bool CanBeParallelFor);
61496154

61506155
/// Creates an empty directive with the place
61516156
/// for \a NumClauses clauses.
@@ -6159,6 +6164,10 @@ class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61596164
unsigned CollapsedNum,
61606165
EmptyShell);
61616166

6167+
/// Return true if current loop directive's associated loop can be a
6168+
/// parallel for.
6169+
bool canBeParallelFor() const { return CanBeParallelFor; }
6170+
61626171
static bool classof(const Stmt *T) {
61636172
return T->getStmtClass() == OMPTargetTeamsGenericLoopDirectiveClass;
61646173
}

clang/include/clang/Sema/Sema.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10219,6 +10219,10 @@ class Sema final {
1021910219
bool isInstantiationRecord() const;
1022010220
};
1022110221

10222+
/// [target] teams loop is equivalent to parallel for if associated loop
10223+
/// nest meets certain critera.
10224+
bool teamsLoopCanBeParallelFor(Stmt *Astmt);
10225+
1022210226
/// A stack object to be created when performing template
1022310227
/// instantiation.
1022410228
///

clang/lib/AST/StmtOpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2431,7 +2431,7 @@ OMPTeamsGenericLoopDirective::CreateEmpty(const ASTContext &C,
24312431
OMPTargetTeamsGenericLoopDirective *OMPTargetTeamsGenericLoopDirective::Create(
24322432
const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
24332433
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
2434-
const HelperExprs &Exprs) {
2434+
const HelperExprs &Exprs, bool CanBeParallelFor) {
24352435
auto *Dir = createDirective<OMPTargetTeamsGenericLoopDirective>(
24362436
C, Clauses, AssociatedStmt,
24372437
numLoopChildren(CollapsedNum, OMPD_target_teams_loop), StartLoc, EndLoc,
@@ -2473,6 +2473,7 @@ OMPTargetTeamsGenericLoopDirective *OMPTargetTeamsGenericLoopDirective::Create(
24732473
Dir->setCombinedNextUpperBound(Exprs.DistCombinedFields.NUB);
24742474
Dir->setCombinedDistCond(Exprs.DistCombinedFields.DistCond);
24752475
Dir->setCombinedParForInDistCond(Exprs.DistCombinedFields.ParForInDistCond);
2476+
Dir->setCanBeParallelFor(CanBeParallelFor);
24762477
return Dir;
24772478
}
24782479

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,9 @@ static bool supportsSPMDExecutionMode(CodeGenModule &CGM,
661661
case OMPD_target_teams_loop:
662662
// Whether this is true or not depends on how the directive will
663663
// eventually be emitted.
664-
return CGM.teamsLoopCanBeParallelFor(D);
664+
if (auto *TTLD = dyn_cast<OMPTargetTeamsGenericLoopDirective>(&D))
665+
return TTLD->canBeParallelFor();
666+
return false;
665667
case OMPD_parallel:
666668
case OMPD_for:
667669
case OMPD_parallel_for:

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,9 +1432,12 @@ void CodeGenFunction::EmitOMPReductionClauseFinal(
14321432
*this, D.getBeginLoc(),
14331433
isOpenMPWorksharingDirective(D.getDirectiveKind()));
14341434
}
1435+
bool TeamsLoopCanBeParallel = false;
1436+
if (auto *TTLD = dyn_cast<OMPTargetTeamsGenericLoopDirective>(&D))
1437+
TeamsLoopCanBeParallel = TTLD->canBeParallelFor();
14351438
bool WithNowait = D.getSingleClause<OMPNowaitClause>() ||
14361439
isOpenMPParallelDirective(D.getDirectiveKind()) ||
1437-
CGM.teamsLoopCanBeParallelFor(D) ||
1440+
TeamsLoopCanBeParallel ||
14381441
ReductionKind == OMPD_simd;
14391442
bool SimpleReduction = ReductionKind == OMPD_simd;
14401443
// Emit nowait reduction if nowait clause is present or directive is a
@@ -8014,7 +8017,7 @@ static void emitTargetTeamsGenericLoopRegionAsDistribute(
80148017
void CodeGenFunction::EmitOMPTargetTeamsGenericLoopDirective(
80158018
const OMPTargetTeamsGenericLoopDirective &S) {
80168019
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
8017-
if (CGF.CGM.teamsLoopCanBeParallelFor(S))
8020+
if (S.canBeParallelFor())
80188021
emitTargetTeamsGenericLoopRegionAsParallel(CGF, Action, S);
80198022
else
80208023
emitTargetTeamsGenericLoopRegionAsDistribute(CGF, Action, S);
@@ -8027,7 +8030,7 @@ void CodeGenFunction::EmitOMPTargetTeamsGenericLoopDeviceFunction(
80278030
const OMPTargetTeamsGenericLoopDirective &S) {
80288031
// Emit SPMD target parallel loop region as a standalone region.
80298032
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
8030-
if (CGF.CGM.teamsLoopCanBeParallelFor(S))
8033+
if (S.canBeParallelFor())
80318034
emitTargetTeamsGenericLoopRegionAsParallel(CGF, Action, S);
80328035
else
80338036
emitTargetTeamsGenericLoopRegionAsDistribute(CGF, Action, S);

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -7579,82 +7579,6 @@ void CodeGenModule::printPostfixForExternalizedDecl(llvm::raw_ostream &OS,
75797579
}
75807580
}
75817581

7582-
namespace {
7583-
/// A 'teams loop' with a nested 'loop bind(parallel)' or generic function
7584-
/// call in the associated loop-nest cannot be a 'parllel for'.
7585-
class TeamsLoopChecker final : public ConstStmtVisitor<TeamsLoopChecker> {
7586-
public:
7587-
TeamsLoopChecker(CodeGenModule &CGM)
7588-
: CGM(CGM), TeamsLoopCanBeParallelFor{true} {}
7589-
bool teamsLoopCanBeParallelFor() const { return TeamsLoopCanBeParallelFor; }
7590-
// Is there a nested OpenMP loop bind(parallel)
7591-
void VisitOMPExecutableDirective(const OMPExecutableDirective *D) {
7592-
if (D->getDirectiveKind() == llvm::omp::Directive::OMPD_loop) {
7593-
if (const auto *C = D->getSingleClause<OMPBindClause>())
7594-
if (C->getBindKind() == OMPC_BIND_parallel) {
7595-
TeamsLoopCanBeParallelFor = false;
7596-
// No need to continue visiting any more
7597-
return;
7598-
}
7599-
}
7600-
for (const Stmt *Child : D->children())
7601-
if (Child)
7602-
Visit(Child);
7603-
}
7604-
7605-
void VisitCallExpr(const CallExpr *C) {
7606-
// Function calls inhibit parallel loop translation of 'target teams loop'
7607-
// unless the assume-no-nested-parallelism flag has been specified.
7608-
// OpenMP API runtime library calls do not inhibit parallel loop
7609-
// translation, regardless of the assume-no-nested-parallelism.
7610-
if (C) {
7611-
bool IsOpenMPAPI = false;
7612-
auto *FD = dyn_cast_or_null<FunctionDecl>(C->getCalleeDecl());
7613-
if (FD) {
7614-
std::string Name = FD->getNameInfo().getAsString();
7615-
IsOpenMPAPI = Name.find("omp_") == 0;
7616-
}
7617-
TeamsLoopCanBeParallelFor =
7618-
IsOpenMPAPI || CGM.getLangOpts().OpenMPNoNestedParallelism;
7619-
if (!TeamsLoopCanBeParallelFor)
7620-
return;
7621-
}
7622-
for (const Stmt *Child : C->children())
7623-
if (Child)
7624-
Visit(Child);
7625-
}
7626-
7627-
void VisitCapturedStmt(const CapturedStmt *S) {
7628-
if (!S)
7629-
return;
7630-
Visit(S->getCapturedDecl()->getBody());
7631-
}
7632-
7633-
void VisitStmt(const Stmt *S) {
7634-
if (!S)
7635-
return;
7636-
for (const Stmt *Child : S->children())
7637-
if (Child)
7638-
Visit(Child);
7639-
}
7640-
7641-
private:
7642-
CodeGenModule &CGM;
7643-
bool TeamsLoopCanBeParallelFor;
7644-
};
7645-
} // namespace
7646-
7647-
/// Determine if 'teams loop' can be emitted using 'parallel for'.
7648-
bool CodeGenModule::teamsLoopCanBeParallelFor(const OMPExecutableDirective &D) {
7649-
if (D.getDirectiveKind() != llvm::omp::Directive::OMPD_target_teams_loop)
7650-
return false;
7651-
assert(D.hasAssociatedStmt() &&
7652-
"Loop directive must have associated statement.");
7653-
TeamsLoopChecker Checker(*this);
7654-
Checker.Visit(D.getAssociatedStmt());
7655-
return Checker.teamsLoopCanBeParallelFor();
7656-
}
7657-
76587582
void CodeGenModule::emitTargetTeamsLoopCodegenStatus(
76597583
std::string StatusMsg, const OMPExecutableDirective &D, bool IsDevice) {
76607584
#ifndef NDEBUG

clang/lib/CodeGen/CodeGenModule.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,8 +1532,6 @@ class CodeGenModule : public CodeGenTypeCache {
15321532
LValueBaseInfo *BaseInfo = nullptr,
15331533
TBAAAccessInfo *TBAAInfo = nullptr);
15341534
bool stopAutoInit();
1535-
/// Determine if 'teams loop' can be emitted using 'parallel for'.
1536-
bool teamsLoopCanBeParallelFor(const OMPExecutableDirective &D);
15371535

15381536
/// Print the postfix for externalized static variable or kernels for single
15391537
/// source offloading languages CUDA and HIP. The unique postfix is created

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6137,6 +6137,78 @@ processImplicitMapsWithDefaultMappers(Sema &S, DSAStackTy *Stack,
61376137
}
61386138
}
61396139

6140+
namespace {
6141+
/// A 'teams loop' with a nested 'loop bind(parallel)' or generic function
6142+
/// call in the associated loop-nest cannot be a 'parallel for'.
6143+
class TeamsLoopChecker final
6144+
: public ConstStmtVisitor<TeamsLoopChecker> {
6145+
Sema &SemaRef;
6146+
public:
6147+
bool teamsLoopCanBeParallelFor() const { return TeamsLoopCanBeParallelFor; }
6148+
6149+
// Is there a nested OpenMP loop bind(parallel)
6150+
void VisitOMPExecutableDirective(const OMPExecutableDirective *D) {
6151+
if (D->getDirectiveKind() == llvm::omp::Directive::OMPD_loop) {
6152+
if (const auto *C = D->getSingleClause<OMPBindClause>())
6153+
if (C->getBindKind() == OMPC_BIND_parallel) {
6154+
TeamsLoopCanBeParallelFor = false;
6155+
// No need to continue visiting any more
6156+
return;
6157+
}
6158+
}
6159+
for (const Stmt *Child : D->children())
6160+
if (Child)
6161+
Visit(Child);
6162+
}
6163+
6164+
void VisitCallExpr(const CallExpr *C) {
6165+
// Function calls inhibit parallel loop translation of 'target teams loop'
6166+
// unless the assume-no-nested-parallelism flag has been specified.
6167+
// OpenMP API runtime library calls do not inhibit parallel loop
6168+
// translation, regardless of the assume-no-nested-parallelism.
6169+
if (C) {
6170+
bool IsOpenMPAPI = false;
6171+
auto *FD = dyn_cast_or_null<FunctionDecl>(C->getCalleeDecl());
6172+
if (FD) {
6173+
std::string Name = FD->getNameInfo().getAsString();
6174+
IsOpenMPAPI = Name.find("omp_") == 0;
6175+
}
6176+
TeamsLoopCanBeParallelFor =
6177+
IsOpenMPAPI || SemaRef.getLangOpts().OpenMPNoNestedParallelism;
6178+
if (!TeamsLoopCanBeParallelFor)
6179+
return;
6180+
}
6181+
for (const Stmt *Child : C->children())
6182+
if (Child)
6183+
Visit(Child);
6184+
}
6185+
6186+
void VisitCapturedStmt(const CapturedStmt *S) {
6187+
if (!S)
6188+
return;
6189+
Visit(S->getCapturedDecl()->getBody());
6190+
}
6191+
6192+
void VisitStmt(const Stmt *S) {
6193+
if (!S)
6194+
return;
6195+
for (const Stmt *Child : S->children())
6196+
if (Child)
6197+
Visit(Child);
6198+
}
6199+
explicit TeamsLoopChecker(Sema &SemaRef)
6200+
: SemaRef(SemaRef), TeamsLoopCanBeParallelFor(true) {}
6201+
private:
6202+
bool TeamsLoopCanBeParallelFor;
6203+
};
6204+
} // namespace
6205+
6206+
bool Sema::teamsLoopCanBeParallelFor(Stmt *AStmt) {
6207+
TeamsLoopChecker Checker(*this);
6208+
Checker.Visit(AStmt);
6209+
return Checker.teamsLoopCanBeParallelFor();
6210+
}
6211+
61406212
bool Sema::mapLoopConstruct(llvm::SmallVector<OMPClause *> &ClausesWithoutBind,
61416213
ArrayRef<OMPClause *> Clauses,
61426214
OpenMPBindClauseKind &BindKind,
@@ -10897,7 +10969,8 @@ StmtResult Sema::ActOnOpenMPTargetTeamsGenericLoopDirective(
1089710969
setFunctionHasBranchProtectedScope();
1089810970

1089910971
return OMPTargetTeamsGenericLoopDirective::Create(
10900-
Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
10972+
Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B,
10973+
teamsLoopCanBeParallelFor(AStmt));
1090110974
}
1090210975

1090310976
StmtResult Sema::ActOnOpenMPParallelGenericLoopDirective(

clang/lib/Serialization/ASTReaderStmt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,6 +2776,7 @@ void ASTStmtReader::VisitOMPTeamsGenericLoopDirective(
27762776
void ASTStmtReader::VisitOMPTargetTeamsGenericLoopDirective(
27772777
OMPTargetTeamsGenericLoopDirective *D) {
27782778
VisitOMPLoopDirective(D);
2779+
D->setCanBeParallelFor(Record.readBool());
27792780
}
27802781

27812782
void ASTStmtReader::VisitOMPParallelGenericLoopDirective(

clang/lib/Serialization/ASTWriterStmt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,6 +2823,7 @@ void ASTStmtWriter::VisitOMPTeamsGenericLoopDirective(
28232823
void ASTStmtWriter::VisitOMPTargetTeamsGenericLoopDirective(
28242824
OMPTargetTeamsGenericLoopDirective *D) {
28252825
VisitOMPLoopDirective(D);
2826+
Record.writeBool(D->canBeParallelFor());
28262827
Code = serialization::STMT_OMP_TARGET_TEAMS_GENERIC_LOOP_DIRECTIVE;
28272828
}
28282829

0 commit comments

Comments
 (0)