Skip to content

Commit 97da34e

Browse files
authored
[OpenACC] Add 'collapse' clause AST/basic Sema implementation (#109461)
The 'collapse' clause on a 'loop' construct is used to specify how many nested loops are associated with the 'loop' construct. It takes an optional 'force' tag, and an integer constant expression as arguments. There are many other restrictions based on the contents of the loop/etc, but those are implemented in followup patches, for now, this patch just adds the AST node and does basic argument checking on the loop-count.
1 parent 677e8cd commit 97da34e

21 files changed

+540
-28
lines changed

clang/include/clang/AST/OpenACCClause.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,32 @@ class OpenACCAsyncClause : public OpenACCClauseWithSingleIntExpr {
547547
SourceLocation EndLoc);
548548
};
549549

550+
/// Represents a 'collapse' clause on a 'loop' construct. This clause takes an
551+
/// integer constant expression 'N' that represents how deep to collapse the
552+
/// construct. It also takes an optional 'force' tag that permits intervening
553+
/// code in the loops.
554+
class OpenACCCollapseClause : public OpenACCClauseWithSingleIntExpr {
555+
bool HasForce = false;
556+
557+
OpenACCCollapseClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
558+
bool HasForce, Expr *LoopCount, SourceLocation EndLoc);
559+
560+
public:
561+
const Expr *getLoopCount() const { return getIntExpr(); }
562+
Expr *getLoopCount() { return getIntExpr(); }
563+
564+
bool hasForce() const { return HasForce; }
565+
566+
static bool classof(const OpenACCClause *C) {
567+
return C->getClauseKind() == OpenACCClauseKind::Collapse;
568+
}
569+
570+
static OpenACCCollapseClause *Create(const ASTContext &C,
571+
SourceLocation BeginLoc,
572+
SourceLocation LParenLoc, bool HasForce,
573+
Expr *LoopCount, SourceLocation EndLoc);
574+
};
575+
550576
/// Represents a clause with one or more 'var' objects, represented as an expr,
551577
/// as its arguments. Var-list is expected to be stored in trailing storage.
552578
/// For now, we're just storing the original expression in its entirety, unlike

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12606,6 +12606,9 @@ def note_acc_construct_here : Note<"'%0' construct is here">;
1260612606
def err_acc_loop_spec_conflict
1260712607
: Error<"OpenACC clause '%0' on '%1' construct conflicts with previous "
1260812608
"data dependence clause">;
12609+
def err_acc_collapse_loop_count
12610+
: Error<"OpenACC 'collapse' clause loop count must be a %select{constant "
12611+
"expression|positive integer value, evaluated to %1}0">;
1260912612

1261012613
// AMDGCN builtins diagnostics
1261112614
def err_amdgcn_global_load_lds_size_invalid_value : Error<"invalid size value">;

clang/include/clang/Basic/OpenACCClauses.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
VISIT_CLAUSE(Auto)
2525
VISIT_CLAUSE(Async)
2626
VISIT_CLAUSE(Attach)
27+
VISIT_CLAUSE(Collapse)
2728
VISIT_CLAUSE(Copy)
2829
CLAUSE_ALIAS(PCopy, Copy, true)
2930
CLAUSE_ALIAS(PresentOrCopy, Copy, true)

clang/include/clang/Sema/SemaOpenACC.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,14 @@ class SemaOpenACC : public SemaBase {
8787
SmallVector<Expr *> VarList;
8888
};
8989

90+
struct CollapseDetails {
91+
bool IsForce;
92+
Expr *LoopCount;
93+
};
94+
9095
std::variant<std::monostate, DefaultDetails, ConditionDetails,
9196
IntExprDetails, VarListDetails, WaitDetails, DeviceTypeDetails,
92-
ReductionDetails>
97+
ReductionDetails, CollapseDetails>
9398
Details = std::monostate{};
9499

95100
public:
@@ -246,6 +251,18 @@ class SemaOpenACC : public SemaBase {
246251
return std::get<VarListDetails>(Details).IsZero;
247252
}
248253

254+
bool isForce() const {
255+
assert(ClauseKind == OpenACCClauseKind::Collapse &&
256+
"Only 'collapse' has a force tag");
257+
return std::get<CollapseDetails>(Details).IsForce;
258+
}
259+
260+
Expr *getLoopCount() const {
261+
assert(ClauseKind == OpenACCClauseKind::Collapse &&
262+
"Only 'collapse' has a loop count");
263+
return std::get<CollapseDetails>(Details).LoopCount;
264+
}
265+
249266
ArrayRef<DeviceTypeArgument> getDeviceTypeArchitectures() const {
250267
assert((ClauseKind == OpenACCClauseKind::DeviceType ||
251268
ClauseKind == OpenACCClauseKind::DType) &&
@@ -384,6 +401,12 @@ class SemaOpenACC : public SemaBase {
384401
"Only 'device_type'/'dtype' has a device-type-arg list");
385402
Details = DeviceTypeDetails{std::move(Archs)};
386403
}
404+
405+
void setCollapseDetails(bool IsForce, Expr *LoopCount) {
406+
assert(ClauseKind == OpenACCClauseKind::Collapse &&
407+
"Only 'collapse' has collapse details");
408+
Details = CollapseDetails{IsForce, LoopCount};
409+
}
387410
};
388411

389412
SemaOpenACC(Sema &S);
@@ -448,6 +471,8 @@ class SemaOpenACC : public SemaBase {
448471
Expr *LowerBound,
449472
SourceLocation ColonLocFirst, Expr *Length,
450473
SourceLocation RBLoc);
474+
/// Checks the loop depth value for a collapse clause.
475+
ExprResult CheckCollapseLoopCount(Expr *LoopCount);
451476

452477
/// Helper type for the registration/assignment of constructs that need to
453478
/// 'know' about their parent constructs and hold a reference to them, such as

clang/lib/AST/OpenACCClause.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ bool OpenACCClauseWithCondition::classof(const OpenACCClause *C) {
4343
bool OpenACCClauseWithSingleIntExpr::classof(const OpenACCClause *C) {
4444
return OpenACCNumWorkersClause::classof(C) ||
4545
OpenACCVectorLengthClause::classof(C) ||
46-
OpenACCAsyncClause::classof(C);
46+
OpenACCCollapseClause::classof(C) || OpenACCAsyncClause::classof(C);
4747
}
4848
OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
4949
OpenACCDefaultClauseKind K,
@@ -134,6 +134,30 @@ OpenACCNumWorkersClause::Create(const ASTContext &C, SourceLocation BeginLoc,
134134
OpenACCNumWorkersClause(BeginLoc, LParenLoc, IntExpr, EndLoc);
135135
}
136136

137+
OpenACCCollapseClause::OpenACCCollapseClause(SourceLocation BeginLoc,
138+
SourceLocation LParenLoc,
139+
bool HasForce, Expr *LoopCount,
140+
SourceLocation EndLoc)
141+
: OpenACCClauseWithSingleIntExpr(OpenACCClauseKind::Collapse, BeginLoc,
142+
LParenLoc, LoopCount, EndLoc),
143+
HasForce(HasForce) {
144+
assert(LoopCount && "LoopCount required");
145+
}
146+
147+
OpenACCCollapseClause *
148+
OpenACCCollapseClause::Create(const ASTContext &C, SourceLocation BeginLoc,
149+
SourceLocation LParenLoc, bool HasForce,
150+
Expr *LoopCount, SourceLocation EndLoc) {
151+
assert(
152+
LoopCount &&
153+
(LoopCount->isInstantiationDependent() || isa<ConstantExpr>(LoopCount)) &&
154+
"Loop count not constant expression");
155+
void *Mem =
156+
C.Allocate(sizeof(OpenACCCollapseClause), alignof(OpenACCCollapseClause));
157+
return new (Mem)
158+
OpenACCCollapseClause(BeginLoc, LParenLoc, HasForce, LoopCount, EndLoc);
159+
}
160+
137161
OpenACCVectorLengthClause::OpenACCVectorLengthClause(SourceLocation BeginLoc,
138162
SourceLocation LParenLoc,
139163
Expr *IntExpr,
@@ -550,3 +574,11 @@ void OpenACCClausePrinter::VisitIndependentClause(
550574
void OpenACCClausePrinter::VisitSeqClause(const OpenACCSeqClause &C) {
551575
OS << "seq";
552576
}
577+
578+
void OpenACCClausePrinter::VisitCollapseClause(const OpenACCCollapseClause &C) {
579+
OS << "collapse(";
580+
if (C.hasForce())
581+
OS << "force:";
582+
printExpr(C.getLoopCount());
583+
OS << ")";
584+
}

clang/lib/AST/StmtProfile.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,6 +2558,12 @@ void OpenACCClauseProfiler::VisitNumWorkersClause(
25582558
Profiler.VisitStmt(Clause.getIntExpr());
25592559
}
25602560

2561+
void OpenACCClauseProfiler::VisitCollapseClause(
2562+
const OpenACCCollapseClause &Clause) {
2563+
assert(Clause.getLoopCount() && "collapse clause requires a valid int expr");
2564+
Profiler.VisitStmt(Clause.getLoopCount());
2565+
}
2566+
25612567
void OpenACCClauseProfiler::VisitPrivateClause(
25622568
const OpenACCPrivateClause &Clause) {
25632569
for (auto *E : Clause.getVarList())

clang/lib/AST/TextNodeDumper.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,12 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
419419
// but print 'clause' here so it is clear what is happening from the dump.
420420
OS << " clause";
421421
break;
422+
case OpenACCClauseKind::Collapse:
423+
OS << " clause";
424+
if (cast<OpenACCCollapseClause>(C)->hasForce())
425+
OS << ": force";
426+
break;
427+
422428
case OpenACCClauseKind::CopyIn:
423429
case OpenACCClauseKind::PCopyIn:
424430
case OpenACCClauseKind::PresentOrCopyIn:

clang/lib/Parse/ParseOpenACC.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -976,14 +976,25 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams(
976976
/*IsReadOnly=*/false, /*IsZero=*/false);
977977
break;
978978
case OpenACCClauseKind::Collapse: {
979-
tryParseAndConsumeSpecialTokenKind(*this, OpenACCSpecialTokenKind::Force,
980-
ClauseKind);
981-
ExprResult NumLoops =
979+
bool HasForce = tryParseAndConsumeSpecialTokenKind(
980+
*this, OpenACCSpecialTokenKind::Force, ClauseKind);
981+
ExprResult LoopCount =
982982
getActions().CorrectDelayedTyposInExpr(ParseConstantExpression());
983-
if (NumLoops.isInvalid()) {
983+
if (LoopCount.isInvalid()) {
984984
Parens.skipToEnd();
985985
return OpenACCCanContinue();
986986
}
987+
988+
LoopCount = getActions().OpenACC().ActOnIntExpr(
989+
OpenACCDirectiveKind::Invalid, ClauseKind,
990+
LoopCount.get()->getBeginLoc(), LoopCount.get());
991+
992+
if (LoopCount.isInvalid()) {
993+
Parens.skipToEnd();
994+
return OpenACCCanContinue();
995+
}
996+
997+
ParsedClause.setCollapseDetails(HasForce, LoopCount.get());
987998
break;
988999
}
9891000
case OpenACCClauseKind::Bind: {

clang/lib/Sema/SemaOpenACC.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,18 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind,
343343
return false;
344344
}
345345

346+
case OpenACCClauseKind::Collapse: {
347+
switch (DirectiveKind) {
348+
case OpenACCDirectiveKind::Loop:
349+
case OpenACCDirectiveKind::ParallelLoop:
350+
case OpenACCDirectiveKind::SerialLoop:
351+
case OpenACCDirectiveKind::KernelsLoop:
352+
return true;
353+
default:
354+
return false;
355+
}
356+
}
357+
346358
default:
347359
// Do nothing so we can go to the 'unimplemented' diagnostic instead.
348360
return true;
@@ -1037,6 +1049,26 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitReductionClause(
10371049
ValidVars, Clause.getEndLoc());
10381050
}
10391051

1052+
OpenACCClause *SemaOpenACCClauseVisitor::VisitCollapseClause(
1053+
SemaOpenACC::OpenACCParsedClause &Clause) {
1054+
// Duplicates here are not really sensible. We could possible permit
1055+
// multiples if they all had the same value, but there isn't really a good
1056+
// reason to do so. Also, this simplifies the suppression of duplicates, in
1057+
// that we know if we 'find' one after instantiation, that it is the same
1058+
// clause, which simplifies instantiation/checking/etc.
1059+
if (checkAlreadyHasClauseOfKind(SemaRef, ExistingClauses, Clause))
1060+
return nullptr;
1061+
1062+
ExprResult LoopCount = SemaRef.CheckCollapseLoopCount(Clause.getLoopCount());
1063+
1064+
if (!LoopCount.isUsable())
1065+
return nullptr;
1066+
1067+
return OpenACCCollapseClause::Create(Ctx, Clause.getBeginLoc(),
1068+
Clause.getLParenLoc(), Clause.isForce(),
1069+
LoopCount.get(), Clause.getEndLoc());
1070+
}
1071+
10401072
} // namespace
10411073

10421074
SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {}
@@ -1273,6 +1305,9 @@ ExprResult SemaOpenACC::ActOnIntExpr(OpenACCDirectiveKind DK,
12731305
}
12741306
} IntExprDiagnoser(DK, CK, IntExpr);
12751307

1308+
if (!IntExpr)
1309+
return ExprError();
1310+
12761311
ExprResult IntExprResult = SemaRef.PerformContextualImplicitConversion(
12771312
Loc, IntExpr, IntExprDiagnoser);
12781313
if (IntExprResult.isInvalid())
@@ -1583,6 +1618,34 @@ ExprResult SemaOpenACC::ActOnArraySectionExpr(Expr *Base, SourceLocation LBLoc,
15831618
OK_Ordinary, ColonLoc, RBLoc);
15841619
}
15851620

1621+
ExprResult SemaOpenACC::CheckCollapseLoopCount(Expr *LoopCount) {
1622+
if (!LoopCount)
1623+
return ExprError();
1624+
1625+
assert((LoopCount->isInstantiationDependent() ||
1626+
LoopCount->getType()->isIntegerType()) &&
1627+
"Loop argument non integer?");
1628+
1629+
// If this is dependent, there really isn't anything we can check.
1630+
if (LoopCount->isInstantiationDependent())
1631+
return ExprResult{LoopCount};
1632+
1633+
std::optional<llvm::APSInt> ICE =
1634+
LoopCount->getIntegerConstantExpr(getASTContext());
1635+
1636+
// OpenACC 3.3: 2.9.1
1637+
// The argument to the collapse clause must be a constant positive integer
1638+
// expression.
1639+
if (!ICE || *ICE <= 0) {
1640+
Diag(LoopCount->getBeginLoc(), diag::err_acc_collapse_loop_count)
1641+
<< ICE.has_value() << ICE.value_or(llvm::APSInt{}).getExtValue();
1642+
return ExprError();
1643+
}
1644+
1645+
return ExprResult{
1646+
ConstantExpr::Create(getASTContext(), LoopCount, APValue{*ICE})};
1647+
}
1648+
15861649
bool SemaOpenACC::ActOnStartStmtDirective(OpenACCDirectiveKind K,
15871650
SourceLocation StartLoc) {
15881651
SemaRef.DiscardCleanupsInEvaluationContext();

clang/lib/Sema/TreeTransform.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11822,6 +11822,31 @@ void OpenACCClauseTransform<Derived>::VisitReductionClause(
1182211822
ParsedClause.getLParenLoc(), C.getReductionOp(), ValidVars,
1182311823
ParsedClause.getEndLoc());
1182411824
}
11825+
11826+
template <typename Derived>
11827+
void OpenACCClauseTransform<Derived>::VisitCollapseClause(
11828+
const OpenACCCollapseClause &C) {
11829+
Expr *LoopCount = const_cast<Expr *>(C.getLoopCount());
11830+
assert(LoopCount && "collapse clause constructed with invalid loop count");
11831+
11832+
ExprResult NewLoopCount = Self.TransformExpr(LoopCount);
11833+
11834+
NewLoopCount = Self.getSema().OpenACC().ActOnIntExpr(
11835+
OpenACCDirectiveKind::Invalid, ParsedClause.getClauseKind(),
11836+
NewLoopCount.get()->getBeginLoc(), NewLoopCount.get());
11837+
11838+
NewLoopCount =
11839+
Self.getSema().OpenACC().CheckCollapseLoopCount(NewLoopCount.get());
11840+
11841+
if (!NewLoopCount.isUsable())
11842+
return;
11843+
11844+
ParsedClause.setCollapseDetails(C.hasForce(), NewLoopCount.get());
11845+
NewClause = OpenACCCollapseClause::Create(
11846+
Self.getSema().getASTContext(), ParsedClause.getBeginLoc(),
11847+
ParsedClause.getLParenLoc(), ParsedClause.isForce(),
11848+
ParsedClause.getLoopCount(), ParsedClause.getEndLoc());
11849+
}
1182511850
} // namespace
1182611851
template <typename Derived>
1182711852
OpenACCClause *TreeTransform<Derived>::TransformOpenACCClause(

clang/lib/Serialization/ASTReader.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12283,6 +12283,13 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
1228312283
return OpenACCIndependentClause::Create(getContext(), BeginLoc, EndLoc);
1228412284
case OpenACCClauseKind::Auto:
1228512285
return OpenACCAutoClause::Create(getContext(), BeginLoc, EndLoc);
12286+
case OpenACCClauseKind::Collapse: {
12287+
SourceLocation LParenLoc = readSourceLocation();
12288+
bool HasForce = readBool();
12289+
Expr *LoopCount = readSubExpr();
12290+
return OpenACCCollapseClause::Create(getContext(), BeginLoc, LParenLoc,
12291+
HasForce, LoopCount, EndLoc);
12292+
}
1228612293

1228712294
case OpenACCClauseKind::Finalize:
1228812295
case OpenACCClauseKind::IfPresent:
@@ -12296,7 +12303,6 @@ OpenACCClause *ASTRecordReader::readOpenACCClause() {
1229612303
case OpenACCClauseKind::DeviceResident:
1229712304
case OpenACCClauseKind::Host:
1229812305
case OpenACCClauseKind::Link:
12299-
case OpenACCClauseKind::Collapse:
1230012306
case OpenACCClauseKind::Bind:
1230112307
case OpenACCClauseKind::DeviceNum:
1230212308
case OpenACCClauseKind::DefaultAsync:

clang/lib/Serialization/ASTWriter.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8148,6 +8148,13 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
81488148
// Nothing to do here, there is no additional information beyond the
81498149
// begin/end loc and clause kind.
81508150
return;
8151+
case OpenACCClauseKind::Collapse: {
8152+
const auto *CC = cast<OpenACCCollapseClause>(C);
8153+
writeSourceLocation(CC->getLParenLoc());
8154+
writeBool(CC->hasForce());
8155+
AddStmt(const_cast<Expr *>(CC->getLoopCount()));
8156+
return;
8157+
}
81518158

81528159
case OpenACCClauseKind::Finalize:
81538160
case OpenACCClauseKind::IfPresent:
@@ -8161,7 +8168,6 @@ void ASTRecordWriter::writeOpenACCClause(const OpenACCClause *C) {
81618168
case OpenACCClauseKind::DeviceResident:
81628169
case OpenACCClauseKind::Host:
81638170
case OpenACCClauseKind::Link:
8164-
case OpenACCClauseKind::Collapse:
81658171
case OpenACCClauseKind::Bind:
81668172
case OpenACCClauseKind::DeviceNum:
81678173
case OpenACCClauseKind::DefaultAsync:

0 commit comments

Comments
 (0)