Skip to content

Commit 6fd33a9

Browse files
[SYCL] Fix bugs with recursion in SYCL kernel
This patch removes ConstexprDepthRAII which incorrectly maintained constexpr context. Instead we don't traverse statements which are forced constant expressions. This patch fixes the following bugs - 1. constexpr int j = test_constexpr_context(recfn(1)); 2. int k; if constexpr (false) k = recfn(1); Here, recfn() is a recursive function. The usage 1. and 2. should not diagnose in SYCL kernel since the function is called in constexpr context. Signed-off-by: Elizabeth Andrews <[email protected]>
1 parent 292b77e commit 6fd33a9

File tree

3 files changed

+50
-45
lines changed

3 files changed

+50
-45
lines changed

clang/include/clang/Analysis/CallGraph.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
145145
bool shouldWalkTypesOfTypeLocs() const { return false; }
146146
bool shouldVisitTemplateInstantiations() const { return true; }
147147
bool shouldVisitImplicitCode() const { return true; }
148-
bool shouldVisitConstantExpressions() const { return false; }
149148
bool shouldSkipConstantExpressions() const { return shouldSkipConstexpr; }
150149
void setSkipConstantExpressions(ASTContext &Context) {
151150
Ctx = &Context;

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -395,27 +395,9 @@ static void collectSYCLAttributes(Sema &S, FunctionDecl *FD,
395395
}
396396

397397
class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
398-
// Used to keep track of the constexpr depth, so we know whether to skip
399-
// diagnostics.
400-
unsigned ConstexprDepth = 0;
401398
Sema &SemaRef;
402399
const llvm::SmallPtrSetImpl<const FunctionDecl *> &RecursiveFuncs;
403400

404-
struct ConstexprDepthRAII {
405-
DiagDeviceFunction &DDF;
406-
bool Increment;
407-
408-
ConstexprDepthRAII(DiagDeviceFunction &DDF, bool Increment = true)
409-
: DDF(DDF), Increment(Increment) {
410-
if (Increment)
411-
++DDF.ConstexprDepth;
412-
}
413-
~ConstexprDepthRAII() {
414-
if (Increment)
415-
--DDF.ConstexprDepth;
416-
}
417-
};
418-
419401
public:
420402
DiagDeviceFunction(
421403
Sema &S,
@@ -433,7 +415,7 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
433415
// instantiation as template functions. It means that
434416
// all functions used by kernel have already been parsed and have
435417
// definitions.
436-
if (RecursiveFuncs.count(Callee) && !ConstexprDepth) {
418+
if (RecursiveFuncs.count(Callee)) {
437419
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
438420
<< Sema::KernelCallRecursiveFunction;
439421
SemaRef.Diag(Callee->getSourceRange().getBegin(),
@@ -486,45 +468,41 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
486468

487469
// Skip checking rules on variables initialized during constant evaluation.
488470
bool TraverseVarDecl(VarDecl *VD) {
489-
ConstexprDepthRAII R(*this, VD->isConstexpr());
471+
if (VD->isConstexpr())
472+
return true;
490473
return RecursiveASTVisitor::TraverseVarDecl(VD);
491474
}
492475

493476
// Skip checking rules on template arguments, since these are constant
494477
// expressions.
495478
bool TraverseTemplateArgumentLoc(const TemplateArgumentLoc &ArgLoc) {
496-
ConstexprDepthRAII R(*this);
497-
return RecursiveASTVisitor::TraverseTemplateArgumentLoc(ArgLoc);
479+
return true;
498480
}
499481

500482
// Skip checking the static assert, both components are required to be
501483
// constant expressions.
502-
bool TraverseStaticAssertDecl(StaticAssertDecl *D) {
503-
ConstexprDepthRAII R(*this);
504-
return RecursiveASTVisitor::TraverseStaticAssertDecl(D);
505-
}
484+
bool TraverseStaticAssertDecl(StaticAssertDecl *D) { return true; }
506485

507486
// Make sure we skip the condition of the case, since that is a constant
508487
// expression.
509488
bool TraverseCaseStmt(CaseStmt *S) {
510-
{
511-
ConstexprDepthRAII R(*this);
512-
if (!TraverseStmt(S->getLHS()))
513-
return false;
514-
if (!TraverseStmt(S->getRHS()))
515-
return false;
516-
}
517489
return TraverseStmt(S->getSubStmt());
518490
}
519491

520492
// Skip checking the size expr, since a constant array type loc's size expr is
521493
// a constant expression.
522494
bool TraverseConstantArrayTypeLoc(const ConstantArrayTypeLoc &ArrLoc) {
523-
if (!TraverseTypeLoc(ArrLoc.getElementLoc()))
524-
return false;
495+
return true;
496+
}
525497

526-
ConstexprDepthRAII R(*this);
527-
return TraverseStmt(ArrLoc.getSizeExpr());
498+
bool TraverseIfStmt(IfStmt *S) {
499+
if (llvm::Optional<Stmt *> ActiveStmt =
500+
S->getNondiscardedCase(SemaRef.Context)) {
501+
if (*ActiveStmt)
502+
return TraverseStmt(*ActiveStmt);
503+
return true;
504+
}
505+
return RecursiveASTVisitor::TraverseIfStmt(S);
528506
}
529507
};
530508

clang/test/SemaSYCL/allow-constexpr-recursion.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ sycl::queue q;
88

99
constexpr int constexpr_recurse1(int n);
1010

11-
// expected-note@+1 3{{function implemented using recursion declared here}}
11+
// expected-note@+1 5{{function implemented using recursion declared here}}
1212
constexpr int constexpr_recurse(int n) {
1313
if (n)
1414
return constexpr_recurse1(n - 1);
@@ -20,6 +20,10 @@ constexpr int constexpr_recurse1(int n) {
2020
return constexpr_recurse(n) + 1;
2121
}
2222

23+
constexpr int test_constexpr_context(int n) {
24+
return n;
25+
}
26+
2327
template <int I>
2428
void bar() {}
2529

@@ -55,15 +59,13 @@ void ConstexprIf2() {
5559
// they should not diagnose.
5660
void constexpr_recurse_test() {
5761
constexpr int i = constexpr_recurse(1);
62+
constexpr int j = test_constexpr_context(constexpr_recurse(1));
5863
bar<constexpr_recurse(2)>();
5964
bar2<1, 2, constexpr_recurse(2)>();
6065
static_assert(constexpr_recurse(2) == 105, "");
6166

62-
int j;
6367
switch (105) {
6468
case constexpr_recurse(2):
65-
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
66-
j = constexpr_recurse(5);
6769
break;
6870
}
6971

@@ -78,14 +80,40 @@ void constexpr_recurse_test() {
7880

7981
ConditionallyExplicitCtor c(1);
8082

81-
ConstexprIf1<0>(); // Should not cause a diagnostic.
82-
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
83-
ConstexprIf2<1>();
83+
ConstexprIf1<0>();
84+
85+
int k;
86+
if constexpr (false)
87+
k = constexpr_recurse(1);
88+
else
89+
constexpr int l = test_constexpr_context(constexpr_recurse(1));
8490
}
8591

8692
void constexpr_recurse_test_err() {
8793
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
8894
int i = constexpr_recurse(1);
95+
96+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
97+
ConstexprIf2<1>();
98+
99+
int j, k;
100+
if constexpr (true)
101+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
102+
j = constexpr_recurse(1);
103+
104+
if constexpr (false)
105+
j = constexpr_recurse(1); // Should not diagnose in discarded branch
106+
else
107+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
108+
k = constexpr_recurse(1);
109+
110+
switch (105) {
111+
case constexpr_recurse(2):
112+
constexpr int l = test_constexpr_context(constexpr_recurse(1));
113+
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
114+
j = constexpr_recurse(5);
115+
break;
116+
}
89117
}
90118

91119
int main() {

0 commit comments

Comments
 (0)