Skip to content

Commit daa8836

Browse files
authored
[OpenACC] Implement 'if' clause for Compute Constructs (#88411)
Like with the 'default' clause, this is being applied to only Compute Constructs for now. The 'if' clause takes a condition expression which is used as a runtime value. This is not a particularly complex semantic implementation, as there isn't much to this clause, other than its interactions with 'self', which will be managed in the patch to implement that.
1 parent c6cd460 commit daa8836

18 files changed

+557
-37
lines changed

clang/include/clang/AST/ASTNodeTraverser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ class ASTNodeTraverser
243243
void Visit(const OpenACCClause *C) {
244244
getNodeDelegate().AddChild([=] {
245245
getNodeDelegate().Visit(C);
246-
// TODO OpenACC: Switch on clauses that have children, and add them.
246+
for (const auto *S : C->children())
247+
Visit(S);
247248
});
248249
}
249250

clang/include/clang/AST/OpenACCClause.h

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef LLVM_CLANG_AST_OPENACCCLAUSE_H
1515
#define LLVM_CLANG_AST_OPENACCCLAUSE_H
1616
#include "clang/AST/ASTContext.h"
17+
#include "clang/AST/StmtIterator.h"
1718
#include "clang/Basic/OpenACCKinds.h"
1819

1920
namespace clang {
@@ -34,6 +35,17 @@ class OpenACCClause {
3435

3536
static bool classof(const OpenACCClause *) { return true; }
3637

38+
using child_iterator = StmtIterator;
39+
using const_child_iterator = ConstStmtIterator;
40+
using child_range = llvm::iterator_range<child_iterator>;
41+
using const_child_range = llvm::iterator_range<const_child_iterator>;
42+
43+
child_range children();
44+
const_child_range children() const {
45+
auto Children = const_cast<OpenACCClause *>(this)->children();
46+
return const_child_range(Children.begin(), Children.end());
47+
}
48+
3749
virtual ~OpenACCClause() = default;
3850
};
3951

@@ -49,6 +61,13 @@ class OpenACCClauseWithParams : public OpenACCClause {
4961

5062
public:
5163
SourceLocation getLParenLoc() const { return LParenLoc; }
64+
65+
child_range children() {
66+
return child_range(child_iterator(), child_iterator());
67+
}
68+
const_child_range children() const {
69+
return const_child_range(const_child_iterator(), const_child_iterator());
70+
}
5271
};
5372

5473
/// A 'default' clause, has the optional 'none' or 'present' argument.
@@ -81,6 +100,51 @@ class OpenACCDefaultClause : public OpenACCClauseWithParams {
81100
SourceLocation EndLoc);
82101
};
83102

103+
/// Represents one of the handful of classes that has an optional/required
104+
/// 'condition' expression as an argument.
105+
class OpenACCClauseWithCondition : public OpenACCClauseWithParams {
106+
Expr *ConditionExpr = nullptr;
107+
108+
protected:
109+
OpenACCClauseWithCondition(OpenACCClauseKind K, SourceLocation BeginLoc,
110+
SourceLocation LParenLoc, Expr *ConditionExpr,
111+
SourceLocation EndLoc)
112+
: OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc),
113+
ConditionExpr(ConditionExpr) {}
114+
115+
public:
116+
bool hasConditionExpr() const { return ConditionExpr; }
117+
const Expr *getConditionExpr() const { return ConditionExpr; }
118+
Expr *getConditionExpr() { return ConditionExpr; }
119+
120+
child_range children() {
121+
if (ConditionExpr)
122+
return child_range(reinterpret_cast<Stmt **>(&ConditionExpr),
123+
reinterpret_cast<Stmt **>(&ConditionExpr + 1));
124+
return child_range(child_iterator(), child_iterator());
125+
}
126+
127+
const_child_range children() const {
128+
if (ConditionExpr)
129+
return const_child_range(
130+
reinterpret_cast<Stmt *const *>(&ConditionExpr),
131+
reinterpret_cast<Stmt *const *>(&ConditionExpr + 1));
132+
return const_child_range(const_child_iterator(), const_child_iterator());
133+
}
134+
};
135+
136+
/// An 'if' clause, which has a required condition expression.
137+
class OpenACCIfClause : public OpenACCClauseWithCondition {
138+
protected:
139+
OpenACCIfClause(SourceLocation BeginLoc, SourceLocation LParenLoc,
140+
Expr *ConditionExpr, SourceLocation EndLoc);
141+
142+
public:
143+
static OpenACCIfClause *Create(const ASTContext &C, SourceLocation BeginLoc,
144+
SourceLocation LParenLoc, Expr *ConditionExpr,
145+
SourceLocation EndLoc);
146+
};
147+
84148
template <class Impl> class OpenACCClauseVisitor {
85149
Impl &getDerived() { return static_cast<Impl &>(*this); }
86150

@@ -98,6 +162,9 @@ template <class Impl> class OpenACCClauseVisitor {
98162
case OpenACCClauseKind::Default:
99163
VisitOpenACCDefaultClause(*cast<OpenACCDefaultClause>(C));
100164
return;
165+
case OpenACCClauseKind::If:
166+
VisitOpenACCIfClause(*cast<OpenACCIfClause>(C));
167+
return;
101168
case OpenACCClauseKind::Finalize:
102169
case OpenACCClauseKind::IfPresent:
103170
case OpenACCClauseKind::Seq:
@@ -106,7 +173,6 @@ template <class Impl> class OpenACCClauseVisitor {
106173
case OpenACCClauseKind::Worker:
107174
case OpenACCClauseKind::Vector:
108175
case OpenACCClauseKind::NoHost:
109-
case OpenACCClauseKind::If:
110176
case OpenACCClauseKind::Self:
111177
case OpenACCClauseKind::Copy:
112178
case OpenACCClauseKind::UseDevice:
@@ -145,9 +211,13 @@ template <class Impl> class OpenACCClauseVisitor {
145211
llvm_unreachable("Invalid Clause kind");
146212
}
147213

148-
void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause) {
149-
return getDerived().VisitOpenACCDefaultClause(Clause);
214+
#define VISIT_CLAUSE(CLAUSE_NAME) \
215+
void VisitOpenACC##CLAUSE_NAME##Clause( \
216+
const OpenACC##CLAUSE_NAME##Clause &Clause) { \
217+
return getDerived().VisitOpenACC##CLAUSE_NAME##Clause(Clause); \
150218
}
219+
220+
#include "clang/Basic/OpenACCClauses.def"
151221
};
152222

153223
class OpenACCClausePrinter final
@@ -165,7 +235,10 @@ class OpenACCClausePrinter final
165235
}
166236
OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {}
167237

168-
void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
238+
#define VISIT_CLAUSE(CLAUSE_NAME) \
239+
void VisitOpenACC##CLAUSE_NAME##Clause( \
240+
const OpenACC##CLAUSE_NAME##Clause &Clause);
241+
#include "clang/Basic/OpenACCClauses.def"
169242
};
170243

171244
} // namespace clang
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===-- OpenACCClauses.def - List of implemented OpenACC Clauses -- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines a list of currently implemented OpenACC Clauses (and
10+
// eventually, the entire list) in a way that makes generating 'visitor' and
11+
// other lists easier.
12+
//
13+
// The primary macro is a single-argument version taking the name of the Clause
14+
// as used in Clang source (so `Default` instead of `default`).
15+
//
16+
// VISIT_CLAUSE(CLAUSE_NAME)
17+
18+
VISIT_CLAUSE(Default)
19+
VISIT_CLAUSE(If)
20+
21+
#undef VISIT_CLAUSE

clang/include/clang/Parse/Parser.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3611,6 +3611,9 @@ class Parser : public CodeCompletionHandler {
36113611
OpenACCClauseParseResult OpenACCCannotContinue();
36123612
OpenACCClauseParseResult OpenACCSuccess(OpenACCClause *Clause);
36133613

3614+
using OpenACCConditionExprParseResult =
3615+
std::pair<ExprResult, OpenACCParseCanContinue>;
3616+
36143617
/// Parses the OpenACC directive (the entire pragma) including the clause
36153618
/// list, but does not produce the main AST node.
36163619
OpenACCDirectiveParseInfo ParseOpenACCDirective();
@@ -3657,6 +3660,8 @@ class Parser : public CodeCompletionHandler {
36573660
bool ParseOpenACCGangArgList();
36583661
/// Parses a 'gang-arg', used for the 'gang' clause.
36593662
bool ParseOpenACCGangArg();
3663+
/// Parses a 'condition' expr, ensuring it results in a
3664+
ExprResult ParseOpenACCConditionExpr();
36603665

36613666
private:
36623667
//===--------------------------------------------------------------------===//

clang/include/clang/Sema/SemaOpenACC.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ class SemaOpenACC : public SemaBase {
4040
OpenACCDefaultClauseKind DefaultClauseKind;
4141
};
4242

43-
std::variant<DefaultDetails> Details;
43+
struct ConditionDetails {
44+
Expr *ConditionExpr;
45+
};
46+
47+
std::variant<DefaultDetails, ConditionDetails> Details;
4448

4549
public:
4650
OpenACCParsedClause(OpenACCDirectiveKind DirKind,
@@ -63,6 +67,16 @@ class SemaOpenACC : public SemaBase {
6367
return std::get<DefaultDetails>(Details).DefaultClauseKind;
6468
}
6569

70+
const Expr *getConditionExpr() const {
71+
return const_cast<OpenACCParsedClause *>(this)->getConditionExpr();
72+
}
73+
74+
Expr *getConditionExpr() {
75+
assert(ClauseKind == OpenACCClauseKind::If &&
76+
"Parsed clause kind does not have a condition expr");
77+
return std::get<ConditionDetails>(Details).ConditionExpr;
78+
}
79+
6680
void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; }
6781
void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); }
6882

@@ -71,6 +85,18 @@ class SemaOpenACC : public SemaBase {
7185
"Parsed clause is not a default clause");
7286
Details = DefaultDetails{DefKind};
7387
}
88+
89+
void setConditionDetails(Expr *ConditionExpr) {
90+
assert(ClauseKind == OpenACCClauseKind::If &&
91+
"Parsed clause kind does not have a condition expr");
92+
// In C++ we can count on this being a 'bool', but in C this gets left as
93+
// some sort of scalar that codegen will have to take care of converting.
94+
assert((!ConditionExpr || ConditionExpr->isInstantiationDependent() ||
95+
ConditionExpr->getType()->isScalarType()) &&
96+
"Condition expression type not scalar/dependent");
97+
98+
Details = ConditionDetails{ConditionExpr};
99+
}
74100
};
75101

76102
SemaOpenACC(Sema &S);

clang/lib/AST/OpenACCClause.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "clang/AST/OpenACCClause.h"
1515
#include "clang/AST/ASTContext.h"
16+
#include "clang/AST/Expr.h"
1617

1718
using namespace clang;
1819

@@ -27,10 +28,48 @@ OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C,
2728
return new (Mem) OpenACCDefaultClause(K, BeginLoc, LParenLoc, EndLoc);
2829
}
2930

31+
OpenACCIfClause *OpenACCIfClause::Create(const ASTContext &C,
32+
SourceLocation BeginLoc,
33+
SourceLocation LParenLoc,
34+
Expr *ConditionExpr,
35+
SourceLocation EndLoc) {
36+
void *Mem = C.Allocate(sizeof(OpenACCIfClause), alignof(OpenACCIfClause));
37+
return new (Mem) OpenACCIfClause(BeginLoc, LParenLoc, ConditionExpr, EndLoc);
38+
}
39+
40+
OpenACCIfClause::OpenACCIfClause(SourceLocation BeginLoc,
41+
SourceLocation LParenLoc, Expr *ConditionExpr,
42+
SourceLocation EndLoc)
43+
: OpenACCClauseWithCondition(OpenACCClauseKind::If, BeginLoc, LParenLoc,
44+
ConditionExpr, EndLoc) {
45+
assert(ConditionExpr && "if clause requires condition expr");
46+
assert((ConditionExpr->isInstantiationDependent() ||
47+
ConditionExpr->getType()->isScalarType()) &&
48+
"Condition expression type not scalar/dependent");
49+
}
50+
51+
OpenACCClause::child_range OpenACCClause::children() {
52+
switch (getClauseKind()) {
53+
default:
54+
assert(false && "Clause children function not implemented");
55+
break;
56+
#define VISIT_CLAUSE(CLAUSE_NAME) \
57+
case OpenACCClauseKind::CLAUSE_NAME: \
58+
return cast<OpenACC##CLAUSE_NAME##Clause>(this)->children();
59+
60+
#include "clang/Basic/OpenACCClauses.def"
61+
}
62+
return child_range(child_iterator(), child_iterator());
63+
}
64+
3065
//===----------------------------------------------------------------------===//
3166
// OpenACC clauses printing methods
3267
//===----------------------------------------------------------------------===//
3368
void OpenACCClausePrinter::VisitOpenACCDefaultClause(
3469
const OpenACCDefaultClause &C) {
3570
OS << "default(" << C.getDefaultClauseKind() << ")";
3671
}
72+
73+
void OpenACCClausePrinter::VisitOpenACCIfClause(const OpenACCIfClause &C) {
74+
OS << "if(" << C.getConditionExpr() << ")";
75+
}

clang/lib/AST/StmtProfile.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,9 +2445,10 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
24452445
namespace {
24462446
class OpenACCClauseProfiler
24472447
: public OpenACCClauseVisitor<OpenACCClauseProfiler> {
2448+
StmtProfiler &Profiler;
24482449

24492450
public:
2450-
OpenACCClauseProfiler() = default;
2451+
OpenACCClauseProfiler(StmtProfiler &P) : Profiler(P) {}
24512452

24522453
void VisitOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses) {
24532454
for (const OpenACCClause *Clause : Clauses) {
@@ -2456,20 +2457,32 @@ class OpenACCClauseProfiler
24562457
Visit(Clause);
24572458
}
24582459
}
2459-
void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause);
2460+
2461+
#define VISIT_CLAUSE(CLAUSE_NAME) \
2462+
void VisitOpenACC##CLAUSE_NAME##Clause( \
2463+
const OpenACC##CLAUSE_NAME##Clause &Clause);
2464+
2465+
#include "clang/Basic/OpenACCClauses.def"
24602466
};
24612467

24622468
/// Nothing to do here, there are no sub-statements.
24632469
void OpenACCClauseProfiler::VisitOpenACCDefaultClause(
24642470
const OpenACCDefaultClause &Clause) {}
2471+
2472+
void OpenACCClauseProfiler::VisitOpenACCIfClause(
2473+
const OpenACCIfClause &Clause) {
2474+
assert(Clause.hasConditionExpr() &&
2475+
"if clause requires a valid condition expr");
2476+
Profiler.VisitStmt(Clause.getConditionExpr());
2477+
}
24652478
} // namespace
24662479

24672480
void StmtProfiler::VisitOpenACCComputeConstruct(
24682481
const OpenACCComputeConstruct *S) {
24692482
// VisitStmt handles children, so the AssociatedStmt is handled.
24702483
VisitStmt(S);
24712484

2472-
OpenACCClauseProfiler P;
2485+
OpenACCClauseProfiler P{*this};
24732486
P.VisitOpenACCClauseList(S->clauses());
24742487
}
24752488

clang/lib/AST/TextNodeDumper.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,11 @@ void TextNodeDumper::Visit(const OpenACCClause *C) {
397397
case OpenACCClauseKind::Default:
398398
OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')';
399399
break;
400+
case OpenACCClauseKind::If:
401+
// The condition expression will be printed as a part of the 'children',
402+
// but print 'clause' here so it is clear what is happening from the dump.
403+
OS << " clause";
404+
break;
400405
default:
401406
// Nothing to do here.
402407
break;

0 commit comments

Comments
 (0)