Skip to content

[SYCL] Allow recursive function calls in a constexpr context. #2105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,24 @@ static int64_t getIntExprValue(const Expr *E, ASTContext &Ctx) {
}

class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
// Used to keep track of the constexpr depth, so we know whether to skip
// diagnostics.
unsigned ConstexprDepth = 0;
struct ConstexprDepthRAII {
MarkDeviceFunction &MDF;
bool Increment;

ConstexprDepthRAII(MarkDeviceFunction &MDF, bool Increment = true)
: MDF(MDF), Increment(Increment) {
if (Increment)
++MDF.ConstexprDepth;
}
~ConstexprDepthRAII() {
if (Increment)
--MDF.ConstexprDepth;
}
};

public:
MarkDeviceFunction(Sema &S)
: RecursiveASTVisitor<MarkDeviceFunction>(), SemaRef(S) {}
Expand All @@ -335,7 +353,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
// instantiation as template functions. It means that
// all functions used by kernel have already been parsed and have
// definitions.
if (RecursiveSet.count(Callee)) {
if (RecursiveSet.count(Callee) && !ConstexprDepth) {
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
<< Sema::KernelCallRecursiveFunction;
SemaRef.Diag(Callee->getSourceRange().getBegin(),
Expand Down Expand Up @@ -386,6 +404,49 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
return true;
}

// Skip checking rules on variables initialized during constant evaluation.
bool TraverseVarDecl(VarDecl *VD) {
ConstexprDepthRAII R(*this, VD->isConstexpr());
return RecursiveASTVisitor::TraverseVarDecl(VD);
}

// Skip checking rules on template arguments, since these are constant
// expressions.
bool TraverseTemplateArgumentLoc(const TemplateArgumentLoc &ArgLoc) {
ConstexprDepthRAII R(*this);
return RecursiveASTVisitor::TraverseTemplateArgumentLoc(ArgLoc);
}

// Skip checking the static assert, both components are required to be
// constant expressions.
bool TraverseStaticAssertDecl(StaticAssertDecl *D) {
ConstexprDepthRAII R(*this);
return RecursiveASTVisitor::TraverseStaticAssertDecl(D);
}

// Make sure we skip the condition of the case, since that is a constant
// expression.
bool TraverseCaseStmt(CaseStmt *S) {
{
ConstexprDepthRAII R(*this);
if (!TraverseStmt(S->getLHS()))
return false;
if (!TraverseStmt(S->getRHS()))
return false;
}
return TraverseStmt(S->getSubStmt());
}

// Skip checking the size expr, since a constant array type loc's size expr is
// a constant expression.
bool TraverseConstantArrayTypeLoc(const ConstantArrayTypeLoc &ArrLoc) {
if (!TraverseTypeLoc(ArrLoc.getElementLoc()))
return false;

ConstexprDepthRAII R(*this);
return TraverseStmt(ArrLoc.getSizeExpr());
}

// The call graph for this translation unit.
CallGraph SYCLCG;
// The set of functions called by a kernel function.
Expand Down
76 changes: 76 additions & 0 deletions clang/test/SemaSYCL/allow-constexpr-recursion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// RUN: %clang_cc1 -fsycl -fsycl-is-device -fcxx-exceptions -Wno-return-type -verify -fsyntax-only -std=c++20 -Werror=vla %s

template <typename name, typename Func>
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
kernelFunc();
}

// expected-note@+1{{function implemented using recursion declared here}}
constexpr int constexpr_recurse1(int n);

// expected-note@+1 3{{function implemented using recursion declared here}}
constexpr int constexpr_recurse(int n) {
if (n)
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
return constexpr_recurse1(n - 1);
return 103;
}

constexpr int constexpr_recurse1(int n) {
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
return constexpr_recurse(n) + 1;
}

template <int I>
void bar() {}

template <int... args>
void bar2() {}

enum class SomeE {
Value = constexpr_recurse(5)
};

struct ConditionallyExplicitCtor {
explicit(constexpr_recurse(5) == 103) ConditionallyExplicitCtor(int i) {}
};

void conditionally_noexcept() noexcept(constexpr_recurse(5)) {}

// All of the uses of constexpr_recurse here are forced constant expressions, so
// they should not diagnose.
void constexpr_recurse_test() {
constexpr int i = constexpr_recurse(1);
bar<constexpr_recurse(2)>();
bar2<1, 2, constexpr_recurse(2)>();
static_assert(constexpr_recurse(2) == 105, "");

int j;
switch (105) {
case constexpr_recurse(2):
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
j = constexpr_recurse(5);
break;
}

SomeE e = SomeE::Value;

int ce_array[constexpr_recurse(5)];

conditionally_noexcept();

if constexpr ((bool)SomeE::Value) {
}

ConditionallyExplicitCtor c(1);
}

void constexpr_recurse_test_err() {
// expected-error@+1{{SYCL kernel cannot call a recursive function}}
int i = constexpr_recurse(1);
}

int main() {
kernel_single_task<class fake_kernel>([]() { constexpr_recurse_test(); });
kernel_single_task<class fake_kernel>([]() { constexpr_recurse_test_err(); });
}