Skip to content

Commit b1b4652

Browse files
committed
[OpenACC] 'wait' clause for compute construct sema
'wait' takes a few int-exprs (well, a series of async-arguments, but those are effectively just an int-expr), plus a pair of tags. This patch adds the support for this to the AST, and does the appropriate semantic analysis for them.
1 parent 8d2ab2a commit b1b4652

20 files changed

+711
-56
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,46 @@ class OpenACCClauseWithExprs : public OpenACCClauseWithParams {
192192
}
193193
};
194194

195+
// Represents the 'devnum' and expressions lists for the 'wait' clause.
196+
class OpenACCWaitClause final
197+
: public OpenACCClauseWithExprs,
198+
public llvm::TrailingObjects<OpenACCWaitClause, Expr *> {
199+
SourceLocation QueuesLoc;
200+
OpenACCWaitClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
201+
Expr *DevNumExpr, SourceLocation QueuesLoc,
202+
ArrayRef<Expr *> QueueIdExprs, SourceLocation EndLoc)
203+
: OpenACCClauseWithExprs(OpenACCClauseKind::Wait, BeginLoc, LParenLoc,
204+
EndLoc),
205+
QueuesLoc(QueuesLoc) {
206+
// The first element of the trailing storage is always the devnum expr,
207+
// whether it is used or not.
208+
std::uninitialized_copy(&DevNumExpr, &DevNumExpr + 1,
209+
getTrailingObjects<Expr *>());
210+
std::uninitialized_copy(QueueIdExprs.begin(), QueueIdExprs.end(),
211+
getTrailingObjects<Expr *>() + 1);
212+
setExprs(
213+
MutableArrayRef(getTrailingObjects<Expr *>(), QueueIdExprs.size() + 1));
214+
}
215+
216+
public:
217+
static OpenACCWaitClause *Create(const ASTContext &C, SourceLocation BeginLoc,
218+
SourceLocation LParenLoc, Expr *DevNumExpr,
219+
SourceLocation QueuesLoc,
220+
ArrayRef<Expr *> QueueIdExprs,
221+
SourceLocation EndLoc);
222+
223+
bool hasQueuesTag() const { return !QueuesLoc.isInvalid(); }
224+
SourceLocation getQueuesLoc() const { return QueuesLoc; }
225+
bool hasDevNumExpr() const { return getExprs()[0]; }
226+
Expr *getDevNumExpr() const { return getExprs()[0]; }
227+
llvm::ArrayRef<Expr *> getQueueIdExprs() {
228+
return OpenACCClauseWithExprs::getExprs().drop_front();
229+
}
230+
llvm::ArrayRef<Expr *> getQueueIdExprs() const {
231+
return OpenACCClauseWithExprs::getExprs().drop_front();
232+
}
233+
};
234+
195235
class OpenACCNumGangsClause final
196236
: public OpenACCClauseWithExprs,
197237
public llvm::TrailingObjects<OpenACCNumGangsClause, Expr *> {

clang/include/clang/Basic/OpenACCClauses.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ VISIT_CLAUSE(Present)
4646
VISIT_CLAUSE(Private)
4747
VISIT_CLAUSE(Self)
4848
VISIT_CLAUSE(VectorLength)
49+
VISIT_CLAUSE(Wait)
4950

5051
#undef VISIT_CLAUSE
5152
#undef CLAUSE_ALIAS

clang/include/clang/Parse/Parser.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3632,6 +3632,13 @@ class Parser : public CodeCompletionHandler {
36323632
// Wait constructs, we likely want to put that information in here as well.
36333633
};
36343634

3635+
struct OpenACCWaitParseInfo {
3636+
bool Failed = false;
3637+
Expr *DevNumExpr = nullptr;
3638+
SourceLocation QueuesLoc;
3639+
SmallVector<Expr *> QueueIdExprs;
3640+
};
3641+
36353642
/// Represents the 'error' state of parsing an OpenACC Clause, and stores
36363643
/// whether we can continue parsing, or should give up on the directive.
36373644
enum class OpenACCParseCanContinue { Cannot = 0, Can = 1 };
@@ -3674,7 +3681,8 @@ class Parser : public CodeCompletionHandler {
36743681
/// Parses the clause-list for an OpenACC directive.
36753682
SmallVector<OpenACCClause *>
36763683
ParseOpenACCClauseList(OpenACCDirectiveKind DirKind);
3677-
bool ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective);
3684+
OpenACCWaitParseInfo ParseOpenACCWaitArgument(SourceLocation Loc,
3685+
bool IsDirective);
36783686
/// Parses the clause of the 'bind' argument, which can be a string literal or
36793687
/// an ID expression.
36803688
ExprResult ParseOpenACCBindClauseArgument();

clang/include/clang/Sema/SemaOpenACC.h

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,14 @@ class SemaOpenACC : public SemaBase {
5454
bool IsZero;
5555
};
5656

57+
struct WaitDetails {
58+
Expr *DevNumExpr;
59+
SourceLocation QueuesLoc;
60+
SmallVector<Expr *> QueueIdExprs;
61+
};
62+
5763
std::variant<std::monostate, DefaultDetails, ConditionDetails,
58-
IntExprDetails, VarListDetails>
64+
IntExprDetails, VarListDetails, WaitDetails>
5965
Details = std::monostate{};
6066

6167
public:
@@ -104,14 +110,45 @@ class SemaOpenACC : public SemaBase {
104110
ClauseKind == OpenACCClauseKind::Async ||
105111
ClauseKind == OpenACCClauseKind::VectorLength) &&
106112
"Parsed clause kind does not have a int exprs");
107-
//
108-
// 'async' has an optional IntExpr, so be tolerant of that.
109-
if (ClauseKind == OpenACCClauseKind::Async &&
113+
114+
// 'async' and 'wait' have an optional IntExpr, so be tolerant of that.
115+
if ((ClauseKind == OpenACCClauseKind::Async ||
116+
ClauseKind == OpenACCClauseKind::Wait) &&
110117
std::holds_alternative<std::monostate>(Details))
111118
return 0;
112119
return std::get<IntExprDetails>(Details).IntExprs.size();
113120
}
114121

122+
SourceLocation getQueuesLoc() const {
123+
assert(ClauseKind == OpenACCClauseKind::Wait &&
124+
"Parsed clause kind does not have a queues location");
125+
126+
if (std::holds_alternative<std::monostate>(Details))
127+
return SourceLocation{};
128+
129+
return std::get<WaitDetails>(Details).QueuesLoc;
130+
}
131+
132+
Expr *getDevNumExpr() const {
133+
assert(ClauseKind == OpenACCClauseKind::Wait &&
134+
"Parsed clause kind does not have a device number expr");
135+
136+
if (std::holds_alternative<std::monostate>(Details))
137+
return nullptr;
138+
139+
return std::get<WaitDetails>(Details).DevNumExpr;
140+
}
141+
142+
ArrayRef<Expr *> getQueueIdExprs() const {
143+
assert(ClauseKind == OpenACCClauseKind::Wait &&
144+
"Parsed clause kind does not have a queue id expr list");
145+
146+
if (std::holds_alternative<std::monostate>(Details))
147+
return ArrayRef<Expr *>{std::nullopt};
148+
149+
return std::get<WaitDetails>(Details).QueueIdExprs;
150+
}
151+
115152
ArrayRef<Expr *> getIntExprs() {
116153
assert((ClauseKind == OpenACCClauseKind::NumGangs ||
117154
ClauseKind == OpenACCClauseKind::NumWorkers ||
@@ -282,6 +319,13 @@ class SemaOpenACC : public SemaBase {
282319
"zero: tag only valid on copyout/create");
283320
Details = VarListDetails{std::move(VarList), IsReadOnly, IsZero};
284321
}
322+
323+
void setWaitDetails(Expr *DevNum, SourceLocation QueuesLoc,
324+
llvm::SmallVector<Expr *> &&IntExprs) {
325+
assert(ClauseKind == OpenACCClauseKind::Wait &&
326+
"Parsed clause kind does not have a wait-details");
327+
Details = WaitDetails{DevNum, QueuesLoc, std::move(IntExprs)};
328+
}
285329
};
286330

287331
SemaOpenACC(Sema &S);

clang/include/clang/Serialization/ASTRecordReader.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ class ASTRecordReader
272272
/// Read a list of Exprs used for a var-list.
273273
llvm::SmallVector<Expr *> readOpenACCVarList();
274274

275+
/// Read a list of Exprs used for a int-expr-list.
276+
llvm::SmallVector<Expr *> readOpenACCIntExprList();
277+
275278
/// Read an OpenACC clause, advancing Idx.
276279
OpenACCClause *readOpenACCClause();
277280

clang/include/clang/Serialization/ASTRecordWriter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ class ASTRecordWriter
296296

297297
void writeOpenACCVarList(const OpenACCClauseWithVarList *C);
298298

299+
void writeOpenACCIntExprList(ArrayRef<Expr *> Exprs);
300+
299301
/// Writes out a single OpenACC Clause.
300302
void writeOpenACCClause(const OpenACCClause *C);
301303

clang/lib/AST/OpenACCClause.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@ OpenACCAsyncClause *OpenACCAsyncClause::Create(const ASTContext &C,
147147
return new (Mem) OpenACCAsyncClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
148148
}
149149

150+
OpenACCWaitClause *OpenACCWaitClause::Create(
151+
const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
152+
Expr *DevNumExpr, SourceLocation QueuesLoc, ArrayRef<Expr *> QueueIdExprs,
153+
SourceLocation EndLoc) {
154+
// Allocates enough room in trailing storage for all the int-exprs, plus a
155+
// placeholder for the devnum.
156+
void *Mem = C.Allocate(
157+
OpenACCWaitClause::totalSizeToAlloc<Expr *>(QueueIdExprs.size() + 1));
158+
return new (Mem) OpenACCWaitClause(BeginLoc, LParenLoc, DevNumExpr, QueuesLoc,
159+
QueueIdExprs, EndLoc);
160+
}
161+
150162
OpenACCNumGangsClause *OpenACCNumGangsClause::Create(const ASTContext &C,
151163
SourceLocation BeginLoc,
152164
SourceLocation LParenLoc,
@@ -393,3 +405,22 @@ void OpenACCClausePrinter::VisitCreateClause(const OpenACCCreateClause &C) {
393405
[&](const Expr *E) { printExpr(E); });
394406
OS << ")";
395407
}
408+
409+
void OpenACCClausePrinter::VisitWaitClause(const OpenACCWaitClause &C) {
410+
OS << "wait";
411+
if (!C.getLParenLoc().isInvalid()) {
412+
OS << "(";
413+
if (C.hasDevNumExpr()) {
414+
OS << "devnum: ";
415+
printExpr(C.getDevNumExpr());
416+
OS << " : ";
417+
}
418+
419+
if (C.hasQueuesTag())
420+
OS << "queues: ";
421+
422+
llvm::interleaveComma(C.getQueueIdExprs(), OS,
423+
[&](const Expr *E) { printExpr(E); });
424+
OS << ")";
425+
}
426+
}

clang/lib/AST/StmtProfile.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,6 +2578,13 @@ void OpenACCClauseProfiler::VisitAsyncClause(const OpenACCAsyncClause &Clause) {
25782578
if (Clause.hasIntExpr())
25792579
Profiler.VisitStmt(Clause.getIntExpr());
25802580
}
2581+
2582+
void OpenACCClauseProfiler::VisitWaitClause(const OpenACCWaitClause &Clause) {
2583+
if (Clause.hasDevNumExpr())
2584+
Profiler.VisitStmt(Clause.getDevNumExpr());
2585+
for (auto *E : Clause.getQueueIdExprs())
2586+
Profiler.VisitStmt(E);
2587+
}
25812588
} // namespace
25822589

25832590
void StmtProfiler::VisitOpenACCComputeConstruct(

clang/lib/AST/TextNodeDumper.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,13 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
437437
if (cast<OpenACCCreateClause>(C)->isZero())
438438
OS << " : zero";
439439
break;
440+
case OpenACCClauseKind::Wait:
441+
OS << " clause";
442+
if (cast<OpenACCWaitClause>(C)->hasDevNumExpr())
443+
OS << " has devnum";
444+
if (cast<OpenACCWaitClause>(C)->hasQueuesTag())
445+
OS << " has queues tag";
446+
break;
440447
default:
441448
// Nothing to do here.
442449
break;

clang/lib/Parse/ParseOpenACC.cpp

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,6 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
867867
SemaOpenACC::OpenACCParsedClause ParsedClause(DirKind, ClauseKind, ClauseLoc);
868868

869869
if (ClauseHasRequiredParens(DirKind, ClauseKind)) {
870-
ParsedClause.setLParenLoc(getCurToken().getLocation());
871870
if (Parens.expectAndConsume()) {
872871
// We are missing a paren, so assume that the person just forgot the
873872
// parameter. Return 'false' so we try to continue on and parse the next
@@ -876,6 +875,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
876875
Parser::StopBeforeMatch);
877876
return OpenACCCanContinue();
878877
}
878+
ParsedClause.setLParenLoc(Parens.getOpenLocation());
879879

880880
switch (ClauseKind) {
881881
case OpenACCClauseKind::Default: {
@@ -1048,8 +1048,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
10481048
return OpenACCCannotContinue();
10491049

10501050
} else if (ClauseHasOptionalParens(DirKind, ClauseKind)) {
1051-
ParsedClause.setLParenLoc(getCurToken().getLocation());
10521051
if (!Parens.consumeOpen()) {
1052+
ParsedClause.setLParenLoc(Parens.getOpenLocation());
10531053
switch (ClauseKind) {
10541054
case OpenACCClauseKind::Self: {
10551055
assert(DirKind != OpenACCDirectiveKind::Update);
@@ -1099,13 +1099,19 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
10991099
return OpenACCCanContinue();
11001100
}
11011101
break;
1102-
case OpenACCClauseKind::Wait:
1103-
if (ParseOpenACCWaitArgument(ClauseLoc,
1104-
/*IsDirective=*/false)) {
1102+
case OpenACCClauseKind::Wait: {
1103+
OpenACCWaitParseInfo Info =
1104+
ParseOpenACCWaitArgument(ClauseLoc,
1105+
/*IsDirective=*/false);
1106+
if (Info.Failed) {
11051107
Parens.skipToEnd();
11061108
return OpenACCCanContinue();
11071109
}
1110+
1111+
ParsedClause.setWaitDetails(Info.DevNumExpr, Info.QueuesLoc,
1112+
std::move(Info.QueueIdExprs));
11081113
break;
1114+
}
11091115
default:
11101116
llvm_unreachable("Not an optional parens type?");
11111117
}
@@ -1139,7 +1145,9 @@ Parser::ParseOpenACCAsyncArgument(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
11391145
/// In this section and throughout the specification, the term wait-argument
11401146
/// means:
11411147
/// [ devnum : int-expr : ] [ queues : ] async-argument-list
1142-
bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
1148+
Parser::OpenACCWaitParseInfo
1149+
Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
1150+
OpenACCWaitParseInfo Result;
11431151
// [devnum : int-expr : ]
11441152
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::DevNum, Tok) &&
11451153
NextToken().is(tok::colon)) {
@@ -1153,18 +1161,25 @@ bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
11531161
: OpenACCDirectiveKind::Invalid,
11541162
IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
11551163
Loc);
1156-
if (Res.first.isInvalid() && Res.second == OpenACCParseCanContinue::Cannot)
1157-
return true;
1164+
if (Res.first.isInvalid() &&
1165+
Res.second == OpenACCParseCanContinue::Cannot) {
1166+
Result.Failed = true;
1167+
return Result;
1168+
}
11581169

1159-
if (ExpectAndConsume(tok::colon))
1160-
return true;
1170+
if (ExpectAndConsume(tok::colon)) {
1171+
Result.Failed = true;
1172+
return Result;
1173+
}
1174+
1175+
Result.DevNumExpr = Res.first.get();
11611176
}
11621177

11631178
// [ queues : ]
11641179
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Queues, Tok) &&
11651180
NextToken().is(tok::colon)) {
11661181
// Consume queues.
1167-
ConsumeToken();
1182+
Result.QueuesLoc = ConsumeToken();
11681183
// Consume colon.
11691184
ConsumeToken();
11701185
}
@@ -1176,8 +1191,10 @@ bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
11761191
bool FirstArg = true;
11771192
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
11781193
if (!FirstArg) {
1179-
if (ExpectAndConsume(tok::comma))
1180-
return true;
1194+
if (ExpectAndConsume(tok::comma)) {
1195+
Result.Failed = true;
1196+
return Result;
1197+
}
11811198
}
11821199
FirstArg = false;
11831200

@@ -1187,11 +1204,16 @@ bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
11871204
IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
11881205
Loc);
11891206

1190-
if (Res.first.isInvalid() && Res.second == OpenACCParseCanContinue::Cannot)
1191-
return true;
1207+
if (Res.first.isInvalid() &&
1208+
Res.second == OpenACCParseCanContinue::Cannot) {
1209+
Result.Failed = true;
1210+
return Result;
1211+
}
1212+
1213+
Result.QueueIdExprs.push_back(Res.first.get());
11921214
}
11931215

1194-
return false;
1216+
return Result;
11951217
}
11961218

11971219
ExprResult Parser::ParseOpenACCIDExpression() {
@@ -1360,7 +1382,7 @@ Parser::OpenACCDirectiveParseInfo Parser::ParseOpenACCDirective() {
13601382
break;
13611383
case OpenACCDirectiveKind::Wait:
13621384
// OpenACC has an optional paren-wrapped 'wait-argument'.
1363-
if (ParseOpenACCWaitArgument(StartLoc, /*IsDirective=*/true))
1385+
if (ParseOpenACCWaitArgument(StartLoc, /*IsDirective=*/true).Failed)
13641386
T.skipToEnd();
13651387
else
13661388
T.consumeClose();

0 commit comments

Comments
 (0)