Skip to content

Commit b85372b

Browse files
committed
Moved check for whether a 'target teams loop' construct can potentially
be considered equivalent to 'target teams distribute parallel for' from CodeGen to Sema.
1 parent 7c183c7 commit b85372b

File tree

10 files changed

+100
-86
lines changed

10 files changed

+100
-86
lines changed

clang/include/clang/AST/StmtOpenMP.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6106,6 +6106,8 @@ class OMPTeamsGenericLoopDirective final : public OMPLoopDirective {
61066106
class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61076107
friend class ASTStmtReader;
61086108
friend class OMPExecutableDirective;
6109+
/// true if loop directive's associated loop can be a parallel for.
6110+
bool CanBeParallelFor = false;
61096111
/// Build directive with the given start and end location.
61106112
///
61116113
/// \param StartLoc Starting location of the directive kind.
@@ -6128,6 +6130,9 @@ class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61286130
llvm::omp::OMPD_target_teams_loop, SourceLocation(),
61296131
SourceLocation(), CollapsedNum) {}
61306132

6133+
/// Set whether associated loop can be a parallel for.
6134+
void setCanBeParallelFor(bool ParFor) { CanBeParallelFor = ParFor; }
6135+
61316136
public:
61326137
/// Creates directive with a list of \p Clauses.
61336138
///
@@ -6142,7 +6147,7 @@ class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61426147
static OMPTargetTeamsGenericLoopDirective *
61436148
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
61446149
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses,
6145-
Stmt *AssociatedStmt, const HelperExprs &Exprs);
6150+
Stmt *AssociatedStmt, const HelperExprs &Exprs, bool CanBeParallelFor);
61466151

61476152
/// Creates an empty directive with the place
61486153
/// for \a NumClauses clauses.
@@ -6156,6 +6161,10 @@ class OMPTargetTeamsGenericLoopDirective final : public OMPLoopDirective {
61566161
unsigned CollapsedNum,
61576162
EmptyShell);
61586163

6164+
/// Return true if current loop directive's associated loop can be a
6165+
/// parallel for.
6166+
bool canBeParallelFor() const { return CanBeParallelFor; }
6167+
61596168
static bool classof(const Stmt *T) {
61606169
return T->getStmtClass() == OMPTargetTeamsGenericLoopDirectiveClass;
61616170
}

clang/include/clang/Sema/Sema.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11311,6 +11311,9 @@ class Sema final {
1131111311
OpenMPDirectiveKind &Kind,
1131211312
OpenMPDirectiveKind &PrevMappedDirective);
1131311313

11314+
/// [target] teams loop is equivalent to parallel for if associated loop
11315+
/// nest meets certain critera.
11316+
bool teamsLoopCanBeParallelFor(Stmt *Astmt);
1131411317
public:
1131511318
/// The declarator \p D defines a function in the scope \p S which is nested
1131611319
/// in an `omp begin/end declare variant` scope. In this method we create a

clang/lib/AST/StmtOpenMP.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2431,7 +2431,7 @@ OMPTeamsGenericLoopDirective::CreateEmpty(const ASTContext &C,
24312431
OMPTargetTeamsGenericLoopDirective *OMPTargetTeamsGenericLoopDirective::Create(
24322432
const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
24332433
unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
2434-
const HelperExprs &Exprs) {
2434+
const HelperExprs &Exprs, bool CanBeParallelFor) {
24352435
auto *Dir = createDirective<OMPTargetTeamsGenericLoopDirective>(
24362436
C, Clauses, AssociatedStmt,
24372437
numLoopChildren(CollapsedNum, OMPD_target_teams_loop), StartLoc, EndLoc,
@@ -2473,6 +2473,7 @@ OMPTargetTeamsGenericLoopDirective *OMPTargetTeamsGenericLoopDirective::Create(
24732473
Dir->setCombinedNextUpperBound(Exprs.DistCombinedFields.NUB);
24742474
Dir->setCombinedDistCond(Exprs.DistCombinedFields.DistCond);
24752475
Dir->setCombinedParForInDistCond(Exprs.DistCombinedFields.ParForInDistCond);
2476+
Dir->setCanBeParallelFor(CanBeParallelFor);
24762477
return Dir;
24772478
}
24782479

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,9 @@ static bool supportsSPMDExecutionMode(CodeGenModule &CGM,
661661
case OMPD_target_teams_loop:
662662
// Whether this is true or not depends on how the directive will
663663
// eventually be emitted.
664-
return CGM.teamsLoopCanBeParallelFor(D);
664+
if (auto *TTLD = dyn_cast<OMPTargetTeamsGenericLoopDirective>(&D))
665+
return TTLD->canBeParallelFor();
666+
return false;
665667
case OMPD_parallel:
666668
case OMPD_for:
667669
case OMPD_parallel_for:

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,9 +1436,12 @@ void CodeGenFunction::EmitOMPReductionClauseFinal(
14361436
*this, D.getBeginLoc(),
14371437
isOpenMPWorksharingDirective(D.getDirectiveKind()));
14381438
}
1439+
bool TeamsLoopCanBeParallel = false;
1440+
if (auto *TTLD = dyn_cast<OMPTargetTeamsGenericLoopDirective>(&D))
1441+
TeamsLoopCanBeParallel = TTLD->canBeParallelFor();
14391442
bool WithNowait = D.getSingleClause<OMPNowaitClause>() ||
14401443
isOpenMPParallelDirective(D.getDirectiveKind()) ||
1441-
CGM.teamsLoopCanBeParallelFor(D) ||
1444+
TeamsLoopCanBeParallel ||
14421445
ReductionKind == OMPD_simd;
14431446
bool SimpleReduction = ReductionKind == OMPD_simd;
14441447
// Emit nowait reduction if nowait clause is present or directive is a
@@ -7965,7 +7968,7 @@ static void emitTargetTeamsGenericLoopRegionAsDistribute(
79657968
void CodeGenFunction::EmitOMPTargetTeamsGenericLoopDirective(
79667969
const OMPTargetTeamsGenericLoopDirective &S) {
79677970
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
7968-
if (CGF.CGM.teamsLoopCanBeParallelFor(S))
7971+
if (S.canBeParallelFor())
79697972
emitTargetTeamsGenericLoopRegionAsParallel(CGF, Action, S);
79707973
else
79717974
emitTargetTeamsGenericLoopRegionAsDistribute(CGF, Action, S);
@@ -7978,7 +7981,7 @@ void CodeGenFunction::EmitOMPTargetTeamsGenericLoopDeviceFunction(
79787981
const OMPTargetTeamsGenericLoopDirective &S) {
79797982
// Emit SPMD target parallel loop region as a standalone region.
79807983
auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) {
7981-
if (CGF.CGM.teamsLoopCanBeParallelFor(S))
7984+
if (S.canBeParallelFor())
79827985
emitTargetTeamsGenericLoopRegionAsParallel(CGF, Action, S);
79837986
else
79847987
emitTargetTeamsGenericLoopRegionAsDistribute(CGF, Action, S);

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -7485,82 +7485,6 @@ void CodeGenModule::printPostfixForExternalizedDecl(llvm::raw_ostream &OS,
74857485
}
74867486
}
74877487

7488-
namespace {
7489-
/// A 'teams loop' with a nested 'loop bind(parallel)' or generic function
7490-
/// call in the associated loop-nest cannot be a 'parllel for'.
7491-
class TeamsLoopChecker final : public ConstStmtVisitor<TeamsLoopChecker> {
7492-
public:
7493-
TeamsLoopChecker(CodeGenModule &CGM)
7494-
: CGM(CGM), TeamsLoopCanBeParallelFor{true} {}
7495-
bool teamsLoopCanBeParallelFor() const { return TeamsLoopCanBeParallelFor; }
7496-
// Is there a nested OpenMP loop bind(parallel)
7497-
void VisitOMPExecutableDirective(const OMPExecutableDirective *D) {
7498-
if (D->getDirectiveKind() == llvm::omp::Directive::OMPD_loop) {
7499-
if (const auto *C = D->getSingleClause<OMPBindClause>())
7500-
if (C->getBindKind() == OMPC_BIND_parallel) {
7501-
TeamsLoopCanBeParallelFor = false;
7502-
// No need to continue visiting any more
7503-
return;
7504-
}
7505-
}
7506-
for (const Stmt *Child : D->children())
7507-
if (Child)
7508-
Visit(Child);
7509-
}
7510-
7511-
void VisitCallExpr(const CallExpr *C) {
7512-
// Function calls inhibit parallel loop translation of 'target teams loop'
7513-
// unless the assume-no-nested-parallelism flag has been specified.
7514-
// OpenMP API runtime library calls do not inhibit parallel loop
7515-
// translation, regardless of the assume-no-nested-parallelism.
7516-
if (C) {
7517-
bool IsOpenMPAPI = false;
7518-
auto *FD = dyn_cast_or_null<FunctionDecl>(C->getCalleeDecl());
7519-
if (FD) {
7520-
std::string Name = FD->getNameInfo().getAsString();
7521-
IsOpenMPAPI = Name.find("omp_") == 0;
7522-
}
7523-
TeamsLoopCanBeParallelFor =
7524-
IsOpenMPAPI || CGM.getLangOpts().OpenMPNoNestedParallelism;
7525-
if (!TeamsLoopCanBeParallelFor)
7526-
return;
7527-
}
7528-
for (const Stmt *Child : C->children())
7529-
if (Child)
7530-
Visit(Child);
7531-
}
7532-
7533-
void VisitCapturedStmt(const CapturedStmt *S) {
7534-
if (!S)
7535-
return;
7536-
Visit(S->getCapturedDecl()->getBody());
7537-
}
7538-
7539-
void VisitStmt(const Stmt *S) {
7540-
if (!S)
7541-
return;
7542-
for (const Stmt *Child : S->children())
7543-
if (Child)
7544-
Visit(Child);
7545-
}
7546-
7547-
private:
7548-
CodeGenModule &CGM;
7549-
bool TeamsLoopCanBeParallelFor;
7550-
};
7551-
} // namespace
7552-
7553-
/// Determine if 'teams loop' can be emitted using 'parallel for'.
7554-
bool CodeGenModule::teamsLoopCanBeParallelFor(const OMPExecutableDirective &D) {
7555-
if (D.getDirectiveKind() != llvm::omp::Directive::OMPD_target_teams_loop)
7556-
return false;
7557-
assert(D.hasAssociatedStmt() &&
7558-
"Loop directive must have associated statement.");
7559-
TeamsLoopChecker Checker(*this);
7560-
Checker.Visit(D.getAssociatedStmt());
7561-
return Checker.teamsLoopCanBeParallelFor();
7562-
}
7563-
75647488
void CodeGenModule::emitTargetTeamsLoopCodegenStatus(
75657489
std::string StatusMsg, const OMPExecutableDirective &D, bool IsDevice) {
75667490
#ifndef NDEBUG

clang/lib/CodeGen/CodeGenModule.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,8 +1528,6 @@ class CodeGenModule : public CodeGenTypeCache {
15281528
LValueBaseInfo *BaseInfo = nullptr,
15291529
TBAAAccessInfo *TBAAInfo = nullptr);
15301530
bool stopAutoInit();
1531-
/// Determine if 'teams loop' can be emitted using 'parallel for'.
1532-
bool teamsLoopCanBeParallelFor(const OMPExecutableDirective &D);
15331531

15341532
/// Print the postfix for externalized static variable or kernels for single
15351533
/// source offloading languages CUDA and HIP. The unique postfix is created

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6124,6 +6124,78 @@ processImplicitMapsWithDefaultMappers(Sema &S, DSAStackTy *Stack,
61246124
}
61256125
}
61266126

6127+
namespace {
6128+
/// A 'teams loop' with a nested 'loop bind(parallel)' or generic function
6129+
/// call in the associated loop-nest cannot be a 'parallel for'.
6130+
class TeamsLoopChecker final
6131+
: public ConstStmtVisitor<TeamsLoopChecker> {
6132+
Sema &SemaRef;
6133+
public:
6134+
bool teamsLoopCanBeParallelFor() const { return TeamsLoopCanBeParallelFor; }
6135+
6136+
// Is there a nested OpenMP loop bind(parallel)
6137+
void VisitOMPExecutableDirective(const OMPExecutableDirective *D) {
6138+
if (D->getDirectiveKind() == llvm::omp::Directive::OMPD_loop) {
6139+
if (const auto *C = D->getSingleClause<OMPBindClause>())
6140+
if (C->getBindKind() == OMPC_BIND_parallel) {
6141+
TeamsLoopCanBeParallelFor = false;
6142+
// No need to continue visiting any more
6143+
return;
6144+
}
6145+
}
6146+
for (const Stmt *Child : D->children())
6147+
if (Child)
6148+
Visit(Child);
6149+
}
6150+
6151+
void VisitCallExpr(const CallExpr *C) {
6152+
// Function calls inhibit parallel loop translation of 'target teams loop'
6153+
// unless the assume-no-nested-parallelism flag has been specified.
6154+
// OpenMP API runtime library calls do not inhibit parallel loop
6155+
// translation, regardless of the assume-no-nested-parallelism.
6156+
if (C) {
6157+
bool IsOpenMPAPI = false;
6158+
auto *FD = dyn_cast_or_null<FunctionDecl>(C->getCalleeDecl());
6159+
if (FD) {
6160+
std::string Name = FD->getNameInfo().getAsString();
6161+
IsOpenMPAPI = Name.find("omp_") == 0;
6162+
}
6163+
TeamsLoopCanBeParallelFor =
6164+
IsOpenMPAPI || SemaRef.getLangOpts().OpenMPNoNestedParallelism;
6165+
if (!TeamsLoopCanBeParallelFor)
6166+
return;
6167+
}
6168+
for (const Stmt *Child : C->children())
6169+
if (Child)
6170+
Visit(Child);
6171+
}
6172+
6173+
void VisitCapturedStmt(const CapturedStmt *S) {
6174+
if (!S)
6175+
return;
6176+
Visit(S->getCapturedDecl()->getBody());
6177+
}
6178+
6179+
void VisitStmt(const Stmt *S) {
6180+
if (!S)
6181+
return;
6182+
for (const Stmt *Child : S->children())
6183+
if (Child)
6184+
Visit(Child);
6185+
}
6186+
explicit TeamsLoopChecker(Sema &SemaRef)
6187+
: SemaRef(SemaRef), TeamsLoopCanBeParallelFor(true) {}
6188+
private:
6189+
bool TeamsLoopCanBeParallelFor;
6190+
};
6191+
} // namespace
6192+
6193+
bool Sema::teamsLoopCanBeParallelFor(Stmt *AStmt) {
6194+
TeamsLoopChecker Checker(*this);
6195+
Checker.Visit(AStmt);
6196+
return Checker.teamsLoopCanBeParallelFor();
6197+
}
6198+
61276199
bool Sema::mapLoopConstruct(llvm::SmallVector<OMPClause *> &ClausesWithoutBind,
61286200
ArrayRef<OMPClause *> Clauses,
61296201
OpenMPBindClauseKind BindKind,
@@ -6234,7 +6306,6 @@ StmtResult Sema::ActOnOpenMPExecutableDirective(
62346306

62356307
UseClausesWithoutBind = mapLoopConstruct(ClausesWithoutBind, Clauses,
62366308
BindKind, Kind, PrevMappedDirective);
6237-
62386309
llvm::SmallVector<OMPClause *, 8> ClausesWithImplicit;
62396310
VarsWithInheritedDSAType VarsWithInheritedDSA;
62406311
bool ErrorFound = false;
@@ -10870,7 +10941,8 @@ StmtResult Sema::ActOnOpenMPTargetTeamsGenericLoopDirective(
1087010941
setFunctionHasBranchProtectedScope();
1087110942

1087210943
return OMPTargetTeamsGenericLoopDirective::Create(
10873-
Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
10944+
Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B,
10945+
teamsLoopCanBeParallelFor(AStmt));
1087410946
}
1087510947

1087610948
StmtResult Sema::ActOnOpenMPParallelGenericLoopDirective(

clang/lib/Serialization/ASTReaderStmt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2715,6 +2715,7 @@ void ASTStmtReader::VisitOMPTeamsGenericLoopDirective(
27152715
void ASTStmtReader::VisitOMPTargetTeamsGenericLoopDirective(
27162716
OMPTargetTeamsGenericLoopDirective *D) {
27172717
VisitOMPLoopDirective(D);
2718+
D->setCanBeParallelFor(Record.readBool());
27182719
}
27192720

27202721
void ASTStmtReader::VisitOMPParallelGenericLoopDirective(

clang/lib/Serialization/ASTWriterStmt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,6 +2696,7 @@ void ASTStmtWriter::VisitOMPTeamsGenericLoopDirective(
26962696
void ASTStmtWriter::VisitOMPTargetTeamsGenericLoopDirective(
26972697
OMPTargetTeamsGenericLoopDirective *D) {
26982698
VisitOMPLoopDirective(D);
2699+
Record.writeBool(D->canBeParallelFor());
26992700
Code = serialization::STMT_OMP_TARGET_TEAMS_GENERIC_LOOP_DIRECTIVE;
27002701
}
27012702

0 commit comments

Comments
 (0)