Skip to content

Commit 299dca8

Browse files
committed
[Clang][OpenMP] Allow num_teams to accept multiple expressions
1 parent 05f0e86 commit 299dca8

File tree

7 files changed

+99
-57
lines changed

7 files changed

+99
-57
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,60 +6131,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
61316131
/// \endcode
61326132
/// In this example directive '#pragma omp teams' has clause 'num_teams'
61336133
/// with single expression 'n'.
6134-
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
6135-
friend class OMPClauseReader;
6134+
///
6135+
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
6136+
/// can accept up to three expressions.
6137+
///
6138+
/// \code
6139+
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
6140+
/// \endcode
6141+
class OMPNumTeamsClause final
6142+
: public OMPVarListClause<OMPNumTeamsClause>,
6143+
public OMPClauseWithPreInit,
6144+
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
6145+
friend OMPVarListClause;
6146+
friend TrailingObjects;
61366147

61376148
/// Location of '('.
61386149
SourceLocation LParenLoc;
61396150

6140-
/// NumTeams number.
6141-
Stmt *NumTeams = nullptr;
6151+
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
6152+
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
6153+
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
6154+
N),
6155+
OMPClauseWithPreInit(this) {}
61426156

6143-
/// Set the NumTeams number.
6144-
///
6145-
/// \param E NumTeams number.
6146-
void setNumTeams(Expr *E) { NumTeams = E; }
6157+
/// Build an empty clause.
6158+
OMPNumTeamsClause(unsigned N)
6159+
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6160+
SourceLocation(), SourceLocation(), N),
6161+
OMPClauseWithPreInit(this) {}
61476162

61486163
public:
6149-
/// Build 'num_teams' clause.
6164+
/// Creates clause with a list of variables \a VL.
61506165
///
6151-
/// \param E Expression associated with this clause.
6152-
/// \param HelperE Helper Expression associated with this clause.
6153-
/// \param CaptureRegion Innermost OpenMP region where expressions in this
6154-
/// clause must be captured.
6166+
/// \param C AST context.
61556167
/// \param StartLoc Starting location of the clause.
61566168
/// \param LParenLoc Location of '('.
61576169
/// \param EndLoc Ending location of the clause.
6158-
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
6159-
SourceLocation StartLoc, SourceLocation LParenLoc,
6160-
SourceLocation EndLoc)
6161-
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
6162-
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
6163-
setPreInitStmt(HelperE, CaptureRegion);
6164-
}
6170+
/// \param VL List of references to the variables.
6171+
/// \param PreInit
6172+
static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
6173+
SourceLocation LParenLoc,
6174+
SourceLocation EndLoc, ArrayRef<Expr *> VL,
6175+
Stmt *PreInit);
61656176

6166-
/// Build an empty clause.
6167-
OMPNumTeamsClause()
6168-
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6169-
SourceLocation()),
6170-
OMPClauseWithPreInit(this) {}
6177+
/// Creates an empty clause with \a N variables.
6178+
///
6179+
/// \param C AST context.
6180+
/// \param N The number of variables.
6181+
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
61716182

61726183
/// Sets the location of '('.
61736184
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
61746185

61756186
/// Returns the location of '('.
61766187
SourceLocation getLParenLoc() const { return LParenLoc; }
61776188

6178-
/// Return NumTeams number.
6179-
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
6189+
/// Return NumTeams number. By default, we return the first expression.
6190+
Expr *getNumTeams() { return getVarRefs().front(); }
61806191

6181-
/// Return NumTeams number.
6182-
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
6192+
/// Return NumTeams number. By default, we return the first expression.
6193+
Expr *getNumTeams() const {
6194+
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
6195+
}
61836196

6184-
child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
6197+
child_range children() {
6198+
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
6199+
reinterpret_cast<Stmt **>(varlist_end()));
6200+
}
61856201

61866202
const_child_range children() const {
6187-
return const_child_range(&NumTeams, &NumTeams + 1);
6203+
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
6204+
return const_child_range(Children.begin(), Children.end());
61886205
}
61896206

61906207
child_range used_children() {

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
37933793
template <typename Derived>
37943794
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
37953795
OMPNumTeamsClause *C) {
3796+
TRY_TO(VisitOMPClauseList(C));
37963797
TRY_TO(VisitOMPClauseWithPreInit(C));
3797-
TRY_TO(TraverseStmt(C->getNumTeams()));
37983798
return true;
37993799
}
38003800

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase {
12271227
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
12281228
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
12291229
/// Called on well-formed 'num_teams' clause.
1230-
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
1230+
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
1231+
SourceLocation StartLoc,
12311232
SourceLocation LParenLoc,
12321233
SourceLocation EndLoc);
12331234
/// Called on well-formed 'thread_limit' clause.

clang/lib/AST/OpenMPClause.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
17201720
return *It;
17211721
}
17221722

1723+
OMPNumTeamsClause *
1724+
OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
1725+
SourceLocation LParenLoc, SourceLocation EndLoc,
1726+
ArrayRef<Expr *> VL, Stmt *PreInit) {
1727+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
1728+
OMPNumTeamsClause *Clause =
1729+
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
1730+
Clause->setVarRefs(VL);
1731+
Clause->setPreInitStmt(PreInit);
1732+
return Clause;
1733+
}
1734+
1735+
OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
1736+
unsigned N) {
1737+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
1738+
return new (Mem) OMPNumTeamsClause(N);
1739+
}
1740+
17231741
//===----------------------------------------------------------------------===//
17241742
// OpenMP clauses printing methods
17251743
//===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
19771995
}
19781996

19791997
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
1980-
OS << "num_teams(";
1981-
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
1982-
OS << ")";
1998+
if (!Node->varlist_empty()) {
1999+
OS << "num_teams";
2000+
VisitOMPClauseList(Node, '(');
2001+
OS << ")";
2002+
}
19832003
}
19842004

19852005
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {

clang/lib/AST/StmtProfile.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
843843
VisitOMPClauseList(C);
844844
}
845845
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
846+
VisitOMPClauseList(C);
846847
VistOMPClauseWithPreInit(C);
847-
if (C->getNumTeams())
848-
Profiler->VisitStmt(C->getNumTeams());
849848
}
850849
void OMPClauseProfiler::VisitOMPThreadLimitClause(
851850
const OMPThreadLimitClause *C) {

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
30983098
case OMPC_simdlen:
30993099
case OMPC_collapse:
31003100
case OMPC_ordered:
3101-
case OMPC_num_teams:
31023101
case OMPC_thread_limit:
31033102
case OMPC_priority:
31043103
case OMPC_grainsize:
@@ -3279,6 +3278,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
32793278
case OMPC_affinity:
32803279
case OMPC_doacross:
32813280
case OMPC_enter:
3281+
case OMPC_num_teams:
32823282
if (getLangOpts().OpenMP >= 52 && DKind == OMPD_ordered &&
32833283
CKind == OMPC_depend)
32843284
Diag(Tok, diag::warn_omp_depend_in_ordered_deprecated);

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15041,9 +15041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1504115041
case OMPC_ordered:
1504215042
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
1504315043
break;
15044-
case OMPC_num_teams:
15045-
Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
15046-
break;
1504715044
case OMPC_thread_limit:
1504815045
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
1504915046
break;
@@ -15147,6 +15144,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1514715144
case OMPC_affinity:
1514815145
case OMPC_when:
1514915146
case OMPC_bind:
15147+
case OMPC_num_teams:
1515015148
default:
1515115149
llvm_unreachable("Clause is not allowed.");
1515215150
}
@@ -17010,6 +17008,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
1701017008
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
1701117009
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
1701217010
break;
17011+
case OMPC_num_teams:
17012+
Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
17013+
break;
1701317014
case OMPC_if:
1701417015
case OMPC_depobj:
1701517016
case OMPC_final:
@@ -17040,7 +17041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
1704017041
case OMPC_device:
1704117042
case OMPC_threads:
1704217043
case OMPC_simd:
17043-
case OMPC_num_teams:
1704417044
case OMPC_thread_limit:
1704517045
case OMPC_priority:
1704617046
case OMPC_grainsize:
@@ -21703,32 +21703,37 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
2170321703
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
2170421704
}
2170521705

21706-
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
21706+
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
2170721707
SourceLocation StartLoc,
2170821708
SourceLocation LParenLoc,
2170921709
SourceLocation EndLoc) {
21710-
Expr *ValExpr = NumTeams;
21711-
Stmt *HelperValStmt = nullptr;
21712-
21713-
// OpenMP [teams Constrcut, Restrictions]
21714-
// The num_teams expression must evaluate to a positive integer value.
21715-
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
21716-
/*StrictlyPositive=*/true))
21710+
if (VarList.empty())
2171721711
return nullptr;
2171821712

2171921713
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
2172021714
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
2172121715
DKind, OMPC_num_teams, getLangOpts().OpenMP);
21722-
if (CaptureRegion != OMPD_unknown &&
21723-
!SemaRef.CurContext->isDependentContext()) {
21716+
21717+
if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
21718+
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
21719+
EndLoc, VarList, /*PreInit=*/nullptr);
21720+
21721+
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
21722+
SmallVector<Expr *, 3> Vars;
21723+
for (Expr *ValExpr : VarList) {
21724+
// OpenMP [teams Constrcut, Restrictions]
21725+
// The num_teams expression must evaluate to a positive integer value.
21726+
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
21727+
/*StrictlyPositive=*/true))
21728+
return nullptr;
2172421729
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
21725-
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
2172621730
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
21727-
HelperValStmt = buildPreInits(getASTContext(), Captures);
21731+
Vars.push_back(ValExpr);
2172821732
}
2172921733

21730-
return new (getASTContext()) OMPNumTeamsClause(
21731-
ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
21734+
Stmt *PreInit = buildPreInits(getASTContext(), Captures);
21735+
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
21736+
Vars, PreInit);
2173221737
}
2173321738

2173421739
OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,

0 commit comments

Comments
 (0)