Skip to content

Commit fd59f45

Browse files
committed
[Clang][Sema][OpenMP] Allow thread_limit to accept multiple expressions
1 parent 3696a34 commit fd59f45

19 files changed

+181
-82
lines changed

clang/docs/OpenMPSupport.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ considered for standardization. Please post on the
363363
| device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 |
364364
| | <https://www.osti.gov/servlets/purl/2205717>`_ | | |
365365
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
366-
| device extension | Multi-dim 'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 |
366+
| device extension | Multi-dim 'num_teams' and 'thread_limit' clause on 'target teams ompx_bare' | :good:`partial` | #99732, #101407, #102715 |
367+
| | construct | | |
367368
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
368369

369370
.. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35

clang/docs/ReleaseNotes.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,9 @@ Improvements
360360
^^^^^^^^^^^^
361361
- Improve the handling of mapping array-section for struct containing nested structs with user defined mappers
362362

363-
- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
364-
This allows the target region to be launched with multi-dim grid on GPUs.
363+
- `num_teams` and `thead_limit` now accept multiple expressions when it is used
364+
along in ``target teams ompx_bare`` construct. This allows the target region
365+
to be launched with multi-dim grid on GPUs.
365366

366367
Additional Information
367368
======================

clang/include/clang/AST/OpenMPClause.h

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6462,61 +6462,78 @@ class OMPNumTeamsClause final
64626462
/// \endcode
64636463
/// In this example directive '#pragma omp teams' has clause 'thread_limit'
64646464
/// with single expression 'n'.
6465-
class OMPThreadLimitClause : public OMPClause, public OMPClauseWithPreInit {
6466-
friend class OMPClauseReader;
6465+
///
6466+
/// When 'ompx_bare' clause exists on a 'target' directive, 'thread_limit'
6467+
/// clause can accept up to three expressions.
6468+
///
6469+
/// \code
6470+
/// #pragma omp target teams ompx_bare thread_limit(x, y, z)
6471+
/// \endcode
6472+
class OMPThreadLimitClause final
6473+
: public OMPVarListClause<OMPThreadLimitClause>,
6474+
public OMPClauseWithPreInit,
6475+
private llvm::TrailingObjects<OMPThreadLimitClause, Expr *> {
6476+
friend OMPVarListClause;
6477+
friend TrailingObjects;
64676478

64686479
/// Location of '('.
64696480
SourceLocation LParenLoc;
64706481

6471-
/// ThreadLimit number.
6472-
Stmt *ThreadLimit = nullptr;
6482+
OMPThreadLimitClause(const ASTContext &C, SourceLocation StartLoc,
6483+
SourceLocation LParenLoc, SourceLocation EndLoc,
6484+
unsigned N)
6485+
: OMPVarListClause(llvm::omp::OMPC_thread_limit, StartLoc, LParenLoc,
6486+
EndLoc, N),
6487+
OMPClauseWithPreInit(this) {}
64736488

6474-
/// Set the ThreadLimit number.
6475-
///
6476-
/// \param E ThreadLimit number.
6477-
void setThreadLimit(Expr *E) { ThreadLimit = E; }
6489+
/// Build an empty clause.
6490+
OMPThreadLimitClause(unsigned N)
6491+
: OMPVarListClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
6492+
SourceLocation(), SourceLocation(), N),
6493+
OMPClauseWithPreInit(this) {}
64786494

64796495
public:
6480-
/// Build 'thread_limit' clause.
6496+
/// Creates clause with a list of variables \a VL.
64816497
///
6482-
/// \param E Expression associated with this clause.
6483-
/// \param HelperE Helper Expression associated with this clause.
6484-
/// \param CaptureRegion Innermost OpenMP region where expressions in this
6485-
/// clause must be captured.
6498+
/// \param C AST context.
64866499
/// \param StartLoc Starting location of the clause.
64876500
/// \param LParenLoc Location of '('.
64886501
/// \param EndLoc Ending location of the clause.
6489-
OMPThreadLimitClause(Expr *E, Stmt *HelperE,
6490-
OpenMPDirectiveKind CaptureRegion,
6491-
SourceLocation StartLoc, SourceLocation LParenLoc,
6492-
SourceLocation EndLoc)
6493-
: OMPClause(llvm::omp::OMPC_thread_limit, StartLoc, EndLoc),
6494-
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), ThreadLimit(E) {
6495-
setPreInitStmt(HelperE, CaptureRegion);
6496-
}
6502+
/// \param VL List of references to the variables.
6503+
/// \param PreInit
6504+
static OMPThreadLimitClause *
6505+
Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
6506+
SourceLocation StartLoc, SourceLocation LParenLoc,
6507+
SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);
64976508

6498-
/// Build an empty clause.
6499-
OMPThreadLimitClause()
6500-
: OMPClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
6501-
SourceLocation()),
6502-
OMPClauseWithPreInit(this) {}
6509+
/// Creates an empty clause with \a N variables.
6510+
///
6511+
/// \param C AST context.
6512+
/// \param N The number of variables.
6513+
static OMPThreadLimitClause *CreateEmpty(const ASTContext &C, unsigned N);
65036514

65046515
/// Sets the location of '('.
65056516
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
65066517

65076518
/// Returns the location of '('.
65086519
SourceLocation getLParenLoc() const { return LParenLoc; }
65096520

6510-
/// Return ThreadLimit number.
6511-
Expr *getThreadLimit() { return cast<Expr>(ThreadLimit); }
6521+
/// Return ThreadLimit expressions.
6522+
ArrayRef<Expr *> getThreadLimit() { return getVarRefs(); }
65126523

6513-
/// Return ThreadLimit number.
6514-
Expr *getThreadLimit() const { return cast<Expr>(ThreadLimit); }
6524+
/// Return ThreadLimit expressions.
6525+
ArrayRef<Expr *> getThreadLimit() const {
6526+
return const_cast<OMPThreadLimitClause *>(this)->getThreadLimit();
6527+
}
65156528

6516-
child_range children() { return child_range(&ThreadLimit, &ThreadLimit + 1); }
6529+
child_range children() {
6530+
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
6531+
reinterpret_cast<Stmt **>(varlist_end()));
6532+
}
65176533

65186534
const_child_range children() const {
6519-
return const_child_range(&ThreadLimit, &ThreadLimit + 1);
6535+
auto Children = const_cast<OMPThreadLimitClause *>(this)->children();
6536+
return const_child_range(Children.begin(), Children.end());
65206537
}
65216538

65226539
child_range used_children() {

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3836,8 +3836,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
38363836
template <typename Derived>
38373837
bool RecursiveASTVisitor<Derived>::VisitOMPThreadLimitClause(
38383838
OMPThreadLimitClause *C) {
3839+
TRY_TO(VisitOMPClauseList(C));
38393840
TRY_TO(VisitOMPClauseWithPreInit(C));
3840-
TRY_TO(TraverseStmt(C->getThreadLimit()));
38413841
return true;
38423842
}
38433843

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,7 @@ class SemaOpenMP : public SemaBase {
12641264
SourceLocation LParenLoc,
12651265
SourceLocation EndLoc);
12661266
/// Called on well-formed 'thread_limit' clause.
1267-
OMPClause *ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
1267+
OMPClause *ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList,
12681268
SourceLocation StartLoc,
12691269
SourceLocation LParenLoc,
12701270
SourceLocation EndLoc);

clang/lib/AST/OpenMPClause.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,24 @@ OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
17731773
return new (Mem) OMPNumTeamsClause(N);
17741774
}
17751775

1776+
OMPThreadLimitClause *OMPThreadLimitClause::Create(
1777+
const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
1778+
SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
1779+
ArrayRef<Expr *> VL, Stmt *PreInit) {
1780+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
1781+
OMPThreadLimitClause *Clause =
1782+
new (Mem) OMPThreadLimitClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
1783+
Clause->setVarRefs(VL);
1784+
Clause->setPreInitStmt(PreInit, CaptureRegion);
1785+
return Clause;
1786+
}
1787+
1788+
OMPThreadLimitClause *OMPThreadLimitClause::CreateEmpty(const ASTContext &C,
1789+
unsigned N) {
1790+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
1791+
return new (Mem) OMPThreadLimitClause(N);
1792+
}
1793+
17761794
//===----------------------------------------------------------------------===//
17771795
// OpenMP clauses printing methods
17781796
//===----------------------------------------------------------------------===//
@@ -2081,9 +2099,11 @@ void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
20812099
}
20822100

20832101
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
2084-
OS << "thread_limit(";
2085-
Node->getThreadLimit()->printPretty(OS, nullptr, Policy, 0);
2086-
OS << ")";
2102+
if (!Node->varlist_empty()) {
2103+
OS << "thread_limit";
2104+
VisitOMPClauseList(Node, '(');
2105+
OS << ")";
2106+
}
20872107
}
20882108

20892109
void OMPClausePrinter::VisitOMPPriorityClause(OMPPriorityClause *Node) {

clang/lib/AST/StmtProfile.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,9 +862,8 @@ void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
862862
}
863863
void OMPClauseProfiler::VisitOMPThreadLimitClause(
864864
const OMPThreadLimitClause *C) {
865+
VisitOMPClauseList(C);
865866
VistOMPClauseWithPreInit(C);
866-
if (C->getThreadLimit())
867-
Profiler->VisitStmt(C->getThreadLimit());
868867
}
869868
void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) {
870869
VistOMPClauseWithPreInit(C);

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6332,7 +6332,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
63326332
CGOpenMPInnerExprInfo CGInfo(CGF, *CS);
63336333
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
63346334
CodeGenFunction::LexicalScope Scope(
6335-
CGF, ThreadLimitClause->getThreadLimit()->getSourceRange());
6335+
CGF,
6336+
ThreadLimitClause->getThreadLimit().front()->getSourceRange());
63366337
if (const auto *PreInit =
63376338
cast_or_null<DeclStmt>(ThreadLimitClause->getPreInitStmt())) {
63386339
for (const auto *I : PreInit->decls()) {
@@ -6349,7 +6350,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
63496350
}
63506351
}
63516352
if (ThreadLimitClause)
6352-
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
6353+
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
6354+
ThreadLimitExpr);
63536355
if (const auto *Dir = dyn_cast_or_null<OMPExecutableDirective>(Child)) {
63546356
if (isOpenMPTeamsDirective(Dir->getDirectiveKind()) &&
63556357
!isOpenMPDistributeDirective(Dir->getDirectiveKind())) {
@@ -6370,7 +6372,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
63706372
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
63716373
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
63726374
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
6373-
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
6375+
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
6376+
ThreadLimitExpr);
63746377
}
63756378
const CapturedStmt *CS = D.getInnermostCapturedStmt();
63766379
getNumThreads(CGF, CS, NTPtr, UpperBound, UpperBoundOnly, CondVal);
@@ -6388,7 +6391,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
63886391
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
63896392
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
63906393
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
6391-
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
6394+
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
6395+
ThreadLimitExpr);
63926396
}
63936397
getNumThreads(CGF, D.getInnermostCapturedStmt(), NTPtr, UpperBound,
63946398
UpperBoundOnly, CondVal);
@@ -6424,7 +6428,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
64246428
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
64256429
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
64266430
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
6427-
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
6431+
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
6432+
ThreadLimitExpr);
64286433
}
64296434
if (D.hasClausesOfKind<OMPNumThreadsClause>()) {
64306435
CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF);

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5259,7 +5259,7 @@ void CodeGenFunction::EmitOMPTargetTaskBasedDirective(
52595259
// enclosing this target region. This will indirectly set the thread_limit
52605260
// for every applicable construct within target region.
52615261
CGF.CGM.getOpenMPRuntime().emitThreadLimitClause(
5262-
CGF, TL->getThreadLimit(), S.getBeginLoc());
5262+
CGF, TL->getThreadLimit().front(), S.getBeginLoc());
52635263
}
52645264
BodyGen(CGF);
52655265
};
@@ -6860,7 +6860,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
68606860
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
68616861
if (NT || TL) {
68626862
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
6863-
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
6863+
const Expr *ThreadLimit = TL ? TL->getThreadLimit().front() : nullptr;
68646864

68656865
CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
68666866
S.getBeginLoc());

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
31753175
case OMPC_simdlen:
31763176
case OMPC_collapse:
31773177
case OMPC_ordered:
3178-
case OMPC_thread_limit:
31793178
case OMPC_priority:
31803179
case OMPC_grainsize:
31813180
case OMPC_num_tasks:
@@ -3332,6 +3331,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
33323331
: ParseOpenMPClause(CKind, WrongDirective);
33333332
break;
33343333
case OMPC_num_teams:
3334+
case OMPC_thread_limit:
33353335
if (!FirstClause) {
33363336
Diag(Tok, diag::err_omp_more_one_clause)
33373337
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;

0 commit comments

Comments
 (0)