Skip to content

Commit dc20a0e

Browse files
authored
[OpenACC] Implement 'num_gangs' sema for compute constructs (#89460)
num_gangs takes an 'int-expr-list', for 'parallel', and an 'int-expr' for 'kernels'. This patch changes the parsing to always parse it as an 'int-expr-list', then correct the expression count during Sema. It also implements the rest of the semantic analysis changes for this clause.
1 parent 5c4b923 commit dc20a0e

18 files changed

+572
-44
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,33 +156,88 @@ class OpenACCSelfClause : public OpenACCClauseWithCondition {
156156
Expr *ConditionExpr, SourceLocation EndLoc);
157157
};
158158

159-
/// Represents oen of a handful of classes that have a single integer
159+
/// Represents a clause that has one or more IntExprs. It does not own the
160+
/// IntExprs, but provides 'children' and other accessors.
161+
class OpenACCClauseWithIntExprs : public OpenACCClauseWithParams {
162+
MutableArrayRef<Expr *> IntExprs;
163+
164+
protected:
165+
OpenACCClauseWithIntExprs(OpenACCClauseKind K, SourceLocation BeginLoc,
166+
SourceLocation LParenLoc, SourceLocation EndLoc)
167+
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc) {}
168+
169+
/// Used only for initialization, the leaf class can initialize this to
170+
/// trailing storage.
171+
void setIntExprs(MutableArrayRef<Expr *> NewIntExprs) {
172+
assert(IntExprs.empty() && "Cannot change IntExprs list");
173+
IntExprs = NewIntExprs;
174+
}
175+
176+
/// Gets the entire list of integer expressions, but leave it to the
177+
/// individual clauses to expose this how they'd like.
178+
llvm::ArrayRef<Expr *> getIntExprs() const { return IntExprs; }
179+
180+
public:
181+
child_range children() {
182+
return child_range(reinterpret_cast<Stmt **>(IntExprs.begin()),
183+
reinterpret_cast<Stmt **>(IntExprs.end()));
184+
}
185+
186+
const_child_range children() const {
187+
child_range Children =
188+
const_cast<OpenACCClauseWithIntExprs *>(this)->children();
189+
return const_child_range(Children.begin(), Children.end());
190+
}
191+
};
192+
193+
class OpenACCNumGangsClause final
194+
: public OpenACCClauseWithIntExprs,
195+
public llvm::TrailingObjects<OpenACCNumGangsClause, Expr *> {
196+
197+
OpenACCNumGangsClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
198+
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc)
199+
: OpenACCClauseWithIntExprs(OpenACCClauseKind::NumGangs, BeginLoc,
200+
LParenLoc, EndLoc) {
201+
std::uninitialized_copy(IntExprs.begin(), IntExprs.end(),
202+
getTrailingObjects<Expr *>());
203+
setIntExprs(MutableArrayRef(getTrailingObjects<Expr *>(), IntExprs.size()));
204+
}
205+
206+
public:
207+
static OpenACCNumGangsClause *
208+
Create(const ASTContext &C, SourceLocation BeginLoc, SourceLocation LParenLoc,
209+
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
210+
211+
llvm::ArrayRef<Expr *> getIntExprs() {
212+
return OpenACCClauseWithIntExprs::getIntExprs();
213+
}
214+
215+
llvm::ArrayRef<Expr *> getIntExprs() const {
216+
return OpenACCClauseWithIntExprs::getIntExprs();
217+
}
218+
};
219+
220+
/// Represents one of a handful of clauses that have a single integer
160221
/// expression.
161-
class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithParams {
222+
class OpenACCClauseWithSingleIntExpr : public OpenACCClauseWithIntExprs {
162223
Expr *IntExpr;
163224

164225
protected:
165226
OpenACCClauseWithSingleIntExpr(OpenACCClauseKind K, SourceLocation BeginLoc,
166227
SourceLocation LParenLoc, Expr *IntExpr,
167228
SourceLocation EndLoc)
168-
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
169-
IntExpr(IntExpr) {}
229+
: OpenACCClauseWithIntExprs(K, BeginLoc, LParenLoc, EndLoc),
230+
IntExpr(IntExpr) {
231+
setIntExprs(MutableArrayRef<Expr *>{&this->IntExpr, 1});
232+
}
170233

171234
public:
172-
bool hasIntExpr() const { return IntExpr; }
173-
const Expr *getIntExpr() const { return IntExpr; }
174-
175-
Expr *getIntExpr() { return IntExpr; };
176-
177-
child_range children() {
178-
return child_range(reinterpret_cast<Stmt **>(&IntExpr),
179-
reinterpret_cast<Stmt **>(&IntExpr + 1));
235+
bool hasIntExpr() const { return !getIntExprs().empty(); }
236+
const Expr *getIntExpr() const {
237+
return hasIntExpr() ? getIntExprs()[0] : nullptr;
180238
}
181239

182-
const_child_range children() const {
183-
return const_child_range(reinterpret_cast<Stmt *const *>(&IntExpr),
184-
reinterpret_cast<Stmt *const *>(&IntExpr + 1));
185-
}
240+
Expr *getIntExpr() { return hasIntExpr() ? getIntExprs()[0] : nullptr; };
186241
};
187242

188243
class OpenACCNumWorkersClause : public OpenACCClauseWithSingleIntExpr {

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12293,4 +12293,9 @@ def note_acc_int_expr_conversion
1229312293
: Note<"conversion to %select{integral|enumeration}0 type %1">;
1229412294
def err_acc_int_expr_multiple_conversions
1229512295
: Error<"multiple conversions from expression type %0 to an integral type">;
12296+
def err_acc_num_gangs_num_args
12297+
: Error<"%select{no|too many}0 integer expression arguments provided to "
12298+
"OpenACC 'num_gangs' "
12299+
"%select{|clause: '%1' directive expects maximum of %2, %3 were "
12300+
"provided}0">;
1229612301
} // end of sema component.

clang/include/clang/Basic/OpenACCClauses.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
VISIT_CLAUSE(Default)
1919
VISIT_CLAUSE(If)
2020
VISIT_CLAUSE(Self)
21+
VISIT_CLAUSE(NumGangs)
2122
VISIT_CLAUSE(NumWorkers)
2223
VISIT_CLAUSE(VectorLength)
2324

clang/include/clang/Parse/Parser.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3644,10 +3644,22 @@ class Parser : public CodeCompletionHandler {
36443644
/// Parses the clause of the 'bind' argument, which can be a string literal or
36453645
/// an ID expression.
36463646
ExprResult ParseOpenACCBindClauseArgument();
3647+
3648+
/// A type to represent the state of parsing after an attempt to parse an
3649+
/// OpenACC int-expr. This is useful to determine whether an int-expr list can
3650+
/// continue parsing after a failed int-expr.
3651+
using OpenACCIntExprParseResult =
3652+
std::pair<ExprResult, OpenACCParseCanContinue>;
36473653
/// Parses the clause kind of 'int-expr', which can be any integral
36483654
/// expression.
3649-
ExprResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
3650-
SourceLocation Loc);
3655+
OpenACCIntExprParseResult ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
3656+
OpenACCClauseKind CK,
3657+
SourceLocation Loc);
3658+
/// Parses the argument list for 'num_gangs', which allows up to 3
3659+
/// 'int-expr's.
3660+
bool ParseOpenACCIntExprList(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
3661+
SourceLocation Loc,
3662+
llvm::SmallVectorImpl<Expr *> &IntExprs);
36513663
/// Parses the 'device-type-list', which is a list of identifiers.
36523664
bool ParseOpenACCDeviceTypeList();
36533665
/// Parses the 'async-argument', which is an integral value with two

clang/include/clang/Sema/SemaOpenACC.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,16 @@ class SemaOpenACC : public SemaBase {
9393
}
9494

9595
unsigned getNumIntExprs() const {
96-
assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
96+
assert((ClauseKind == OpenACCClauseKind::NumGangs ||
97+
ClauseKind == OpenACCClauseKind::NumWorkers ||
9798
ClauseKind == OpenACCClauseKind::VectorLength) &&
9899
"Parsed clause kind does not have a int exprs");
99100
return std::get<IntExprDetails>(Details).IntExprs.size();
100101
}
101102

102103
ArrayRef<Expr *> getIntExprs() {
103-
assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
104+
assert((ClauseKind == OpenACCClauseKind::NumGangs ||
105+
ClauseKind == OpenACCClauseKind::NumWorkers ||
104106
ClauseKind == OpenACCClauseKind::VectorLength) &&
105107
"Parsed clause kind does not have a int exprs");
106108
return std::get<IntExprDetails>(Details).IntExprs;
@@ -134,11 +136,19 @@ class SemaOpenACC : public SemaBase {
134136
}
135137

136138
void setIntExprDetails(ArrayRef<Expr *> IntExprs) {
137-
assert((ClauseKind == OpenACCClauseKind::NumWorkers ||
139+
assert((ClauseKind == OpenACCClauseKind::NumGangs ||
140+
ClauseKind == OpenACCClauseKind::NumWorkers ||
138141
ClauseKind == OpenACCClauseKind::VectorLength) &&
139142
"Parsed clause kind does not have a int exprs");
140143
Details = IntExprDetails{{IntExprs.begin(), IntExprs.end()}};
141144
}
145+
void setIntExprDetails(llvm::SmallVector<Expr *> &&IntExprs) {
146+
assert((ClauseKind == OpenACCClauseKind::NumGangs ||
147+
ClauseKind == OpenACCClauseKind::NumWorkers ||
148+
ClauseKind == OpenACCClauseKind::VectorLength) &&
149+
"Parsed clause kind does not have a int exprs");
150+
Details = IntExprDetails{IntExprs};
151+
}
142152
};
143153

144154
SemaOpenACC(Sema &S);

clang/lib/AST/OpenACCClause.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ OpenACCVectorLengthClause::Create(const ASTContext &C, SourceLocation BeginLoc,
124124
OpenACCVectorLengthClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
125125
}
126126

127+
OpenACCNumGangsClause *OpenACCNumGangsClause::Create(const ASTContext &C,
128+
SourceLocation BeginLoc,
129+
SourceLocation LParenLoc,
130+
ArrayRef<Expr *> IntExprs,
131+
SourceLocation EndLoc) {
132+
void *Mem = C.Allocate(
133+
OpenACCNumGangsClause::totalSizeToAlloc<Expr *>(IntExprs.size()));
134+
return new (Mem) OpenACCNumGangsClause(BeginLoc, LParenLoc, IntExprs, EndLoc);
135+
}
136+
127137
//===----------------------------------------------------------------------===//
128138
// OpenACC clauses printing methods
129139
//===----------------------------------------------------------------------===//
@@ -141,6 +151,12 @@ void OpenACCClausePrinter::VisitSelfClause(const OpenACCSelfClause &C) {
141151
OS << "(" << CondExpr << ")";
142152
}
143153

154+
void OpenACCClausePrinter::VisitNumGangsClause(const OpenACCNumGangsClause &C) {
155+
OS << "num_gangs(";
156+
llvm::interleaveComma(C.getIntExprs(), OS);
157+
OS << ")";
158+
}
159+
144160
void OpenACCClausePrinter::VisitNumWorkersClause(
145161
const OpenACCNumWorkersClause &C) {
146162
OS << "num_workers(" << C.getIntExpr() << ")";

clang/lib/AST/StmtProfile.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,6 +2497,12 @@ void OpenACCClauseProfiler::VisitSelfClause(const OpenACCSelfClause &Clause) {
24972497
Profiler.VisitStmt(Clause.getConditionExpr());
24982498
}
24992499

2500+
void OpenACCClauseProfiler::VisitNumGangsClause(
2501+
const OpenACCNumGangsClause &Clause) {
2502+
for (auto *E : Clause.getIntExprs())
2503+
Profiler.VisitStmt(E);
2504+
}
2505+
25002506
void OpenACCClauseProfiler::VisitNumWorkersClause(
25012507
const OpenACCNumWorkersClause &Clause) {
25022508
assert(Clause.hasIntExpr() && "num_workers clause requires a valid int expr");

clang/lib/AST/TextNodeDumper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
399399
break;
400400
case OpenACCClauseKind::If:
401401
case OpenACCClauseKind::Self:
402+
case OpenACCClauseKind::NumGangs:
402403
case OpenACCClauseKind::NumWorkers:
403404
case OpenACCClauseKind::VectorLength:
404405
// The condition expression will be printed as a part of the 'children',

clang/lib/Parse/ParseOpenACC.cpp

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -632,16 +632,54 @@ Parser::ParseOpenACCClauseList(OpenACCDirectiveKind DirKind) {
632632
return Clauses;
633633
}
634634

635-
ExprResult Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK,
636-
OpenACCClauseKind CK,
637-
SourceLocation Loc) {
638-
ExprResult ER =
639-
getActions().CorrectDelayedTyposInExpr(ParseAssignmentExpression());
635+
Parser::OpenACCIntExprParseResult
636+
Parser::ParseOpenACCIntExpr(OpenACCDirectiveKind DK, OpenACCClauseKind CK,
637+
SourceLocation Loc) {
638+
ExprResult ER = ParseAssignmentExpression();
640639

640+
// If the actual parsing failed, we don't know the state of the parse, so
641+
// don't try to continue.
641642
if (!ER.isUsable())
642-
return ER;
643+
return {ER, OpenACCParseCanContinue::Cannot};
644+
645+
// Parsing can continue after the initial assignment expression parsing, so
646+
// even if there was a typo, we can continue.
647+
ER = getActions().CorrectDelayedTyposInExpr(ER);
648+
if (!ER.isUsable())
649+
return {ER, OpenACCParseCanContinue::Can};
650+
651+
return {getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get()),
652+
OpenACCParseCanContinue::Can};
653+
}
654+
655+
bool Parser::ParseOpenACCIntExprList(OpenACCDirectiveKind DK,
656+
OpenACCClauseKind CK, SourceLocation Loc,
657+
llvm::SmallVectorImpl<Expr *> &IntExprs) {
658+
OpenACCIntExprParseResult CurResult = ParseOpenACCIntExpr(DK, CK, Loc);
659+
660+
if (!CurResult.first.isUsable() &&
661+
CurResult.second == OpenACCParseCanContinue::Cannot) {
662+
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
663+
Parser::StopBeforeMatch);
664+
return true;
665+
}
666+
667+
IntExprs.push_back(CurResult.first.get());
668+
669+
while (!getCurToken().isOneOf(tok::r_paren, tok::annot_pragma_openacc_end)) {
670+
ExpectAndConsume(tok::comma);
671+
672+
CurResult = ParseOpenACCIntExpr(DK, CK, Loc);
643673

644-
return getActions().OpenACC().ActOnIntExpr(DK, CK, Loc, ER.get());
674+
if (!CurResult.first.isUsable() &&
675+
CurResult.second == OpenACCParseCanContinue::Cannot) {
676+
SkipUntil(tok::r_paren, tok::annot_pragma_openacc_end,
677+
Parser::StopBeforeMatch);
678+
return true;
679+
}
680+
IntExprs.push_back(CurResult.first.get());
681+
}
682+
return false;
645683
}
646684

647685
bool Parser::ParseOpenACCClauseVarList(OpenACCClauseKind Kind) {
@@ -761,7 +799,7 @@ bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {
761799
ConsumeToken();
762800
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
763801
OpenACCClauseKind::Gang, GangLoc)
764-
.isInvalid();
802+
.first.isInvalid();
765803
}
766804

767805
if (isOpenACCSpecialToken(OpenACCSpecialTokenKind::Num, getCurToken()) &&
@@ -773,7 +811,7 @@ bool Parser::ParseOpenACCGangArg(SourceLocation GangLoc) {
773811
// This is just the 'num' case where 'num' is optional.
774812
return ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
775813
OpenACCClauseKind::Gang, GangLoc)
776-
.isInvalid();
814+
.first.isInvalid();
777815
}
778816

779817
bool Parser::ParseOpenACCGangArgList(SourceLocation GangLoc) {
@@ -946,13 +984,25 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
946984
}
947985
break;
948986
}
949-
case OpenACCClauseKind::NumGangs:
987+
case OpenACCClauseKind::NumGangs: {
988+
llvm::SmallVector<Expr *> IntExprs;
989+
990+
if (ParseOpenACCIntExprList(OpenACCDirectiveKind::Invalid,
991+
OpenACCClauseKind::NumGangs, ClauseLoc,
992+
IntExprs)) {
993+
Parens.skipToEnd();
994+
return OpenACCCanContinue();
995+
}
996+
ParsedClause.setIntExprDetails(std::move(IntExprs));
997+
break;
998+
}
950999
case OpenACCClauseKind::NumWorkers:
9511000
case OpenACCClauseKind::DeviceNum:
9521001
case OpenACCClauseKind::DefaultAsync:
9531002
case OpenACCClauseKind::VectorLength: {
9541003
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
955-
ClauseKind, ClauseLoc);
1004+
ClauseKind, ClauseLoc)
1005+
.first;
9561006
if (IntExpr.isInvalid()) {
9571007
Parens.skipToEnd();
9581008
return OpenACCCanContinue();
@@ -1017,7 +1067,8 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
10171067
: OpenACCSpecialTokenKind::Num,
10181068
ClauseKind);
10191069
ExprResult IntExpr = ParseOpenACCIntExpr(OpenACCDirectiveKind::Invalid,
1020-
ClauseKind, ClauseLoc);
1070+
ClauseKind, ClauseLoc)
1071+
.first;
10211072
if (IntExpr.isInvalid()) {
10221073
Parens.skipToEnd();
10231074
return OpenACCCanContinue();
@@ -1081,11 +1132,13 @@ bool Parser::ParseOpenACCWaitArgument(SourceLocation Loc, bool IsDirective) {
10811132
// Consume colon.
10821133
ConsumeToken();
10831134

1084-
ExprResult IntExpr = ParseOpenACCIntExpr(
1085-
IsDirective ? OpenACCDirectiveKind::Wait
1086-
: OpenACCDirectiveKind::Invalid,
1087-
IsDirective ? OpenACCClauseKind::Invalid : OpenACCClauseKind::Wait,
1088-
Loc);
1135+
ExprResult IntExpr =
1136+
ParseOpenACCIntExpr(IsDirective ? OpenACCDirectiveKind::Wait
1137+
: OpenACCDirectiveKind::Invalid,
1138+
IsDirective ? OpenACCClauseKind::Invalid
1139+
: OpenACCClauseKind::Wait,
1140+
Loc)
1141+
.first;
10891142
if (IntExpr.isInvalid())
10901143
return true;
10911144

0 commit comments

Comments
 (0)