Skip to content

Commit 9a9a018

Browse files
elizabethandrewsErich Keane
andauthored
[SYCL] Fix bugs with recursion in SYCL kernel (#3958)
This patch removes ConstexprDepthRAII which incorrectly maintained constexpr context. Instead we don't traverse statements which are forced constant expressions. This patch fixes the following bugs - constexpr int j = test_constexpr_context(recfn(1)); int k; if constexpr (false) k = recfn(1); Here, recfn() is a recursive function. The usage 1. should not diagnose in SYCL kernel since the function is called in constexpr context. Similarly 2. should not diagnose in SYCL kernel since the recursive function call is in a discarded branch. Includes commits in PR - #3714 Signed-off-by: Elizabeth Andrews <[email protected]> Co-authored-by: Erich Keane <[email protected]>
1 parent 0498efe commit 9a9a018

File tree

4 files changed

+100
-45
lines changed

4 files changed

+100
-45
lines changed

clang/include/clang/Analysis/CallGraph.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
namespace clang {
3131

32+
class ASTContext;
3233
class CallGraphNode;
3334
class Decl;
3435
class DeclContext;
@@ -51,6 +52,12 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
5152
/// This is a virtual root node that has edges to all the functions.
5253
CallGraphNode *Root;
5354

55+
/// A setting to determine whether this should include calls that are done in
56+
/// a constant expression's context. This DOES require the ASTContext object
57+
/// for constexpr-if, so setting it requires a valid ASTContext.
58+
bool ShouldSkipConstexpr = false;
59+
ASTContext *Ctx;
60+
5461
public:
5562
CallGraph();
5663
~CallGraph();
@@ -66,7 +73,7 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
6673
/// Determine if a declaration should be included in the graph.
6774
static bool includeInGraph(const Decl *D);
6875

69-
/// Determine if a declaration should be included in the graph for the
76+
/// Determine if a declaration should be included in the graph for the
7077
/// purposes of being a callee. This is similar to includeInGraph except
7178
/// it permits declarations, not just definitions.
7279
static bool includeCalleeInGraph(const Decl *D);
@@ -138,6 +145,15 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
138145
bool shouldWalkTypesOfTypeLocs() const { return false; }
139146
bool shouldVisitTemplateInstantiations() const { return true; }
140147
bool shouldVisitImplicitCode() const { return true; }
148+
bool shouldSkipConstantExpressions() const { return ShouldSkipConstexpr; }
149+
void setSkipConstantExpressions(ASTContext &Context) {
150+
Ctx = &Context;
151+
ShouldSkipConstexpr = true;
152+
}
153+
ASTContext &getASTContext() {
154+
assert(Ctx);
155+
return *Ctx;
156+
}
141157

142158
private:
143159
/// Add the given declaration to the call graph.

clang/lib/Analysis/CallGraph.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "clang/Analysis/CallGraph.h"
14+
#include "clang/AST/ASTContext.h"
1415
#include "clang/AST/Decl.h"
1516
#include "clang/AST/DeclBase.h"
1617
#include "clang/AST/DeclObjC.h"
@@ -136,6 +137,37 @@ class CGBuilder : public StmtVisitor<CGBuilder> {
136137
}
137138
}
138139

140+
void VisitIfStmt(IfStmt *If) {
141+
if (G->shouldSkipConstantExpressions()) {
142+
if (llvm::Optional<Stmt *> ActiveStmt =
143+
If->getNondiscardedCase(G->getASTContext())) {
144+
if (*ActiveStmt)
145+
this->Visit(*ActiveStmt);
146+
return;
147+
}
148+
}
149+
150+
StmtVisitor::VisitIfStmt(If);
151+
}
152+
153+
void VisitDeclStmt(DeclStmt *DS) {
154+
if (G->shouldSkipConstantExpressions()) {
155+
auto IsConstexprVarDecl = [](Decl *D) {
156+
if (const auto *VD = dyn_cast<VarDecl>(D))
157+
return VD->isConstexpr();
158+
return false;
159+
};
160+
if (llvm::any_of(DS->decls(), IsConstexprVarDecl)) {
161+
assert(llvm::all_of(DS->decls(), IsConstexprVarDecl) &&
162+
"Situation where a decl-group would be a mix of decl types, or "
163+
"constexpr and not?");
164+
return;
165+
}
166+
}
167+
168+
StmtVisitor::VisitDeclStmt(DS);
169+
}
170+
139171
void VisitChildren(Stmt *S) {
140172
for (Stmt *SubStmt : S->children())
141173
if (SubStmt)

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -579,27 +579,9 @@ static void collectSYCLAttributes(Sema &S, FunctionDecl *FD,
579579
}
580580

581581
class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
582-
// Used to keep track of the constexpr depth, so we know whether to skip
583-
// diagnostics.
584-
unsigned ConstexprDepth = 0;
585582
Sema &SemaRef;
586583
const llvm::SmallPtrSetImpl<const FunctionDecl *> &RecursiveFuncs;
587584

588-
struct ConstexprDepthRAII {
589-
DiagDeviceFunction &DDF;
590-
bool Increment;
591-
592-
ConstexprDepthRAII(DiagDeviceFunction &DDF, bool Increment = true)
593-
: DDF(DDF), Increment(Increment) {
594-
if (Increment)
595-
++DDF.ConstexprDepth;
596-
}
597-
~ConstexprDepthRAII() {
598-
if (Increment)
599-
--DDF.ConstexprDepth;
600-
}
601-
};
602-
603585
public:
604586
DiagDeviceFunction(
605587
Sema &S,
@@ -617,7 +599,7 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
617599
// instantiation as template functions. It means that
618600
// all functions used by kernel have already been parsed and have
619601
// definitions.
620-
if (RecursiveFuncs.count(Callee) && !ConstexprDepth) {
602+
if (RecursiveFuncs.count(Callee)) {
621603
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
622604
<< Sema::KernelCallRecursiveFunction;
623605
SemaRef.Diag(Callee->getSourceRange().getBegin(),
@@ -670,45 +652,41 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
670652

671653
// Skip checking rules on variables initialized during constant evaluation.
672654
bool TraverseVarDecl(VarDecl *VD) {
673-
ConstexprDepthRAII R(*this, VD->isConstexpr());
655+
if (VD->isConstexpr())
656+
return true;
674657
return RecursiveASTVisitor::TraverseVarDecl(VD);
675658
}
676659

677660
// Skip checking rules on template arguments, since these are constant
678661
// expressions.
679662
bool TraverseTemplateArgumentLoc(const TemplateArgumentLoc &ArgLoc) {
680-
ConstexprDepthRAII R(*this);
681-
return RecursiveASTVisitor::TraverseTemplateArgumentLoc(ArgLoc);
663+
return true;
682664
}
683665

684666
// Skip checking the static assert, both components are required to be
685667
// constant expressions.
686-
bool TraverseStaticAssertDecl(StaticAssertDecl *D) {
687-
ConstexprDepthRAII R(*this);
688-
return RecursiveASTVisitor::TraverseStaticAssertDecl(D);
689-
}
668+
bool TraverseStaticAssertDecl(StaticAssertDecl *D) { return true; }
690669

691670
// Make sure we skip the condition of the case, since that is a constant
692671
// expression.
693672
bool TraverseCaseStmt(CaseStmt *S) {
694-
{
695-
ConstexprDepthRAII R(*this);
696-
if (!TraverseStmt(S->getLHS()))
697-
return false;
698-
if (!TraverseStmt(S->getRHS()))
699-
return false;
700-
}
701673
return TraverseStmt(S->getSubStmt());
702674
}
703675

704676
// Skip checking the size expr, since a constant array type loc's size expr is
705677
// a constant expression.
706678
bool TraverseConstantArrayTypeLoc(const ConstantArrayTypeLoc &ArrLoc) {
707-
if (!TraverseTypeLoc(ArrLoc.getElementLoc()))
708-
return false;
679+
return true;
680+
}
709681

710-
ConstexprDepthRAII R(*this);
711-
return TraverseStmt(ArrLoc.getSizeExpr());
682+
bool TraverseIfStmt(IfStmt *S) {
683+
if (llvm::Optional<Stmt *> ActiveStmt =
684+
S->getNondiscardedCase(SemaRef.Context)) {
685+
if (*ActiveStmt)
686+
return TraverseStmt(*ActiveStmt);
687+
return true;
688+
}
689+
return RecursiveASTVisitor::TraverseIfStmt(S);
712690
}
713691
};
714692

@@ -749,6 +727,7 @@ class DeviceFunctionTracker {
749727

750728
public:
751729
DeviceFunctionTracker(Sema &S) : SemaRef(S) {
730+
CG.setSkipConstantExpressions(S.Context);
752731
CG.addToCallGraph(S.getASTContext().getTranslationUnitDecl());
753732
CollectSyclExternalFuncs();
754733
}

clang/test/SemaSYCL/allow-constexpr-recursion.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ sycl::queue q;
88

99
constexpr int constexpr_recurse1(int n);
1010

11-
// expected-note@+1 3{{function implemented using recursion declared here}}
11+
// expected-note@+1 5{{function implemented using recursion declared here}}
1212
constexpr int constexpr_recurse(int n) {
1313
if (n)
1414
return constexpr_recurse1(n - 1);
@@ -20,6 +20,10 @@ constexpr int constexpr_recurse1(int n) {
2020
return constexpr_recurse(n) + 1;
2121
}
2222

23+
constexpr int test_constexpr_context(int n) {
24+
return n;
25+
}
26+
2327
template <int I>
2428
void bar() {}
2529

@@ -55,15 +59,13 @@ void ConstexprIf2() {
5559
// they should not diagnose.
5660
void constexpr_recurse_test() {
5761
constexpr int i = constexpr_recurse(1);
62+
constexpr int j = test_constexpr_context(constexpr_recurse(1));
5863
bar<constexpr_recurse(2)>();
5964
bar2<1, 2, constexpr_recurse(2)>();
6065
static_assert(constexpr_recurse(2) == 105, "");
6166

62-
int j;
6367
switch (105) {
6468
case constexpr_recurse(2):
65-
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
66-
j = constexpr_recurse(5);
6769
break;
6870
}
6971

@@ -78,14 +80,40 @@ void constexpr_recurse_test() {
7880

7981
ConditionallyExplicitCtor c(1);
8082

81-
ConstexprIf1<0>(); // Should not cause a diagnostic.
82-
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
83-
ConstexprIf2<1>();
83+
ConstexprIf1<0>();
84+
85+
int k;
86+
if constexpr (false)
87+
k = constexpr_recurse(1);
88+
else
89+
constexpr int l = test_constexpr_context(constexpr_recurse(1));
8490
}
8591

8692
void constexpr_recurse_test_err() {
8793
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
8894
int i = constexpr_recurse(1);
95+
96+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
97+
ConstexprIf2<1>();
98+
99+
int j, k;
100+
if constexpr (true)
101+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
102+
j = constexpr_recurse(1);
103+
104+
if constexpr (false)
105+
j = constexpr_recurse(1); // Should not diagnose in discarded branch
106+
else
107+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
108+
k = constexpr_recurse(1);
109+
110+
switch (105) {
111+
case constexpr_recurse(2):
112+
constexpr int l = test_constexpr_context(constexpr_recurse(1));
113+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
114+
j = constexpr_recurse(5);
115+
break;
116+
}
89117
}
90118

91119
int main() {

0 commit comments

Comments
 (0)