Skip to content

Commit 37c28db

Browse files
committed
[Clang][OpenMP] Allow num_teams to accept multiple expressions
1 parent 378fe2f commit 37c28db

File tree

15 files changed

+478
-312
lines changed

15 files changed

+478
-312
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,60 +6131,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
61316131
/// \endcode
61326132
/// In this example directive '#pragma omp teams' has clause 'num_teams'
61336133
/// with single expression 'n'.
6134-
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
6135-
friend class OMPClauseReader;
6134+
///
6135+
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
6136+
/// can accept up to three expressions.
6137+
///
6138+
/// \code
6139+
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
6140+
/// \endcode
6141+
class OMPNumTeamsClause final
6142+
: public OMPVarListClause<OMPNumTeamsClause>,
6143+
public OMPClauseWithPreInit,
6144+
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
6145+
friend OMPVarListClause;
6146+
friend TrailingObjects;
61366147

61376148
/// Location of '('.
61386149
SourceLocation LParenLoc;
61396150

6140-
/// NumTeams number.
6141-
Stmt *NumTeams = nullptr;
6151+
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
6152+
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
6153+
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
6154+
N),
6155+
OMPClauseWithPreInit(this) {}
61426156

6143-
/// Set the NumTeams number.
6144-
///
6145-
/// \param E NumTeams number.
6146-
void setNumTeams(Expr *E) { NumTeams = E; }
6157+
/// Build an empty clause.
6158+
OMPNumTeamsClause(unsigned N)
6159+
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6160+
SourceLocation(), SourceLocation(), N),
6161+
OMPClauseWithPreInit(this) {}
61476162

61486163
public:
6149-
/// Build 'num_teams' clause.
6164+
/// Creates clause with a list of variables \a VL.
61506165
///
6151-
/// \param E Expression associated with this clause.
6152-
/// \param HelperE Helper Expression associated with this clause.
6153-
/// \param CaptureRegion Innermost OpenMP region where expressions in this
6154-
/// clause must be captured.
6166+
/// \param C AST context.
61556167
/// \param StartLoc Starting location of the clause.
61566168
/// \param LParenLoc Location of '('.
61576169
/// \param EndLoc Ending location of the clause.
6158-
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
6159-
SourceLocation StartLoc, SourceLocation LParenLoc,
6160-
SourceLocation EndLoc)
6161-
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
6162-
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
6163-
setPreInitStmt(HelperE, CaptureRegion);
6164-
}
6170+
/// \param VL List of references to the variables.
6171+
/// \param PreInit
6172+
static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
6173+
SourceLocation LParenLoc,
6174+
SourceLocation EndLoc, ArrayRef<Expr *> VL,
6175+
Stmt *PreInit);
61656176

6166-
/// Build an empty clause.
6167-
OMPNumTeamsClause()
6168-
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6169-
SourceLocation()),
6170-
OMPClauseWithPreInit(this) {}
6177+
/// Creates an empty clause with \a N variables.
6178+
///
6179+
/// \param C AST context.
6180+
/// \param N The number of variables.
6181+
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
61716182

61726183
/// Sets the location of '('.
61736184
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
61746185

61756186
/// Returns the location of '('.
61766187
SourceLocation getLParenLoc() const { return LParenLoc; }
61776188

6178-
/// Return NumTeams number.
6179-
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
6189+
/// Return NumTeams expressions.
6190+
ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }
61806191

6181-
/// Return NumTeams number.
6182-
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
6192+
/// Return NumTeams expressions.
6193+
ArrayRef<Expr *> getNumTeams() const {
6194+
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
6195+
}
61836196

6184-
child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
6197+
child_range children() {
6198+
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
6199+
reinterpret_cast<Stmt **>(varlist_end()));
6200+
}
61856201

61866202
const_child_range children() const {
6187-
return const_child_range(&NumTeams, &NumTeams + 1);
6203+
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
6204+
return const_child_range(Children.begin(), Children.end());
61886205
}
61896206

61906207
child_range used_children() {

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
37933793
template <typename Derived>
37943794
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
37953795
OMPNumTeamsClause *C) {
3796+
TRY_TO(VisitOMPClauseList(C));
37963797
TRY_TO(VisitOMPClauseWithPreInit(C));
3797-
TRY_TO(TraverseStmt(C->getNumTeams()));
37983798
return true;
37993799
}
38003800

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,8 @@ class SemaOpenMP : public SemaBase {
12261226
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
12271227
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
12281228
/// Called on well-formed 'num_teams' clause.
1229-
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
1229+
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
1230+
SourceLocation StartLoc,
12301231
SourceLocation LParenLoc,
12311232
SourceLocation EndLoc);
12321233
/// Called on well-formed 'thread_limit' clause.

clang/lib/AST/OpenMPClause.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
17201720
return *It;
17211721
}
17221722

1723+
OMPNumTeamsClause *
1724+
OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
1725+
SourceLocation LParenLoc, SourceLocation EndLoc,
1726+
ArrayRef<Expr *> VL, Stmt *PreInit) {
1727+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
1728+
OMPNumTeamsClause *Clause =
1729+
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
1730+
Clause->setVarRefs(VL);
1731+
Clause->setPreInitStmt(PreInit);
1732+
return Clause;
1733+
}
1734+
1735+
OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
1736+
unsigned N) {
1737+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
1738+
return new (Mem) OMPNumTeamsClause(N);
1739+
}
1740+
17231741
//===----------------------------------------------------------------------===//
17241742
// OpenMP clauses printing methods
17251743
//===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
19771995
}
19781996

19791997
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
1980-
OS << "num_teams(";
1981-
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
1982-
OS << ")";
1998+
if (!Node->varlist_empty()) {
1999+
OS << "num_teams";
2000+
VisitOMPClauseList(Node, '(');
2001+
OS << ")";
2002+
}
19832003
}
19842004

19852005
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {

clang/lib/AST/StmtProfile.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
843843
VisitOMPClauseList(C);
844844
}
845845
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
846+
VisitOMPClauseList(C);
846847
VistOMPClauseWithPreInit(C);
847-
if (C->getNumTeams())
848-
Profiler->VisitStmt(C->getNumTeams());
849848
}
850849
void OMPClauseProfiler::VisitOMPThreadLimitClause(
851850
const OMPThreadLimitClause *C) {

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
60366036
dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
60376037
if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
60386038
if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
6039-
const Expr *NumTeams =
6040-
NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
6039+
const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
6040+
->getNumTeams()
6041+
.front();
60416042
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
60426043
if (auto Constant =
60436044
NumTeams->getIntegerConstantExpr(CGF.getContext()))
@@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
60626063
case OMPD_target_teams_distribute_parallel_for_simd: {
60636064
if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
60646065
const Expr *NumTeams =
6065-
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
6066+
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
60666067
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
60676068
if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
60686069
MinTeamsVal = MaxTeamsVal = Constant->getExtValue();

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
68596859
const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
68606860
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
68616861
if (NT || TL) {
6862-
const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
6862+
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
68636863
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
68646864

68656865
CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
30983098
case OMPC_simdlen:
30993099
case OMPC_collapse:
31003100
case OMPC_ordered:
3101-
case OMPC_num_teams:
31023101
case OMPC_thread_limit:
31033102
case OMPC_priority:
31043103
case OMPC_grainsize:
@@ -3252,6 +3251,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
32523251
? ParseOpenMPSimpleClause(CKind, WrongDirective)
32533252
: ParseOpenMPClause(CKind, WrongDirective);
32543253
break;
3254+
case OMPC_num_teams:
3255+
if (!FirstClause) {
3256+
Diag(Tok, diag::err_omp_more_one_clause)
3257+
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
3258+
ErrorFound = true;
3259+
}
3260+
[[clang::fallthrough]];
32553261
case OMPC_private:
32563262
case OMPC_firstprivate:
32573263
case OMPC_lastprivate:

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13785,6 +13785,16 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
1378513785
return StmtError();
1378613786
}
1378713787

13788+
auto NumTeamsClauseItr =
13789+
llvm::find_if(Clauses, llvm::IsaPred<OMPNumTeamsClause>);
13790+
if (NumTeamsClauseItr != Clauses.end()) {
13791+
ArrayRef<const Expr *> NumTeams =
13792+
cast<OMPNumTeamsClause>(*NumTeamsClauseItr)->getNumTeams();
13793+
if (!HasBareClause && NumTeams.size() > 1) {
13794+
return StmtError();
13795+
}
13796+
}
13797+
1378813798
return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
1378913799
Clauses, AStmt);
1379013800
}
@@ -14925,9 +14935,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1492514935
case OMPC_ordered:
1492614936
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
1492714937
break;
14928-
case OMPC_num_teams:
14929-
Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
14930-
break;
1493114938
case OMPC_thread_limit:
1493214939
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
1493314940
break;
@@ -15031,6 +15038,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1503115038
case OMPC_affinity:
1503215039
case OMPC_when:
1503315040
case OMPC_bind:
15041+
case OMPC_num_teams:
1503415042
default:
1503515043
llvm_unreachable("Clause is not allowed.");
1503615044
}
@@ -16894,6 +16902,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
1689416902
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
1689516903
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
1689616904
break;
16905+
case OMPC_num_teams:
16906+
Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
16907+
break;
1689716908
case OMPC_if:
1689816909
case OMPC_depobj:
1689916910
case OMPC_final:
@@ -16924,7 +16935,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
1692416935
case OMPC_device:
1692516936
case OMPC_threads:
1692616937
case OMPC_simd:
16927-
case OMPC_num_teams:
1692816938
case OMPC_thread_limit:
1692916939
case OMPC_priority:
1693016940
case OMPC_grainsize:
@@ -21587,32 +21597,39 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
2158721597
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
2158821598
}
2158921599

21590-
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
21600+
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
2159121601
SourceLocation StartLoc,
2159221602
SourceLocation LParenLoc,
2159321603
SourceLocation EndLoc) {
21594-
Expr *ValExpr = NumTeams;
21595-
Stmt *HelperValStmt = nullptr;
21596-
21597-
// OpenMP [teams Constrcut, Restrictions]
21598-
// The num_teams expression must evaluate to a positive integer value.
21599-
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
21600-
/*StrictlyPositive=*/true))
21604+
if (VarList.empty())
2160121605
return nullptr;
2160221606

21607+
for (Expr *ValExpr : VarList) {
21608+
// OpenMP [teams Constrcut, Restrictions]
21609+
// The num_teams expression must evaluate to a positive integer value.
21610+
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
21611+
/*StrictlyPositive=*/true))
21612+
return nullptr;
21613+
}
21614+
2160321615
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
2160421616
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
2160521617
DKind, OMPC_num_teams, getLangOpts().OpenMP);
21606-
if (CaptureRegion != OMPD_unknown &&
21607-
!SemaRef.CurContext->isDependentContext()) {
21618+
if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
21619+
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
21620+
EndLoc, VarList, /*PreInit=*/nullptr);
21621+
21622+
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
21623+
SmallVector<Expr *, 3> Vars;
21624+
for (Expr *ValExpr : VarList) {
2160821625
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
21609-
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
2161021626
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
21611-
HelperValStmt = buildPreInits(getASTContext(), Captures);
21627+
Vars.push_back(ValExpr);
2161221628
}
2161321629

21614-
return new (getASTContext()) OMPNumTeamsClause(
21615-
ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
21630+
Stmt *PreInit = buildPreInits(getASTContext(), Captures);
21631+
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
21632+
Vars, PreInit);
2161621633
}
2161721634

2161821635
OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,

clang/lib/Sema/TreeTransform.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,10 +2065,11 @@ class TreeTransform {
20652065
///
20662066
/// By default, performs semantic analysis to build the new statement.
20672067
/// Subclasses may override this routine to provide different behavior.
2068-
OMPClause *RebuildOMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
2068+
OMPClause *RebuildOMPNumTeamsClause(ArrayRef<Expr *> VarList,
2069+
SourceLocation StartLoc,
20692070
SourceLocation LParenLoc,
20702071
SourceLocation EndLoc) {
2071-
return getSema().OpenMP().ActOnOpenMPNumTeamsClause(NumTeams, StartLoc,
2072+
return getSema().OpenMP().ActOnOpenMPNumTeamsClause(VarList, StartLoc,
20722073
LParenLoc, EndLoc);
20732074
}
20742075

@@ -10872,7 +10873,7 @@ TreeTransform<Derived>::TransformOMPAllocateClause(OMPAllocateClause *C) {
1087210873
template <typename Derived>
1087310874
OMPClause *
1087410875
TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) {
10875-
ExprResult E = getDerived().TransformExpr(C->getNumTeams());
10876+
ExprResult E = getDerived().TransformExpr(C->getNumTeams().front());
1087610877
if (E.isInvalid())
1087710878
return nullptr;
1087810879
return getDerived().RebuildOMPNumTeamsClause(

clang/lib/Serialization/ASTReader.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
#include "llvm/ADT/IntrusiveRefCntPtr.h"
105105
#include "llvm/ADT/STLExtras.h"
106106
#include "llvm/ADT/ScopeExit.h"
107+
#include "llvm/ADT/Sequence.h"
107108
#include "llvm/ADT/SmallPtrSet.h"
108109
#include "llvm/ADT/SmallString.h"
109110
#include "llvm/ADT/SmallVector.h"
@@ -10562,7 +10563,7 @@ OMPClause *OMPClauseReader::readClause() {
1056210563
break;
1056310564
}
1056410565
case llvm::omp::OMPC_num_teams:
10565-
C = new (Context) OMPNumTeamsClause();
10566+
C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt());
1056610567
break;
1056710568
case llvm::omp::OMPC_thread_limit:
1056810569
C = new (Context) OMPThreadLimitClause();
@@ -11350,8 +11351,13 @@ void OMPClauseReader::VisitOMPAllocateClause(OMPAllocateClause *C) {
1135011351

1135111352
void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
1135211353
VisitOMPClauseWithPreInit(C);
11353-
C->setNumTeams(Record.readSubExpr());
1135411354
C->setLParenLoc(Record.readSourceLocation());
11355+
unsigned NumVars = C->varlist_size();
11356+
SmallVector<Expr *, 16> Vars;
11357+
Vars.reserve(NumVars);
11358+
for ([[maybe_unused]] unsigned I : llvm::seq<unsigned>(NumVars))
11359+
Vars.push_back(Record.readSubExpr());
11360+
C->setVarRefs(Vars);
1135511361
}
1135611362

1135711363
void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {

0 commit comments

Comments
 (0)