21
21
#include " llvm/Support/FileSystem.h"
22
22
#include " llvm/Support/Path.h"
23
23
#include " llvm/Support/raw_ostream.h"
24
+ #include " clang/Analysis/CallGraph.h"
24
25
25
26
#include < array>
26
27
@@ -45,6 +46,7 @@ enum RestrictKind {
45
46
KernelRTTI,
46
47
KernelNonConstStaticDataVariable,
47
48
KernelCallVirtualFunction,
49
+ KernelCallRecursiveFunction,
48
50
KernelCallFunctionPointer,
49
51
KernelAllocateStorage,
50
52
KernelUseExceptions,
@@ -85,20 +87,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
85
87
86
88
bool VisitCallExpr (CallExpr *e) {
87
89
for (const auto &Arg : e->arguments ())
88
- CheckTypeForVirtual (Arg->getType (), Arg->getSourceRange ());
90
+ CheckSYCLType (Arg->getType (), Arg->getSourceRange ());
89
91
90
92
if (FunctionDecl *Callee = e->getDirectCallee ()) {
93
+ Callee = Callee->getCanonicalDecl ();
91
94
// Remember that all SYCL kernel functions have deferred
92
95
// instantiation as template functions. It means that
93
96
// all functions used by kernel have already been parsed and have
94
97
// definitions.
98
+ llvm::SmallPtrSet<FunctionDecl *, 10 > VisitedSet;
99
+ if (IsRecursive (Callee, Callee, VisitedSet))
100
+ SemaRef.Diag (e->getExprLoc (), diag::err_sycl_restrict) <<
101
+ KernelCallRecursiveFunction;
95
102
96
103
if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
97
104
if (Method->isVirtual ())
98
105
SemaRef.Diag (e->getExprLoc (), diag::err_sycl_restrict) <<
99
106
KernelCallVirtualFunction;
100
107
101
- CheckTypeForVirtual (Callee->getReturnType (), Callee->getSourceRange ());
108
+ CheckSYCLType (Callee->getReturnType (), Callee->getSourceRange ());
102
109
103
110
if (FunctionDecl *Def = Callee->getDefinition ()) {
104
111
if (!Def->hasAttr <SYCLDeviceAttr>()) {
@@ -116,7 +123,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
116
123
117
124
bool VisitCXXConstructExpr (CXXConstructExpr *E) {
118
125
for (const auto &Arg : E->arguments ())
119
- CheckTypeForVirtual (Arg->getType (), Arg->getSourceRange ());
126
+ CheckSYCLType (Arg->getType (), Arg->getSourceRange ());
120
127
121
128
CXXConstructorDecl *Ctor = E->getConstructor ();
122
129
@@ -150,22 +157,22 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
150
157
}
151
158
152
159
bool VisitTypedefNameDecl (TypedefNameDecl *TD) {
153
- CheckTypeForVirtual (TD->getUnderlyingType (), TD->getLocation ());
160
+ CheckSYCLType (TD->getUnderlyingType (), TD->getLocation ());
154
161
return true ;
155
162
}
156
163
157
164
bool VisitRecordDecl (RecordDecl *RD) {
158
- CheckTypeForVirtual (QualType{RD->getTypeForDecl (), 0 }, RD->getLocation ());
165
+ CheckSYCLType (QualType{RD->getTypeForDecl (), 0 }, RD->getLocation ());
159
166
return true ;
160
167
}
161
168
162
169
bool VisitParmVarDecl (VarDecl *VD) {
163
- CheckTypeForVirtual (VD->getType (), VD->getLocation ());
170
+ CheckSYCLType (VD->getType (), VD->getLocation ());
164
171
return true ;
165
172
}
166
173
167
174
bool VisitVarDecl (VarDecl *VD) {
168
- CheckTypeForVirtual (VD->getType (), VD->getLocation ());
175
+ CheckSYCLType (VD->getType (), VD->getLocation ());
169
176
return true ;
170
177
}
171
178
@@ -180,7 +187,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
180
187
}
181
188
182
189
bool VisitDeclRefExpr (DeclRefExpr *E) {
183
- CheckTypeForVirtual (E->getType (), E->getSourceRange ());
190
+ CheckSYCLType (E->getType (), E->getSourceRange ());
184
191
if (VarDecl *VD = dyn_cast<VarDecl>(E->getDecl ())) {
185
192
bool IsConst = VD->getType ().getNonReferenceType ().isConstQualified ();
186
193
if (!IsConst && VD->hasGlobalStorage () && !VD->isStaticLocal () &&
@@ -199,12 +206,17 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
199
206
// storage are disallowed in a SYCL kernel. The placement
200
207
// new operator and any user-defined overloads that
201
208
// do not allocate storage are permitted.
202
- const FunctionDecl *FD = E->getOperatorNew ();
203
- if (FD && !FD->isReservedGlobalPlacementOperator ()) {
204
- OverloadedOperatorKind Kind = FD->getOverloadedOperator ();
205
- if (Kind == OO_New || Kind == OO_Array_New)
206
- SemaRef.Diag (E->getExprLoc (), diag::err_sycl_restrict) <<
207
- KernelAllocateStorage;
209
+ if (FunctionDecl *FD = E->getOperatorNew ()) {
210
+ if (FD->isReplaceableGlobalAllocationFunction ()) {
211
+ SemaRef.Diag (E->getExprLoc (), diag::err_sycl_restrict) <<
212
+ KernelAllocateStorage;
213
+ } else if (FunctionDecl *Def = FD->getDefinition ()) {
214
+ if (!Def->hasAttr <SYCLDeviceAttr>()) {
215
+ Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
216
+ this ->TraverseStmt (Def->getBody ());
217
+ SemaRef.AddSyclKernel (Def);
218
+ }
219
+ }
208
220
}
209
221
return true ;
210
222
}
@@ -245,8 +257,42 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
245
257
return true ;
246
258
}
247
259
260
+ // The call graph for this translation unit.
261
+ CallGraph SYCLCG;
248
262
private:
249
- bool CheckTypeForVirtual (QualType Ty, SourceRange Loc) {
263
+ // Determines whether the function FD is recursive.
264
+ // CalleeNode is a function which is called either directly
265
+ // or indirectly from FD. If recursion is detected then create
266
+ // diagnostic notes on each function as the callstack is unwound.
267
+ bool IsRecursive (FunctionDecl *CalleeNode, FunctionDecl *FD,
268
+ llvm::SmallPtrSet<FunctionDecl *, 10 > VisitedSet) {
269
+ // We're currently checking CalleeNode on a different
270
+ // trace through the CallGraph, we avoid infinite recursion
271
+ // by using VisitedSet to keep track of this.
272
+ if (!VisitedSet.insert (CalleeNode).second )
273
+ return false ;
274
+ if (CallGraphNode *N = SYCLCG.getNode (CalleeNode)) {
275
+ for (const CallGraphNode *CI : *N) {
276
+ if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl ())) {
277
+ Callee = Callee->getCanonicalDecl ();
278
+ if (Callee == FD)
279
+ return SemaRef.Diag (FD->getSourceRange ().getBegin (),
280
+ diag::note_sycl_recursive_function_declared_here)
281
+ << KernelCallRecursiveFunction;
282
+ else if (IsRecursive (Callee, FD, VisitedSet))
283
+ return true ;
284
+ }
285
+ }
286
+ }
287
+ return false ;
288
+ }
289
+
290
+ bool CheckSYCLType (QualType Ty, SourceRange Loc) {
291
+ if (Ty->isVariableArrayType ()) {
292
+ SemaRef.Diag (Loc.getBegin (), diag::err_vla_unsupported);
293
+ return false ;
294
+ }
295
+
250
296
while (Ty->isAnyPointerType () || Ty->isArrayType ())
251
297
Ty = QualType{Ty->getPointeeOrArrayElementType (), 0 };
252
298
@@ -264,25 +310,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
264
310
}
265
311
266
312
for (const auto &Field : CRD->fields ()) {
267
- if (!CheckTypeForVirtual (Field->getType (), Field->getSourceRange ())) {
313
+ if (!CheckSYCLType (Field->getType (), Field->getSourceRange ())) {
268
314
SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
269
315
return false ;
270
316
}
271
317
}
272
318
} else if (const auto *RD = Ty->getAsRecordDecl ()) {
273
319
for (const auto &Field : RD->fields ()) {
274
- if (!CheckTypeForVirtual (Field->getType (), Field->getSourceRange ())) {
320
+ if (!CheckSYCLType (Field->getType (), Field->getSourceRange ())) {
275
321
SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
276
322
return false ;
277
323
}
278
324
}
279
325
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
280
326
for (const auto &ParamTy : FPTy->param_types ())
281
- if (!CheckTypeForVirtual (ParamTy, Loc))
327
+ if (!CheckSYCLType (ParamTy, Loc))
282
328
return false ;
283
- return CheckTypeForVirtual (FPTy->getReturnType (), Loc);
329
+ return CheckSYCLType (FPTy->getReturnType (), Loc);
284
330
} else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
285
- return CheckTypeForVirtual (FTy->getReturnType (), Loc);
331
+ return CheckSYCLType (FTy->getReturnType (), Loc);
286
332
}
287
333
return true ;
288
334
}
@@ -726,6 +772,10 @@ void Sema::ConstructSYCLKernel(FunctionDecl *KernelCallerFunc) {
726
772
AddSyclKernel (SYCLKernel);
727
773
// Let's mark all called functions with SYCL Device attribute.
728
774
MarkDeviceFunction Marker (*this );
775
+ // Create the call graph so we can detect recursion and check the validity
776
+ // of new operator overrides. Add the kernel function itself in case
777
+ // it is recursive.
778
+ Marker.SYCLCG .addToCallGraph (getASTContext ().getTranslationUnitDecl ());
729
779
Marker.TraverseStmt (SYCLKernelBody);
730
780
}
731
781
0 commit comments