Skip to content

Commit 4efe9fc

Browse files
Blower, Melanievladimirlaz
Blower, Melanie
authored andcommitted
[SYCL] Language restrictions for SYCL kernel functions from 6.3 section
- disallow allocation in kernel functions (Overloaded 'new' operations are allowed if no storage is allocated) - disallow recursion in kernel functions Signed-off-by: Blower, Melanie <[email protected]> Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent 6a70b70 commit 4efe9fc

File tree

7 files changed

+357
-24
lines changed

7 files changed

+357
-24
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9532,13 +9532,15 @@ def err_sycl_restrict : Error<
95329532
"|use rtti"
95339533
"|use a non-const static data variable"
95349534
"|call a virtual function"
9535+
"|call a recursive function"
95359536
"|call through a function pointer"
95369537
"|allocate storage"
95379538
"|use exceptions"
95389539
"|use inline assembly}0">;
95399540
def err_sycl_virtual_types : Error<
95409541
"No class with a vtable can be used in a SYCL kernel or any code included in the kernel">;
95419542
def note_sycl_used_here : Note<"used here">;
9543+
def note_sycl_recursive_function_declared_here: Note<"function implemented using recursion declared here">;
95429544
def err_sycl_non_std_layout_type : Error<
95439545
"kernel parameter has non-standard layout class/struct type">;
95449546
} // end of sema component.

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/Support/FileSystem.h"
2222
#include "llvm/Support/Path.h"
2323
#include "llvm/Support/raw_ostream.h"
24+
#include "clang/Analysis/CallGraph.h"
2425

2526
#include <array>
2627

@@ -45,6 +46,7 @@ enum RestrictKind {
4546
KernelRTTI,
4647
KernelNonConstStaticDataVariable,
4748
KernelCallVirtualFunction,
49+
KernelCallRecursiveFunction,
4850
KernelCallFunctionPointer,
4951
KernelAllocateStorage,
5052
KernelUseExceptions,
@@ -85,20 +87,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
8587

8688
bool VisitCallExpr(CallExpr *e) {
8789
for (const auto &Arg : e->arguments())
88-
CheckTypeForVirtual(Arg->getType(), Arg->getSourceRange());
90+
CheckSYCLType(Arg->getType(), Arg->getSourceRange());
8991

9092
if (FunctionDecl *Callee = e->getDirectCallee()) {
93+
Callee = Callee->getCanonicalDecl();
9194
// Remember that all SYCL kernel functions have deferred
9295
// instantiation as template functions. It means that
9396
// all functions used by kernel have already been parsed and have
9497
// definitions.
98+
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
99+
if (IsRecursive(Callee, Callee, VisitedSet))
100+
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict) <<
101+
KernelCallRecursiveFunction;
95102

96103
if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
97104
if (Method->isVirtual())
98105
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict) <<
99106
KernelCallVirtualFunction;
100107

101-
CheckTypeForVirtual(Callee->getReturnType(), Callee->getSourceRange());
108+
CheckSYCLType(Callee->getReturnType(), Callee->getSourceRange());
102109

103110
if (FunctionDecl *Def = Callee->getDefinition()) {
104111
if (!Def->hasAttr<SYCLDeviceAttr>()) {
@@ -116,7 +123,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
116123

117124
bool VisitCXXConstructExpr(CXXConstructExpr *E) {
118125
for (const auto &Arg : E->arguments())
119-
CheckTypeForVirtual(Arg->getType(), Arg->getSourceRange());
126+
CheckSYCLType(Arg->getType(), Arg->getSourceRange());
120127

121128
CXXConstructorDecl *Ctor = E->getConstructor();
122129

@@ -150,22 +157,22 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
150157
}
151158

152159
bool VisitTypedefNameDecl(TypedefNameDecl *TD) {
153-
CheckTypeForVirtual(TD->getUnderlyingType(), TD->getLocation());
160+
CheckSYCLType(TD->getUnderlyingType(), TD->getLocation());
154161
return true;
155162
}
156163

157164
bool VisitRecordDecl(RecordDecl *RD) {
158-
CheckTypeForVirtual(QualType{RD->getTypeForDecl(), 0}, RD->getLocation());
165+
CheckSYCLType(QualType{RD->getTypeForDecl(), 0}, RD->getLocation());
159166
return true;
160167
}
161168

162169
bool VisitParmVarDecl(VarDecl *VD) {
163-
CheckTypeForVirtual(VD->getType(), VD->getLocation());
170+
CheckSYCLType(VD->getType(), VD->getLocation());
164171
return true;
165172
}
166173

167174
bool VisitVarDecl(VarDecl *VD) {
168-
CheckTypeForVirtual(VD->getType(), VD->getLocation());
175+
CheckSYCLType(VD->getType(), VD->getLocation());
169176
return true;
170177
}
171178

@@ -180,7 +187,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
180187
}
181188

182189
bool VisitDeclRefExpr(DeclRefExpr *E) {
183-
CheckTypeForVirtual(E->getType(), E->getSourceRange());
190+
CheckSYCLType(E->getType(), E->getSourceRange());
184191
if (VarDecl *VD = dyn_cast<VarDecl>(E->getDecl())) {
185192
bool IsConst = VD->getType().getNonReferenceType().isConstQualified();
186193
if (!IsConst && VD->hasGlobalStorage() && !VD->isStaticLocal() &&
@@ -199,12 +206,17 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
199206
// storage are disallowed in a SYCL kernel. The placement
200207
// new operator and any user-defined overloads that
201208
// do not allocate storage are permitted.
202-
const FunctionDecl *FD = E->getOperatorNew();
203-
if (FD && !FD->isReservedGlobalPlacementOperator()) {
204-
OverloadedOperatorKind Kind = FD->getOverloadedOperator();
205-
if (Kind == OO_New || Kind == OO_Array_New)
206-
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) <<
207-
KernelAllocateStorage;
209+
if (FunctionDecl *FD = E->getOperatorNew()) {
210+
if (FD->isReplaceableGlobalAllocationFunction()) {
211+
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) <<
212+
KernelAllocateStorage;
213+
} else if (FunctionDecl *Def = FD->getDefinition()) {
214+
if (!Def->hasAttr<SYCLDeviceAttr>()) {
215+
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
216+
this->TraverseStmt(Def->getBody());
217+
SemaRef.AddSyclKernel(Def);
218+
}
219+
}
208220
}
209221
return true;
210222
}
@@ -245,8 +257,42 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
245257
return true;
246258
}
247259

260+
// The call graph for this translation unit.
261+
CallGraph SYCLCG;
248262
private:
249-
bool CheckTypeForVirtual(QualType Ty, SourceRange Loc) {
263+
// Determines whether the function FD is recursive.
264+
// CalleeNode is a function which is called either directly
265+
// or indirectly from FD. If recursion is detected then create
266+
// diagnostic notes on each function as the callstack is unwound.
267+
bool IsRecursive(FunctionDecl *CalleeNode, FunctionDecl *FD,
268+
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet) {
269+
// We're currently checking CalleeNode on a different
270+
// trace through the CallGraph, we avoid infinite recursion
271+
// by using VisitedSet to keep track of this.
272+
if (!VisitedSet.insert(CalleeNode).second)
273+
return false;
274+
if (CallGraphNode *N = SYCLCG.getNode(CalleeNode)) {
275+
for (const CallGraphNode *CI : *N) {
276+
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
277+
Callee = Callee->getCanonicalDecl();
278+
if (Callee == FD)
279+
return SemaRef.Diag(FD->getSourceRange().getBegin(),
280+
diag::note_sycl_recursive_function_declared_here)
281+
<< KernelCallRecursiveFunction;
282+
else if (IsRecursive(Callee, FD, VisitedSet))
283+
return true;
284+
}
285+
}
286+
}
287+
return false;
288+
}
289+
290+
bool CheckSYCLType(QualType Ty, SourceRange Loc) {
291+
if (Ty->isVariableArrayType()) {
292+
SemaRef.Diag(Loc.getBegin(), diag::err_vla_unsupported);
293+
return false;
294+
}
295+
250296
while (Ty->isAnyPointerType() || Ty->isArrayType())
251297
Ty = QualType{Ty->getPointeeOrArrayElementType(), 0};
252298

@@ -264,25 +310,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
264310
}
265311

266312
for (const auto &Field : CRD->fields()) {
267-
if (!CheckTypeForVirtual(Field->getType(), Field->getSourceRange())) {
313+
if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) {
268314
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
269315
return false;
270316
}
271317
}
272318
} else if (const auto *RD = Ty->getAsRecordDecl()) {
273319
for (const auto &Field : RD->fields()) {
274-
if (!CheckTypeForVirtual(Field->getType(), Field->getSourceRange())) {
320+
if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) {
275321
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
276322
return false;
277323
}
278324
}
279325
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
280326
for (const auto &ParamTy : FPTy->param_types())
281-
if (!CheckTypeForVirtual(ParamTy, Loc))
327+
if (!CheckSYCLType(ParamTy, Loc))
282328
return false;
283-
return CheckTypeForVirtual(FPTy->getReturnType(), Loc);
329+
return CheckSYCLType(FPTy->getReturnType(), Loc);
284330
} else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
285-
return CheckTypeForVirtual(FTy->getReturnType(), Loc);
331+
return CheckSYCLType(FTy->getReturnType(), Loc);
286332
}
287333
return true;
288334
}
@@ -726,6 +772,10 @@ void Sema::ConstructSYCLKernel(FunctionDecl *KernelCallerFunc) {
726772
AddSyclKernel(SYCLKernel);
727773
// Let's mark all called functions with SYCL Device attribute.
728774
MarkDeviceFunction Marker(*this);
775+
// Create the call graph so we can detect recursion and check the validity
776+
// of new operator overrides. Add the kernel function itself in case
777+
// it is recursive.
778+
Marker.SYCLCG.addToCallGraph(getASTContext().getTranslationUnitDecl());
729779
Marker.TraverseStmt(SYCLKernelBody);
730780
}
731781

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// RUN: %clang_cc1 -fcxx-exceptions -fsycl-is-device -Wno-return-type -verify -fsyntax-only -x c++ -emit-llvm-only -std=c++17 %s
2+
3+
// This recursive function is not called from sycl kernel,
4+
// so it should not be diagnosed.
5+
int fib(int n)
6+
{
7+
if (n <= 1)
8+
return n;
9+
return fib(n-1) + fib(n-2);
10+
}
11+
12+
typedef struct S {
13+
template <typename T>
14+
// expected-note@+1 2{{function implemented using recursion declared here}}
15+
T factT(T i, T j)
16+
{
17+
// expected-error@+1 1{{SYCL kernel cannot call a recursive function}}
18+
return factT(j,i);
19+
}
20+
21+
int fact(unsigned i)
22+
{
23+
if (i==0) return 1;
24+
// expected-error@+1 1{{SYCL kernel cannot call a recursive function}}
25+
else return factT<unsigned>(i-1, i);
26+
}
27+
} S_type;
28+
29+
30+
// expected-note@+1 2{{function implemented using recursion declared here}}
31+
int fact(unsigned i);
32+
// expected-note@+1 2{{function implemented using recursion declared here}}
33+
int fact1(unsigned i)
34+
{
35+
if (i==0) return 1;
36+
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
37+
else return fact(i-1) * i;
38+
}
39+
int fact(unsigned i)
40+
{
41+
if (i==0) return 1;
42+
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
43+
else return fact1(i-1) * i;
44+
}
45+
46+
bool isa_B(void) {
47+
S_type s;
48+
49+
unsigned f = s.fact(3);
50+
// expected-error@+1 1{{SYCL kernel cannot call a recursive function}}
51+
unsigned f1 = s.factT<unsigned>(3,4);
52+
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
53+
unsigned g = fact(3);
54+
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
55+
unsigned g1 = fact1(3);
56+
return 0;
57+
}
58+
59+
__attribute__((sycl_kernel)) void kernel1(void) {
60+
isa_B();
61+
}
62+
// expected-note@+1 2{{function implemented using recursion declared here}}
63+
__attribute__((sycl_kernel)) void kernel2(void) {
64+
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
65+
kernel2();
66+
}
67+
__attribute__((sycl_kernel)) void kernel3(void) {
68+
;
69+
}
70+
71+
using myFuncDef = int(int,int);
72+
73+
void usage( myFuncDef functionPtr ) {
74+
kernel1();
75+
}
76+
void usage2( myFuncDef functionPtr ) {
77+
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
78+
kernel2();
79+
}
80+
void usage3( myFuncDef functionPtr ) {
81+
kernel3();
82+
}
83+
84+
int addInt(int n, int m) {
85+
return n+m;
86+
}
87+
88+
template <typename name, typename Func>
89+
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
90+
kernelFunc();
91+
}
92+
93+
template <typename name, typename Func>
94+
// expected-note@+1 2{{function implemented using recursion declared here}}
95+
__attribute__((sycl_kernel)) void kernel_single_task2(Func kernelFunc) {
96+
kernelFunc();
97+
// expected-error@+1 2{{SYCL kernel cannot call a recursive function}}
98+
kernel_single_task2<name, Func>(kernelFunc);
99+
}
100+
101+
int main() {
102+
kernel_single_task<class fake_kernel>([]() { usage( &addInt ); });
103+
kernel_single_task<class fake_kernel>([]() { usage2( &addInt ); });
104+
kernel_single_task2<class fake_kernel>([]() { usage3( &addInt ); });
105+
return fib(5);
106+
}
107+

0 commit comments

Comments
 (0)