Skip to content

Commit f624ba2

Browse files
authored
[OpenMP][clang] 6.0: parsing/sema for num_threads 'strict' modifier (#145490)
Implement parsing and semantic analysis support for the optional 'strict' modifier of the num_threads clause. This modifier has been introduced in OpenMP 6.0, section 12.1.2. Note: this is basically 1:1 https://reviews.llvm.org/D138328.
1 parent 4d21da0 commit f624ba2

File tree

14 files changed

+272
-50
lines changed

14 files changed

+272
-50
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -824,31 +824,52 @@ class OMPNumThreadsClause final
824824
public OMPClauseWithPreInit {
825825
friend class OMPClauseReader;
826826

827+
/// Modifiers for 'num_threads' clause.
828+
OpenMPNumThreadsClauseModifier Modifier = OMPC_NUMTHREADS_unknown;
829+
830+
/// Location of the modifier.
831+
SourceLocation ModifierLoc;
832+
833+
/// Sets modifier.
834+
void setModifier(OpenMPNumThreadsClauseModifier M) { Modifier = M; }
835+
836+
/// Sets modifier location.
837+
void setModifierLoc(SourceLocation Loc) { ModifierLoc = Loc; }
838+
827839
/// Set condition.
828840
void setNumThreads(Expr *NThreads) { setStmt(NThreads); }
829841

830842
public:
831843
/// Build 'num_threads' clause with condition \a NumThreads.
832844
///
845+
/// \param Modifier Clause modifier.
833846
/// \param NumThreads Number of threads for the construct.
834847
/// \param HelperNumThreads Helper Number of threads for the construct.
835848
/// \param CaptureRegion Innermost OpenMP region where expressions in this
836849
/// clause must be captured.
837850
/// \param StartLoc Starting location of the clause.
838851
/// \param LParenLoc Location of '('.
852+
/// \param ModifierLoc Modifier location.
839853
/// \param EndLoc Ending location of the clause.
840-
OMPNumThreadsClause(Expr *NumThreads, Stmt *HelperNumThreads,
841-
OpenMPDirectiveKind CaptureRegion,
854+
OMPNumThreadsClause(OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
855+
Stmt *HelperNumThreads, OpenMPDirectiveKind CaptureRegion,
842856
SourceLocation StartLoc, SourceLocation LParenLoc,
843-
SourceLocation EndLoc)
857+
SourceLocation ModifierLoc, SourceLocation EndLoc)
844858
: OMPOneStmtClause(NumThreads, StartLoc, LParenLoc, EndLoc),
845-
OMPClauseWithPreInit(this) {
859+
OMPClauseWithPreInit(this), Modifier(Modifier),
860+
ModifierLoc(ModifierLoc) {
846861
setPreInitStmt(HelperNumThreads, CaptureRegion);
847862
}
848863

849864
/// Build an empty clause.
850865
OMPNumThreadsClause() : OMPOneStmtClause(), OMPClauseWithPreInit(this) {}
851866

867+
/// Gets modifier.
868+
OpenMPNumThreadsClauseModifier getModifier() const { return Modifier; }
869+
870+
/// Gets modifier location.
871+
SourceLocation getModifierLoc() const { return ModifierLoc; }
872+
852873
/// Returns number of threads.
853874
Expr *getNumThreads() const { return getStmtAs<Expr>(); }
854875
};

clang/include/clang/Basic/OpenMPKinds.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@
8686
#ifndef OPENMP_NUMTASKS_MODIFIER
8787
#define OPENMP_NUMTASKS_MODIFIER(Name)
8888
#endif
89+
#ifndef OPENMP_NUMTHREADS_MODIFIER
90+
#define OPENMP_NUMTHREADS_MODIFIER(Name)
91+
#endif
8992
#ifndef OPENMP_DOACROSS_MODIFIER
9093
#define OPENMP_DOACROSS_MODIFIER(Name)
9194
#endif
@@ -227,6 +230,9 @@ OPENMP_GRAINSIZE_MODIFIER(strict)
227230
// Modifiers for the 'num_tasks' clause.
228231
OPENMP_NUMTASKS_MODIFIER(strict)
229232

233+
// Modifiers for the 'num_tasks' clause.
234+
OPENMP_NUMTHREADS_MODIFIER(strict)
235+
230236
// Modifiers for 'allocate' clause.
231237
OPENMP_ALLOCATE_MODIFIER(allocator)
232238
OPENMP_ALLOCATE_MODIFIER(align)
@@ -238,6 +244,7 @@ OPENMP_DOACROSS_MODIFIER(sink_omp_cur_iteration)
238244
OPENMP_DOACROSS_MODIFIER(source_omp_cur_iteration)
239245

240246
#undef OPENMP_NUMTASKS_MODIFIER
247+
#undef OPENMP_NUMTHREADS_MODIFIER
241248
#undef OPENMP_GRAINSIZE_MODIFIER
242249
#undef OPENMP_BIND_KIND
243250
#undef OPENMP_ADJUST_ARGS_KIND

clang/include/clang/Basic/OpenMPKinds.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ enum OpenMPNumTasksClauseModifier {
223223
OMPC_NUMTASKS_unknown
224224
};
225225

226+
enum OpenMPNumThreadsClauseModifier {
227+
#define OPENMP_NUMTHREADS_MODIFIER(Name) OMPC_NUMTHREADS_##Name,
228+
#include "clang/Basic/OpenMPKinds.def"
229+
OMPC_NUMTHREADS_unknown
230+
};
231+
226232
/// OpenMP dependence types for 'doacross' clause.
227233
enum OpenMPDoacrossClauseModifier {
228234
#define OPENMP_DOACROSS_MODIFIER(Name) OMPC_DOACROSS_##Name,

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -877,10 +877,10 @@ class SemaOpenMP : public SemaBase {
877877
SourceLocation LParenLoc,
878878
SourceLocation EndLoc);
879879
/// Called on well-formed 'num_threads' clause.
880-
OMPClause *ActOnOpenMPNumThreadsClause(Expr *NumThreads,
881-
SourceLocation StartLoc,
882-
SourceLocation LParenLoc,
883-
SourceLocation EndLoc);
880+
OMPClause *ActOnOpenMPNumThreadsClause(
881+
OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
882+
SourceLocation StartLoc, SourceLocation LParenLoc,
883+
SourceLocation ModifierLoc, SourceLocation EndLoc);
884884
/// Called on well-formed 'align' clause.
885885
OMPClause *ActOnOpenMPAlignClause(Expr *Alignment, SourceLocation StartLoc,
886886
SourceLocation LParenLoc,

clang/lib/AST/OpenMPClause.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,6 +1830,11 @@ void OMPClausePrinter::VisitOMPFinalClause(OMPFinalClause *Node) {
18301830

18311831
void OMPClausePrinter::VisitOMPNumThreadsClause(OMPNumThreadsClause *Node) {
18321832
OS << "num_threads(";
1833+
OpenMPNumThreadsClauseModifier Modifier = Node->getModifier();
1834+
if (Modifier != OMPC_NUMTHREADS_unknown) {
1835+
OS << getOpenMPSimpleClauseTypeName(Node->getClauseKind(), Modifier)
1836+
<< ": ";
1837+
}
18331838
Node->getNumThreads()->printPretty(OS, nullptr, Policy, 0);
18341839
OS << ")";
18351840
}

clang/lib/Basic/OpenMPKinds.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,19 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str,
185185
#define OPENMP_ALLOCATE_MODIFIER(Name) .Case(#Name, OMPC_ALLOCATE_##Name)
186186
#include "clang/Basic/OpenMPKinds.def"
187187
.Default(OMPC_ALLOCATE_unknown);
188+
case OMPC_num_threads: {
189+
unsigned Type = llvm::StringSwitch<unsigned>(Str)
190+
#define OPENMP_NUMTHREADS_MODIFIER(Name) .Case(#Name, OMPC_NUMTHREADS_##Name)
191+
#include "clang/Basic/OpenMPKinds.def"
192+
.Default(OMPC_NUMTHREADS_unknown);
193+
if (LangOpts.OpenMP < 60)
194+
return OMPC_NUMTHREADS_unknown;
195+
return Type;
196+
}
188197
case OMPC_unknown:
189198
case OMPC_threadprivate:
190199
case OMPC_if:
191200
case OMPC_final:
192-
case OMPC_num_threads:
193201
case OMPC_safelen:
194202
case OMPC_simdlen:
195203
case OMPC_sizes:
@@ -520,11 +528,20 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
520528
#include "clang/Basic/OpenMPKinds.def"
521529
}
522530
llvm_unreachable("Invalid OpenMP 'allocate' clause modifier");
531+
case OMPC_num_threads:
532+
switch (Type) {
533+
case OMPC_NUMTHREADS_unknown:
534+
return "unknown";
535+
#define OPENMP_NUMTHREADS_MODIFIER(Name) \
536+
case OMPC_NUMTHREADS_##Name: \
537+
return #Name;
538+
#include "clang/Basic/OpenMPKinds.def"
539+
}
540+
llvm_unreachable("Invalid OpenMP 'num_threads' clause modifier");
523541
case OMPC_unknown:
524542
case OMPC_threadprivate:
525543
case OMPC_if:
526544
case OMPC_final:
527-
case OMPC_num_threads:
528545
case OMPC_safelen:
529546
case OMPC_simdlen:
530547
case OMPC_sizes:

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3196,7 +3196,8 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
31963196
if ((CKind == OMPC_ordered || CKind == OMPC_partial) &&
31973197
PP.LookAhead(/*N=*/0).isNot(tok::l_paren))
31983198
Clause = ParseOpenMPClause(CKind, WrongDirective);
3199-
else if (CKind == OMPC_grainsize || CKind == OMPC_num_tasks)
3199+
else if (CKind == OMPC_grainsize || CKind == OMPC_num_tasks ||
3200+
CKind == OMPC_num_threads)
32003201
Clause = ParseOpenMPSingleExprWithArgClause(DKind, CKind, WrongDirective);
32013202
else
32023203
Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective);
@@ -3981,6 +3982,33 @@ OMPClause *Parser::ParseOpenMPSingleExprWithArgClause(OpenMPDirectiveKind DKind,
39813982
Arg.push_back(OMPC_NUMTASKS_unknown);
39823983
KLoc.emplace_back();
39833984
}
3985+
} else if (Kind == OMPC_num_threads) {
3986+
// Parse optional <num_threads modifier> ':'
3987+
OpenMPNumThreadsClauseModifier Modifier =
3988+
static_cast<OpenMPNumThreadsClauseModifier>(getOpenMPSimpleClauseType(
3989+
Kind, Tok.isAnnotation() ? "" : PP.getSpelling(Tok),
3990+
getLangOpts()));
3991+
if (getLangOpts().OpenMP >= 60) {
3992+
if (NextToken().is(tok::colon)) {
3993+
Arg.push_back(Modifier);
3994+
KLoc.push_back(Tok.getLocation());
3995+
// Parse modifier
3996+
ConsumeAnyToken();
3997+
// Parse ':'
3998+
ConsumeAnyToken();
3999+
} else {
4000+
if (Modifier == OMPC_NUMTHREADS_strict) {
4001+
Diag(Tok, diag::err_modifier_expected_colon) << "strict";
4002+
// Parse modifier
4003+
ConsumeAnyToken();
4004+
}
4005+
Arg.push_back(OMPC_NUMTHREADS_unknown);
4006+
KLoc.emplace_back();
4007+
}
4008+
} else {
4009+
Arg.push_back(OMPC_NUMTHREADS_unknown);
4010+
KLoc.emplace_back();
4011+
}
39844012
} else {
39854013
assert(Kind == OMPC_if);
39864014
KLoc.push_back(Tok.getLocation());
@@ -4004,7 +4032,8 @@ OMPClause *Parser::ParseOpenMPSingleExprWithArgClause(OpenMPDirectiveKind DKind,
40044032
bool NeedAnExpression = (Kind == OMPC_schedule && DelimLoc.isValid()) ||
40054033
(Kind == OMPC_dist_schedule && DelimLoc.isValid()) ||
40064034
Kind == OMPC_if || Kind == OMPC_device ||
4007-
Kind == OMPC_grainsize || Kind == OMPC_num_tasks;
4035+
Kind == OMPC_grainsize || Kind == OMPC_num_tasks ||
4036+
Kind == OMPC_num_threads;
40084037
if (NeedAnExpression) {
40094038
SourceLocation ELoc = Tok.getLocation();
40104039
ExprResult LHS(ParseCastExpression(CastParseKind::AnyCastExpr, false,

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15509,9 +15509,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1550915509
case OMPC_final:
1551015510
Res = ActOnOpenMPFinalClause(Expr, StartLoc, LParenLoc, EndLoc);
1551115511
break;
15512-
case OMPC_num_threads:
15513-
Res = ActOnOpenMPNumThreadsClause(Expr, StartLoc, LParenLoc, EndLoc);
15514-
break;
1551515512
case OMPC_safelen:
1551615513
Res = ActOnOpenMPSafelenClause(Expr, StartLoc, LParenLoc, EndLoc);
1551715514
break;
@@ -15565,6 +15562,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1556515562
break;
1556615563
case OMPC_grainsize:
1556715564
case OMPC_num_tasks:
15565+
case OMPC_num_threads:
1556815566
case OMPC_device:
1556915567
case OMPC_if:
1557015568
case OMPC_default:
@@ -15911,10 +15909,41 @@ isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef, OpenMPClauseKind CKind,
1591115909
return true;
1591215910
}
1591315911

15914-
OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(Expr *NumThreads,
15915-
SourceLocation StartLoc,
15916-
SourceLocation LParenLoc,
15917-
SourceLocation EndLoc) {
15912+
static std::string getListOfPossibleValues(OpenMPClauseKind K, unsigned First,
15913+
unsigned Last,
15914+
ArrayRef<unsigned> Exclude = {}) {
15915+
SmallString<256> Buffer;
15916+
llvm::raw_svector_ostream Out(Buffer);
15917+
unsigned Skipped = Exclude.size();
15918+
for (unsigned I = First; I < Last; ++I) {
15919+
if (llvm::is_contained(Exclude, I)) {
15920+
--Skipped;
15921+
continue;
15922+
}
15923+
Out << "'" << getOpenMPSimpleClauseTypeName(K, I) << "'";
15924+
if (I + Skipped + 2 == Last)
15925+
Out << " or ";
15926+
else if (I + Skipped + 1 != Last)
15927+
Out << ", ";
15928+
}
15929+
return std::string(Out.str());
15930+
}
15931+
15932+
OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(
15933+
OpenMPNumThreadsClauseModifier Modifier, Expr *NumThreads,
15934+
SourceLocation StartLoc, SourceLocation LParenLoc,
15935+
SourceLocation ModifierLoc, SourceLocation EndLoc) {
15936+
assert((ModifierLoc.isInvalid() || getLangOpts().OpenMP >= 60) &&
15937+
"Unexpected num_threads modifier in OpenMP < 60.");
15938+
15939+
if (ModifierLoc.isValid() && Modifier == OMPC_NUMTHREADS_unknown) {
15940+
std::string Values = getListOfPossibleValues(OMPC_num_threads, /*First=*/0,
15941+
OMPC_NUMTHREADS_unknown);
15942+
Diag(ModifierLoc, diag::err_omp_unexpected_clause_value)
15943+
<< Values << getOpenMPClauseNameForDiag(OMPC_num_threads);
15944+
return nullptr;
15945+
}
15946+
1591815947
Expr *ValExpr = NumThreads;
1591915948
Stmt *HelperValStmt = nullptr;
1592015949

@@ -15935,8 +15964,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumThreadsClause(Expr *NumThreads,
1593515964
HelperValStmt = buildPreInits(getASTContext(), Captures);
1593615965
}
1593715966

15938-
return new (getASTContext()) OMPNumThreadsClause(
15939-
ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
15967+
return new (getASTContext())
15968+
OMPNumThreadsClause(Modifier, ValExpr, HelperValStmt, CaptureRegion,
15969+
StartLoc, LParenLoc, ModifierLoc, EndLoc);
1594015970
}
1594115971

1594215972
ExprResult SemaOpenMP::VerifyPositiveIntegerConstantInClause(
@@ -16301,26 +16331,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSimpleClause(
1630116331
return Res;
1630216332
}
1630316333

16304-
static std::string getListOfPossibleValues(OpenMPClauseKind K, unsigned First,
16305-
unsigned Last,
16306-
ArrayRef<unsigned> Exclude = {}) {
16307-
SmallString<256> Buffer;
16308-
llvm::raw_svector_ostream Out(Buffer);
16309-
unsigned Skipped = Exclude.size();
16310-
for (unsigned I = First; I < Last; ++I) {
16311-
if (llvm::is_contained(Exclude, I)) {
16312-
--Skipped;
16313-
continue;
16314-
}
16315-
Out << "'" << getOpenMPSimpleClauseTypeName(K, I) << "'";
16316-
if (I + Skipped + 2 == Last)
16317-
Out << " or ";
16318-
else if (I + Skipped + 1 != Last)
16319-
Out << ", ";
16320-
}
16321-
return std::string(Out.str());
16322-
}
16323-
1632416334
OMPClause *SemaOpenMP::ActOnOpenMPDefaultClause(DefaultKind Kind,
1632516335
SourceLocation KindKwLoc,
1632616336
SourceLocation StartLoc,
@@ -16693,8 +16703,14 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprWithArgClause(
1669316703
static_cast<OpenMPNumTasksClauseModifier>(Argument.back()), Expr,
1669416704
StartLoc, LParenLoc, ArgumentLoc.back(), EndLoc);
1669516705
break;
16696-
case OMPC_final:
1669716706
case OMPC_num_threads:
16707+
assert(Argument.size() == 1 && ArgumentLoc.size() == 1 &&
16708+
"Modifier for num_threads clause and its location are expected.");
16709+
Res = ActOnOpenMPNumThreadsClause(
16710+
static_cast<OpenMPNumThreadsClauseModifier>(Argument.back()), Expr,
16711+
StartLoc, LParenLoc, ArgumentLoc.back(), EndLoc);
16712+
break;
16713+
case OMPC_final:
1669816714
case OMPC_safelen:
1669916715
case OMPC_simdlen:
1670016716
case OMPC_sizes:

clang/lib/Sema/TreeTransform.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,12 +1714,14 @@ class TreeTransform {
17141714
///
17151715
/// By default, performs semantic analysis to build the new OpenMP clause.
17161716
/// Subclasses may override this routine to provide different behavior.
1717-
OMPClause *RebuildOMPNumThreadsClause(Expr *NumThreads,
1717+
OMPClause *RebuildOMPNumThreadsClause(OpenMPNumThreadsClauseModifier Modifier,
1718+
Expr *NumThreads,
17181719
SourceLocation StartLoc,
17191720
SourceLocation LParenLoc,
1721+
SourceLocation ModifierLoc,
17201722
SourceLocation EndLoc) {
1721-
return getSema().OpenMP().ActOnOpenMPNumThreadsClause(NumThreads, StartLoc,
1722-
LParenLoc, EndLoc);
1723+
return getSema().OpenMP().ActOnOpenMPNumThreadsClause(
1724+
Modifier, NumThreads, StartLoc, LParenLoc, ModifierLoc, EndLoc);
17231725
}
17241726

17251727
/// Build a new OpenMP 'safelen' clause.
@@ -10461,7 +10463,8 @@ TreeTransform<Derived>::TransformOMPNumThreadsClause(OMPNumThreadsClause *C) {
1046110463
if (NumThreads.isInvalid())
1046210464
return nullptr;
1046310465
return getDerived().RebuildOMPNumThreadsClause(
10464-
NumThreads.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
10466+
C->getModifier(), NumThreads.get(), C->getBeginLoc(), C->getLParenLoc(),
10467+
C->getModifierLoc(), C->getEndLoc());
1046510468
}
1046610469

1046710470
template <typename Derived>

clang/lib/Serialization/ASTReader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11461,7 +11461,9 @@ void OMPClauseReader::VisitOMPFinalClause(OMPFinalClause *C) {
1146111461

1146211462
void OMPClauseReader::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
1146311463
VisitOMPClauseWithPreInit(C);
11464+
C->setModifier(Record.readEnum<OpenMPNumThreadsClauseModifier>());
1146411465
C->setNumThreads(Record.readSubExpr());
11466+
C->setModifierLoc(Record.readSourceLocation());
1146511467
C->setLParenLoc(Record.readSourceLocation());
1146611468
}
1146711469

clang/lib/Serialization/ASTWriter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7802,7 +7802,9 @@ void OMPClauseWriter::VisitOMPFinalClause(OMPFinalClause *C) {
78027802

78037803
void OMPClauseWriter::VisitOMPNumThreadsClause(OMPNumThreadsClause *C) {
78047804
VisitOMPClauseWithPreInit(C);
7805+
Record.writeEnum(C->getModifier());
78057806
Record.AddStmt(C->getNumThreads());
7807+
Record.AddSourceLocation(C->getModifierLoc());
78067808
Record.AddSourceLocation(C->getLParenLoc());
78077809
}
78087810

0 commit comments

Comments
 (0)