Skip to content

Commit 76600ae

Browse files
authored
[OpenACC] Implement 'num_workers' clause for compute constructs (#89151)
This clause just takes an 'int expr', which is not optional. This patch implements the clause on compute constructs.
1 parent 8ba0041 commit 76600ae

18 files changed

+746
-27
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,46 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition {
156156
Expr *ConditionExpr, SourceLocation EndLoc);
157157
};
158158

159+
/// Represents oen of a handful of classes that have a single integer
160+
/// expression.
161+
class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithParams {
162+
Expr *IntExpr;
163+
164+
protected:
165+
OpenACCClauseWithSingleIntExpr(OpenACCClauseKind K, SourceLocation BeginLoc,
166+
SourceLocation LParenLoc, Expr *IntExpr,
167+
SourceLocation EndLoc)
168+
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
169+
IntExpr(IntExpr) {}
170+
171+
public:
172+
bool hasIntExpr() const { return IntExpr; }
173+
const Expr *getIntExpr() const { return IntExpr; }
174+
175+
Expr *getIntExpr() { return IntExpr; };
176+
177+
child_range children() {
178+
return child_range(reinterpret_cast<Stmt **>(&IntExpr),
179+
reinterpret_cast<Stmt **>(&IntExpr + 1));
180+
}
181+
182+
const_child_range children() const {
183+
return const_child_range(reinterpret_cast<Stmt *const *>(&IntExpr),
184+
reinterpret_cast<Stmt *const *>(&IntExpr + 1));
185+
}
186+
};
187+
188+
class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {
189+
OpenACCNumWorkersClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
190+
Expr *IntExpr, SourceLocation EndLoc);
191+
192+
public:
193+
static OpenACCNumWorkersClause *Create(const ASTContext &C,
194+
SourceLocation BeginLoc,
195+
SourceLocation LParenLoc,
196+
Expr *IntExpr, SourceLocation EndLoc);
197+
};
198+
159199
template <class Impl> class OpenACCClauseVisitor {
160200
Impl &getDerived() { return static_cast<Impl &>(*this); }
161201

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12271,4 +12271,16 @@ def warn_acc_if_self_conflict
1227112271
: Warning<"OpenACC construct 'self' has no effect when an 'if' clause "
1227212272
"evaluates to true">,
1227312273
InGroup<DiagGroup<"openacc-self-if-potential-conflict">>;
12274+
def err_acc_int_expr_requires_integer
12275+
: Error<"OpenACC %select{clause|directive}0 '%1' requires expression of "
12276+
"integer type (%2 invalid)">;
12277+
def err_acc_int_expr_incomplete_class_type
12278+
: Error<"OpenACC integer expression has incomplete class type %0">;
12279+
def err_acc_int_expr_explicit_conversion
12280+
: Error<"OpenACC integer expression type %0 requires explicit conversion "
12281+
"to %1">;
12282+
def note_acc_int_expr_conversion
12283+
: Note<"conversion to %select{integral|enumeration}0 type %1">;
12284+
def err_acc_int_expr_multiple_conversions
12285+
: Error<"multiple conversions from expression type %0 to an integral type">;
1227412286
} // end of sema component.

clang/include/clang/Basic/OpenACCClauses.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
VISIT_CLAUSE(Default)
1919
VISIT_CLAUSE(If)
2020
VISIT_CLAUSE(Self)
21+
VISIT_CLAUSE(NumWorkers)
2122

2223
#undef VISIT_CLAUSE

clang/include/clang/Parse/Parser.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3640,13 +3640,14 @@ class Parser : public CodeCompletionHandler {
36403640
/// Parses the clause-list for an OpenACC directive.
36413641
SmallVector<OpenACCClause *>
36423642
ParseOpenACCClauseList(OpenACCDirectiveKind DirKind);
3643-
bool ParseOpenACCWaitArgument();
3643+
bool ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective);
36443644
/// Parses the clause of the 'bind' argument, which can be a string literal or
36453645
/// an ID expression.
36463646
ExprResult ParseOpenACCBindClauseArgument();
36473647
/// Parses the clause kind of 'int-expr', which can be any integral
36483648
/// expression.
3649-
ExprResult ParseOpenACCIntExpr();
3649+
ExprResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
3650+
SourceLocation Loc);
36503651
/// Parses the 'device-type-list', which is a list of identifiers.
36513652
bool ParseOpenACCDeviceTypeList();
36523653
/// Parses the 'async-argument', which is an integral value with two
@@ -3657,9 +3658,9 @@ class Parser : public CodeCompletionHandler {
36573658
/// Parses a comma delimited list of 'size-expr's.
36583659
bool ParseOpenACCSizeExprList();
36593660
/// Parses a 'gang-arg-list', used for the 'gang' clause.
3660-
bool ParseOpenACCGangArgList();
3661+
bool ParseOpenACCGangArgList(SourceLocation GangLoc);
36613662
/// Parses a 'gang-arg', used for the 'gang' clause.
3662-
bool ParseOpenACCGangArg();
3663+
bool ParseOpenACCGangArg(SourceLocation GangLoc);
36633664
/// Parses a 'condition' expr, ensuring it results in a
36643665
ExprResult ParseOpenACCConditionExpr();
36653666

clang/include/clang/Sema/SemaOpenACC.h

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,13 @@ class SemaOpenACC : public SemaBase {
4444
Expr *ConditionExpr;
4545
};
4646

47-
std::variant<std::monostate, DefaultDetails, ConditionDetails> Details =
48-
std::monostate{};
47+
struct IntExprDetails {
48+
SmallVector<Expr *> IntExprs;
49+
};
50+
51+
std::variant<std::monostate, DefaultDetails, ConditionDetails,
52+
IntExprDetails>
53+
Details = std::monostate{};
4954

5055
public:
5156
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
@@ -87,6 +92,22 @@ class SemaOpenACC : public SemaBase {
8792
return std::get<ConditionDetails>(Details).ConditionExpr;
8893
}
8994

95+
unsigned getNumIntExprs() const {
96+
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
97+
"Parsed clause kind does not have a int exprs");
98+
return std::get<IntExprDetails>(Details).IntExprs.size();
99+
}
100+
101+
ArrayRef<Expr *> getIntExprs() {
102+
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
103+
"Parsed clause kind does not have a int exprs");
104+
return std::get<IntExprDetails>(Details).IntExprs;
105+
}
106+
107+
ArrayRef<Expr *> getIntExprs() const {
108+
return const_cast<OpenACCParsedClause *>(this)->getIntExprs();
109+
}
110+
90111
void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
91112
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }
92113

@@ -109,6 +130,12 @@ class SemaOpenACC : public SemaBase {
109130

110131
Details = ConditionDetails{ConditionExpr};
111132
}
133+
134+
void setIntExprDetails(ArrayRef<Expr *> IntExprs) {
135+
assert(ClauseKind == OpenACCClauseKind::NumWorkers &&
136+
"Parsed clause kind does not have a int exprs");
137+
Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
138+
}
112139
};
113140

114141
SemaOpenACC(Sema &S);
@@ -148,6 +175,11 @@ class SemaOpenACC : public SemaBase {
148175
/// Called after the directive has been completely parsed, including the
149176
/// declaration group or associated statement.
150177
DeclGroupRef ActOnEndDeclDirective();
178+
179+
/// Called when encountering an 'int-expr' for OpenACC, and manages
180+
/// conversions and diagnostics to 'int'.
181+
ExprResult ActOnIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
182+
SourceLocation Loc, Expr *IntExpr);
151183
};
152184

153185
} // namespace clang

clang/lib/AST/OpenACCClause.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,27 @@ OpenACCClause::child_range OpenACCClause::children() {
8282
return child_range(child_iterator(), child_iterator());
8383
}
8484

85+
OpenACCNumWorkersClause::OpenACCNumWorkersClause(SourceLocation BeginLoc,
86+
SourceLocation LParenLoc,
87+
Expr *IntExpr,
88+
SourceLocation EndLoc)
89+
: OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::NumWorkers, BeginLoc,
90+
LParenLoc, IntExpr, EndLoc) {
91+
assert((!IntExpr || IntExpr->isInstantiationDependent() ||
92+
IntExpr->getType()->isIntegerType()) &&
93+
"Condition expression type not scalar/dependent");
94+
}
95+
96+
OpenACCNumWorkersClause *
97+
OpenACCNumWorkersClause::Create(const ASTContext &C, SourceLocation BeginLoc,
98+
SourceLocation LParenLoc, Expr *IntExpr,
99+
SourceLocation EndLoc) {
100+
void *Mem = C.Allocate(sizeof(OpenACCNumWorkersClause),
101+
alignof(OpenACCNumWorkersClause));
102+
return new (Mem)
103+
OpenACCNumWorkersClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
104+
}
105+
85106
//===----------------------------------------------------------------------===//
86107
// OpenACC clauses printing methods
87108
//===----------------------------------------------------------------------===//
@@ -98,3 +119,8 @@ void OpenACCClausePrinter::VisitSelfClause(const OpenACCSelfClause &C) {
98119
if (const Expr *CondExpr = C.getConditionExpr())
99120
OS << "(" << CondExpr << ")";
100121
}
122+
123+
void OpenACCClausePrinter::VisitNumWorkersClause(
124+
const OpenACCNumWorkersClause &C) {
125+
OS << "num_workers(" << C.getIntExpr() << ")";
126+
}

clang/lib/AST/StmtProfile.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,13 @@ void OpenACCClauseProfiler::VisitSelfClause(const OpenACCSelfClause &Clause) {
24962496
if (Clause.hasConditionExpr())
24972497
Profiler.VisitStmt(Clause.getConditionExpr());
24982498
}
2499+
2500+
void OpenACCClauseProfiler::VisitNumWorkersClause(
2501+
const OpenACCNumWorkersClause &Clause) {
2502+
assert(Clause.hasIntExpr() && "num_workers clause requires a valid int expr");
2503+
Profiler.VisitStmt(Clause.getIntExpr());
2504+
}
2505+
24992506
} // namespace
25002507

25012508
void StmtProfiler::VisitOpenACCComputeConstruct(

clang/lib/AST/TextNodeDumper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
399399
break;
400400
case OpenACCClauseKind::If:
401401
case OpenACCClauseKind::Self:
402+
case OpenACCClauseKind::NumWorkers:
402403
// The condition expression will be printed as a part of the 'children',
403404
// but print 'clause' here so it is clear what is happening from the dump.
404405
OS << " clause";

clang/lib/Parse/ParseOpenACC.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -632,10 +632,16 @@ Parser::ParseOpenACCClauseList(OpenACCDirectiveKind DirKind) {
632632
return Clauses;
633633
}
634634

635-
ExprResult Parser::ParseOpenACCIntExpr() {
636-
// FIXME: this is required to be an integer expression (or dependent), so we
637-
// should ensure that is the case by passing this to SEMA here.
638-
return getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());
635+
ExprResult Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
636+
OpenACCClauseKind CK,
637+
SourceLocation Loc) {
638+
ExprResult ER =
639+
getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());
640+
641+
if (!ER.isUsable())
642+
return ER;
643+
644+
return getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get());
639645
}
640646

641647
bool Parser::ParseOpenACCClauseVarList(OpenACCClauseKind Kind) {
@@ -739,7 +745,7 @@ bool Parser::ParseOpenACCSizeExprList() {
739745
/// [num:]int-expr
740746
/// dim:int-expr
741747
/// static:size-expr
742-
bool Parser::ParseOpenACCGangArg() {
748+
bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {
743749

744750
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Static, getCurToken()) &&
745751
NextToken().is(tok::colon)) {
@@ -753,7 +759,9 @@ bool Parser::ParseOpenACCGangArg() {
753759
NextToken().is(tok::colon)) {
754760
ConsumeToken();
755761
ConsumeToken();
756-
return ParseOpenACCIntExpr().isInvalid();
762+
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
763+
OpenACCClauseKind::Gang, GangLoc)
764+
.isInvalid();
757765
}
758766

759767
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Num, getCurToken()) &&
@@ -763,11 +771,13 @@ bool Parser::ParseOpenACCGangArg() {
763771
// Fallthrough to the 'int-expr' handling for when 'num' is omitted.
764772
}
765773
// This is just the 'num' case where 'num' is optional.
766-
return ParseOpenACCIntExpr().isInvalid();
774+
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
775+
OpenACCClauseKind::Gang, GangLoc)
776+
.isInvalid();
767777
}
768778

769-
bool Parser::ParseOpenACCGangArgList() {
770-
if (ParseOpenACCGangArg()) {
779+
bool Parser::ParseOpenACCGangArgList(SourceLocation GangLoc) {
780+
if (ParseOpenACCGangArg(GangLoc)) {
771781
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
772782
Parser::StopBeforeMatch);
773783
return false;
@@ -776,7 +786,7 @@ bool Parser::ParseOpenACCGangArgList() {
776786
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
777787
ExpectAndConsume(tok::comma);
778788

779-
if (ParseOpenACCGangArg()) {
789+
if (ParseOpenACCGangArg(GangLoc)) {
780790
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
781791
Parser::StopBeforeMatch);
782792
return false;
@@ -941,11 +951,18 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
941951
case OpenACCClauseKind::DeviceNum:
942952
case OpenACCClauseKind::DefaultAsync:
943953
case OpenACCClauseKind::VectorLength: {
944-
ExprResult IntExpr = ParseOpenACCIntExpr();
954+
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
955+
ClauseKind, ClauseLoc);
945956
if (IntExpr.isInvalid()) {
946957
Parens.skipToEnd();
947958
return OpenACCCanContinue();
948959
}
960+
961+
// TODO OpenACC: as we implement the 'rest' of the above, this 'if' should
962+
// be removed leaving just the 'setIntExprDetails'.
963+
if (ClauseKind == OpenACCClauseKind::NumWorkers)
964+
ParsedClause.setIntExprDetails(IntExpr.get());
965+
949966
break;
950967
}
951968
case OpenACCClauseKind::DType:
@@ -998,7 +1015,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
9981015
? OpenACCSpecialTokenKind::Length
9991016
: OpenACCSpecialTokenKind::Num,
10001017
ClauseKind);
1001-
ExprResult IntExpr = ParseOpenACCIntExpr();
1018+
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
1019+
ClauseKind, ClauseLoc);
10021020
if (IntExpr.isInvalid()) {
10031021
Parens.skipToEnd();
10041022
return OpenACCCanContinue();
@@ -1014,13 +1032,14 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
10141032
break;
10151033
}
10161034
case OpenACCClauseKind::Gang:
1017-
if (ParseOpenACCGangArgList()) {
1035+
if (ParseOpenACCGangArgList(ClauseLoc)) {
10181036
Parens.skipToEnd();
10191037
return OpenACCCanContinue();
10201038
}
10211039
break;
10221040
case OpenACCClauseKind::Wait:
1023-
if (ParseOpenACCWaitArgument()) {
1041+
if (ParseOpenACCWaitArgument(ClauseLoc,
1042+
/*IsDirective=*/false)) {
10241043
Parens.skipToEnd();
10251044
return OpenACCCanContinue();
10261045
}
@@ -1052,7 +1071,7 @@ ExprResult Parser::ParseOpenACCAsyncArgument() {
10521071
/// In this section and throughout the specification, the term wait-argument
10531072
/// means:
10541073
/// [ devnum : int-expr : ] [ queues : ] async-argument-list
1055-
bool Parser::ParseOpenACCWaitArgument() {
1074+
bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
10561075
// [devnum : int-expr : ]
10571076
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::DevNum, Tok) &&
10581077
NextToken().is(tok::colon)) {
@@ -1061,7 +1080,11 @@ bool Parser::ParseOpenACCWaitArgument() {
10611080
// Consume colon.
10621081
ConsumeToken();
10631082

1064-
ExprResult IntExpr = ParseOpenACCIntExpr();
1083+
ExprResult IntExpr = ParseOpenACCIntExpr(
1084+
IsDirective ? OpenACCDirectiveKind::Wait
1085+
: OpenACCDirectiveKind::Invalid,
1086+
IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
1087+
Loc);
10651088
if (IntExpr.isInvalid())
10661089
return true;
10671090

@@ -1245,7 +1268,7 @@ Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
12451268
break;
12461269
case OpenACCDirectiveKind::Wait:
12471270
// OpenACC has an optional paren-wrapped 'wait-argument'.
1248-
if (ParseOpenACCWaitArgument())
1271+
if (ParseOpenACCWaitArgument(StartLoc, /*IsDirective=*/true))
12491272
T.skipToEnd();
12501273
else
12511274
T.consumeClose();

0 commit comments

Comments
 (0)