@@ -94,10 +94,13 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
94
94
// instantiation as template functions. It means that
95
95
// all functions used by kernel have already been parsed and have
96
96
// definitions.
97
- llvm::SmallPtrSet<FunctionDecl *, 10 > VisitedSet;
98
- if (IsRecursive (Callee, Callee, VisitedSet))
97
+ if (RecursiveSet.count (Callee)) {
99
98
SemaRef.Diag (e->getExprLoc (), diag::err_sycl_restrict) <<
100
99
KernelCallRecursiveFunction;
100
+ SemaRef.Diag (Callee->getSourceRange ().getBegin (),
101
+ diag::note_sycl_recursive_function_declared_here)
102
+ << KernelCallRecursiveFunction;
103
+ }
101
104
102
105
if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
103
106
if (Method->isVirtual ())
@@ -109,7 +112,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
109
112
if (FunctionDecl *Def = Callee->getDefinition ()) {
110
113
if (!Def->hasAttr <SYCLDeviceAttr>()) {
111
114
Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
112
- this ->TraverseStmt (Def->getBody ());
113
115
SemaRef.AddSyclKernel (Def);
114
116
}
115
117
}
@@ -127,7 +129,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
127
129
128
130
if (FunctionDecl *Def = Ctor->getDefinition ()) {
129
131
Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
130
- this ->TraverseStmt (Def->getBody ());
131
132
SemaRef.AddSyclKernel (Def);
132
133
}
133
134
@@ -137,7 +138,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
137
138
138
139
if (FunctionDecl *Def = Dtor->getDefinition ()) {
139
140
Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
140
- this ->TraverseStmt (Def->getBody ());
141
141
SemaRef.AddSyclKernel (Def);
142
142
}
143
143
}
@@ -211,7 +211,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
211
211
} else if (FunctionDecl *Def = FD->getDefinition ()) {
212
212
if (!Def->hasAttr <SYCLDeviceAttr>()) {
213
213
Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
214
- this ->TraverseStmt (Def->getBody ());
215
214
SemaRef.AddSyclKernel (Def);
216
215
}
217
216
}
@@ -257,33 +256,42 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
257
256
258
257
// The call graph for this translation unit.
259
258
CallGraph SYCLCG;
260
- private:
259
+ // The set of functions called by a kernel function.
260
+ llvm::SmallPtrSet<FunctionDecl *, 10 > KernelSet;
261
+ // The set of recursive functions identified while building the
262
+ // kernel set, this is used for error diagnostics.
263
+ llvm::SmallPtrSet<FunctionDecl *, 10 > RecursiveSet;
261
264
// Determines whether the function FD is recursive.
262
265
// CalleeNode is a function which is called either directly
263
266
// or indirectly from FD. If recursion is detected then create
264
267
// diagnostic notes on each function as the callstack is unwound.
265
- bool IsRecursive (FunctionDecl *CalleeNode, FunctionDecl *FD,
266
- llvm::SmallPtrSet<FunctionDecl *, 10 > VisitedSet) {
268
+ void CollectKernelSet (FunctionDecl *CalleeNode, FunctionDecl *FD,
269
+ llvm::SmallPtrSet<FunctionDecl *, 10 > VisitedSet) {
267
270
// We're currently checking CalleeNode on a different
268
271
// trace through the CallGraph, we avoid infinite recursion
269
- // by using VisitedSet to keep track of this.
270
- if (!VisitedSet.insert (CalleeNode).second )
271
- return false ;
272
+ // by using KernelSet to keep track of this.
273
+ if (!KernelSet.insert (CalleeNode).second )
274
+ // Previously seen, stop recursion.
275
+ return ;
272
276
if (CallGraphNode *N = SYCLCG.getNode (CalleeNode)) {
273
277
for (const CallGraphNode *CI : *N) {
274
278
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl ())) {
275
279
Callee = Callee->getCanonicalDecl ();
276
- if (Callee == FD)
277
- return SemaRef.Diag (FD->getSourceRange ().getBegin (),
278
- diag::note_sycl_recursive_function_declared_here)
279
- << KernelCallRecursiveFunction;
280
- else if (IsRecursive (Callee, FD, VisitedSet))
281
- return true ;
280
+ if (VisitedSet.count (Callee)) {
281
+ // There's a stack frame to visit this Callee above
282
+ // this invocation. Do not recurse here.
283
+ RecursiveSet.insert (Callee);
284
+ RecursiveSet.insert (CalleeNode);
285
+ } else {
286
+ VisitedSet.insert (Callee);
287
+ CollectKernelSet (Callee, FD, VisitedSet);
288
+ VisitedSet.erase (Callee);
289
+ }
282
290
}
283
291
}
284
292
}
285
- return false ;
286
293
}
294
+ private:
287
295
288
296
bool CheckSYCLType (QualType Ty, SourceRange Loc) {
289
297
if (Ty->isVariableArrayType ()) {
@@ -770,13 +778,30 @@ void Sema::ConstructSYCLKernel(FunctionDecl *KernelCallerFunc) {
770
778
CreateSYCLKernelBody (*this , KernelCallerFunc, SYCLKernel);
771
779
SYCLKernel->setBody (SYCLKernelBody);
772
780
AddSyclKernel (SYCLKernel);
781
+ }
782
+
783
+ void Sema::MarkDevice (void ) {
773
784
// Let's mark all called functions with SYCL Device attribute.
774
- MarkDeviceFunction Marker (*this );
775
785
// Create the call graph so we can detect recursion and check the validity
776
786
// of new operator overrides. Add the kernel function itself in case
777
787
// it is recursive.
788
+ MarkDeviceFunction Marker (*this );
778
789
Marker.SYCLCG .addToCallGraph (getASTContext ().getTranslationUnitDecl ());
779
- Marker.TraverseStmt (SYCLKernelBody);
790
+ for (Decl *D : SyclKernels ()) {
791
+ if (auto SYCLKernel = dyn_cast<FunctionDecl>(D)) {
792
+ llvm::SmallPtrSet<FunctionDecl *, 10 > VisitedSet;
793
+ Marker.CollectKernelSet (SYCLKernel, SYCLKernel, VisitedSet);
794
+ }
795
+ }
796
+ for (const auto &elt : Marker.KernelSet ) {
797
+ if (FunctionDecl *Def = elt->getDefinition ()) {
798
+ if (!Def->hasAttr <SYCLDeviceAttr>()) {
799
+ Def->addAttr (SYCLDeviceAttr::CreateImplicit (Context));
800
+ AddSyclKernel (Def);
801
+ }
802
+ Marker.TraverseStmt (Def->getBody ());
803
+ }
804
+ }
780
805
}
781
806
782
807
// -----------------------------------------------------------------------------
0 commit comments