Skip to content

Commit 971fecd

Browse files
Blower, Melanievladimirlaz
Blower, Melanie
authored andcommitted
[SYCL] fix MarkFunction ASTConsumer issue with delayed instantiations
Signed-off-by: Vladimir Lazarev <[email protected]> Signed-off-by: Blower, Melanie <[email protected]>
1 parent e878f1d commit 971fecd

File tree

10 files changed

+209
-110
lines changed

10 files changed

+209
-110
lines changed

clang/include/clang/Analysis/CallGraph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {
131131

132132
bool shouldWalkTypesOfTypeLocs() const { return false; }
133133
bool shouldVisitTemplateInstantiations() const { return true; }
134+
bool shouldVisitImplicitCode() const { return true; }
134135

135136
private:
136137
/// Add the given declaration to the call graph.

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11146,6 +11146,7 @@ class Sema {
1114611146
}
1114711147

1114811148
void ConstructSYCLKernel(FunctionDecl *KernelCallerFunc);
11149+
void MarkDevice(void);
1114911150
};
1115011151

1115111152
/// RAII object that enters a new expression evaluation context.

clang/lib/Analysis/CallGraph.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,30 @@ class CGBuilder : public StmtVisitor<CGBuilder> {
7979
VisitChildren(CE);
8080
}
8181

82+
void VisitLambdaExpr(LambdaExpr *LE) {
83+
if (CXXMethodDecl *MD = LE->getCallOperator())
84+
G->VisitFunctionDecl(MD);
85+
}
86+
87+
void VisitCXXNewExpr(CXXNewExpr *E) {
88+
if (FunctionDecl *FD = E->getOperatorNew())
89+
addCalledDecl(FD);
90+
VisitChildren(E);
91+
}
92+
93+
void VisitCXXConstructExpr(CXXConstructExpr *E) {
94+
CXXConstructorDecl *Ctor = E->getConstructor();
95+
if (FunctionDecl *Def = Ctor->getDefinition())
96+
addCalledDecl(Def);
97+
const auto *ConstructedType = Ctor->getParent();
98+
if (ConstructedType->hasUserDeclaredDestructor()) {
99+
CXXDestructorDecl *Dtor = ConstructedType->getDestructor();
100+
if (FunctionDecl *Def = Dtor->getDefinition())
101+
addCalledDecl(Def);
102+
}
103+
VisitChildren(E);
104+
}
105+
82106
// Adds may-call edges for the ObjC message sends.
83107
void VisitObjCMessageExpr(ObjCMessageExpr *ME) {
84108
if (ObjCInterfaceDecl *IDecl = ME->getReceiverInterface()) {

clang/lib/Sema/Sema.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,9 @@ void Sema::ActOnEndOfTranslationUnit() {
921921
if (getLangOpts().SYCLIsDevice && SyclIntHeader != nullptr) {
922922
SyclIntHeader->emit(getLangOpts().SYCLIntHeader);
923923
}
924+
if (getLangOpts().SYCLIsDevice)
925+
MarkDevice();
926+
924927

925928
assert(LateParsedInstantiations.empty() &&
926929
"end of TU template instantiation should not create more "

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,13 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
9494
// instantiation as template functions. It means that
9595
// all functions used by kernel have already been parsed and have
9696
// definitions.
97-
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
98-
if (IsRecursive(Callee, Callee, VisitedSet))
97+
if (RecursiveSet.count(Callee)) {
9998
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict) <<
10099
KernelCallRecursiveFunction;
100+
SemaRef.Diag(Callee->getSourceRange().getBegin(),
101+
diag::note_sycl_recursive_function_declared_here)
102+
<< KernelCallRecursiveFunction;
103+
}
101104

102105
if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
103106
if (Method->isVirtual())
@@ -109,7 +112,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
109112
if (FunctionDecl *Def = Callee->getDefinition()) {
110113
if (!Def->hasAttr<SYCLDeviceAttr>()) {
111114
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
112-
this->TraverseStmt(Def->getBody());
113115
SemaRef.AddSyclKernel(Def);
114116
}
115117
}
@@ -127,7 +129,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
127129

128130
if (FunctionDecl *Def = Ctor->getDefinition()) {
129131
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
130-
this->TraverseStmt(Def->getBody());
131132
SemaRef.AddSyclKernel(Def);
132133
}
133134

@@ -137,7 +138,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
137138

138139
if (FunctionDecl *Def = Dtor->getDefinition()) {
139140
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
140-
this->TraverseStmt(Def->getBody());
141141
SemaRef.AddSyclKernel(Def);
142142
}
143143
}
@@ -211,7 +211,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
211211
} else if (FunctionDecl *Def = FD->getDefinition()) {
212212
if (!Def->hasAttr<SYCLDeviceAttr>()) {
213213
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
214-
this->TraverseStmt(Def->getBody());
215214
SemaRef.AddSyclKernel(Def);
216215
}
217216
}
@@ -257,33 +256,42 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
257256

258257
// The call graph for this translation unit.
259258
CallGraph SYCLCG;
260-
private:
259+
// The set of functions called by a kernel function.
260+
llvm::SmallPtrSet<FunctionDecl *, 10> KernelSet;
261+
// The set of recursive functions identified while building the
262+
// kernel set, this is used for error diagnostics.
263+
llvm::SmallPtrSet<FunctionDecl *, 10> RecursiveSet;
261264
// Determines whether the function FD is recursive.
262265
// CalleeNode is a function which is called either directly
263266
// or indirectly from FD. If recursion is detected then create
264267
// diagnostic notes on each function as the callstack is unwound.
265-
bool IsRecursive(FunctionDecl *CalleeNode, FunctionDecl *FD,
266-
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet) {
268+
void CollectKernelSet(FunctionDecl *CalleeNode, FunctionDecl *FD,
269+
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet) {
267270
// We're currently checking CalleeNode on a different
268271
// trace through the CallGraph, we avoid infinite recursion
269-
// by using VisitedSet to keep track of this.
270-
if (!VisitedSet.insert(CalleeNode).second)
271-
return false;
272+
// by using KernelSet to keep track of this.
273+
if (!KernelSet.insert(CalleeNode).second)
274+
// Previously seen, stop recursion.
275+
return;
272276
if (CallGraphNode *N = SYCLCG.getNode(CalleeNode)) {
273277
for (const CallGraphNode *CI : *N) {
274278
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
275279
Callee = Callee->getCanonicalDecl();
276-
if (Callee == FD)
277-
return SemaRef.Diag(FD->getSourceRange().getBegin(),
278-
diag::note_sycl_recursive_function_declared_here)
279-
<< KernelCallRecursiveFunction;
280-
else if (IsRecursive(Callee, FD, VisitedSet))
281-
return true;
280+
if (VisitedSet.count(Callee)) {
281+
// There's a stack frame to visit this Callee above
282+
// this invocation. Do not recurse here.
283+
RecursiveSet.insert(Callee);
284+
RecursiveSet.insert(CalleeNode);
285+
} else {
286+
VisitedSet.insert(Callee);
287+
CollectKernelSet(Callee, FD, VisitedSet);
288+
VisitedSet.erase(Callee);
289+
}
282290
}
283291
}
284292
}
285-
return false;
286293
}
294+
private:
287295

288296
bool CheckSYCLType(QualType Ty, SourceRange Loc) {
289297
if (Ty->isVariableArrayType()) {
@@ -770,13 +778,30 @@ void Sema::ConstructSYCLKernel(FunctionDecl *KernelCallerFunc) {
770778
CreateSYCLKernelBody(*this, KernelCallerFunc, SYCLKernel);
771779
SYCLKernel->setBody(SYCLKernelBody);
772780
AddSyclKernel(SYCLKernel);
781+
}
782+
783+
void Sema::MarkDevice(void) {
773784
// Let's mark all called functions with SYCL Device attribute.
774-
MarkDeviceFunction Marker(*this);
775785
// Create the call graph so we can detect recursion and check the validity
776786
// of new operator overrides. Add the kernel function itself in case
777787
// it is recursive.
788+
MarkDeviceFunction Marker(*this);
778789
Marker.SYCLCG.addToCallGraph(getASTContext().getTranslationUnitDecl());
779-
Marker.TraverseStmt(SYCLKernelBody);
790+
for (Decl *D : SyclKernels()) {
791+
if (auto SYCLKernel = dyn_cast<FunctionDecl>(D)) {
792+
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
793+
Marker.CollectKernelSet(SYCLKernel, SYCLKernel, VisitedSet);
794+
}
795+
}
796+
for (const auto &elt : Marker.KernelSet) {
797+
if (FunctionDecl *Def = elt->getDefinition()) {
798+
if (!Def->hasAttr<SYCLDeviceAttr>()) {
799+
Def->addAttr(SYCLDeviceAttr::CreateImplicit(Context));
800+
AddSyclKernel(Def);
801+
}
802+
Marker.TraverseStmt(Def->getBody());
803+
}
804+
}
780805
}
781806

782807
// -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)