Skip to content

Commit 08b20c1

Browse files
committed
[SYCL] Implement virtual-table prohibit in SYCL device code.
According to SYCL specification, virtual tables are illegal in device code. This accomplishes that in a couple of ways. First, there is a code-gen assert that will prevent emission of virtual tables always. Second, it prevents virtual tables from being 'used', so non-SYCL kernel code cannot cause v-tables to be emitted. Finally, SYCL-specific functions are checked for usage of polymorphic functions and a diagnostic is emitted. Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent 33ff936 commit 08b20c1

File tree

8 files changed

+252
-4
lines changed

8 files changed

+252
-4
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9488,4 +9488,7 @@ def err_sycl_attribute_address_space_invalid : Error<
94889488
def err_sycl_kernel_name_class_not_top_level : Error<
94899489
"kernel name class and its template argument classes' declarations can only "
94909490
"nest in a namespace: %0">;
9491+
def err_sycl_virtual_types : Error<
9492+
"No class with a vtable can be used in a SYCL kernel or any code included in the kernel">;
9493+
def note_sycl_used_here : Note<"used here">;
94919494
} // end of sema component.

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,10 @@ llvm::StructType *CodeGenTypes::ConvertRecordDeclType(const RecordDecl *RD) {
719719
return Ty;
720720
}
721721

722+
assert((!Context.getLangOpts().SYCL || !isa<CXXRecordDecl>(RD) ||
723+
!dyn_cast<CXXRecordDecl>(RD)->isPolymorphic()) &&
724+
"Types with virtual functions not allowed in SYCL");
725+
722726
// Okay, this is a definition of a type. Compile the implementation now.
723727
bool InsertResult = RecordsBeingLaidOut.insert(Key).second;
724728
(void)InsertResult;

clang/lib/Sema/SemaDeclCXX.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15013,6 +15013,10 @@ void Sema::MarkVTableUsed(SourceLocation Loc, CXXRecordDecl *Class,
1501315013
return;
1501415014
}
1501515015

15016+
// No VTable usage is legal in SYCL, so don't bother marking them used.
15017+
if (getLangOpts().SYCL)
15018+
return;
15019+
1501615020
// Try to insert this class into the map.
1501715021
LoadExternalVTableUses();
1501815022
Class = Class->getCanonicalDecl();

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,113 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
4949
MarkDeviceFunction(Sema &S)
5050
: RecursiveASTVisitor<MarkDeviceFunction>(), SemaRef(S) {}
5151
bool VisitCallExpr(CallExpr *e) {
52+
for (const auto &Arg : e->arguments())
53+
CheckTypeForVirtual(Arg->getType(), Arg->getSourceRange());
54+
5255
if (FunctionDecl *Callee = e->getDirectCallee()) {
5356
// Remember that all SYCL kernel functions have deferred
5457
// instantiation as template functions. It means that
5558
// all functions used by kernel have already been parsed and have
5659
// definitions.
60+
61+
CheckTypeForVirtual(Callee->getReturnType(), Callee->getSourceRange());
62+
5763
if (FunctionDecl *Def = Callee->getDefinition()) {
5864
if (!Def->hasAttr<SYCLDeviceAttr>()) {
5965
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
6066
this->TraverseStmt(Def->getBody());
61-
// But because parser works with top level declarations and CodeGen
62-
// already saw and ignored our function without device attribute we
63-
// need to add this function into SYCL kernels array to show it
64-
// this function again.
6567
SemaRef.AddSyclKernel(Def);
6668
}
6769
}
6870
}
6971
return true;
7072
}
7173

74+
bool VisitCXXConstructExpr(CXXConstructExpr *E) {
75+
for (const auto &Arg : E->arguments())
76+
CheckTypeForVirtual(Arg->getType(), Arg->getSourceRange());
77+
78+
CXXConstructorDecl *Ctor = E->getConstructor();
79+
80+
if (FunctionDecl *Def = Ctor->getDefinition()) {
81+
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
82+
this->TraverseStmt(Def->getBody());
83+
SemaRef.AddSyclKernel(Def);
84+
}
85+
86+
const auto *ConstructedType = Ctor->getParent();
87+
if (ConstructedType->hasUserDeclaredDestructor()) {
88+
CXXDestructorDecl *Dtor = ConstructedType->getDestructor();
89+
90+
if (FunctionDecl *Def = Dtor->getDefinition()) {
91+
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
92+
this->TraverseStmt(Def->getBody());
93+
SemaRef.AddSyclKernel(Def);
94+
}
95+
}
96+
return true;
97+
}
98+
99+
bool VisitTypedefNameDecl(TypedefNameDecl *TD) {
100+
CheckTypeForVirtual(TD->getUnderlyingType(), TD->getLocation());
101+
return true;
102+
}
103+
104+
bool VisitRecordDecl(RecordDecl *RD) {
105+
CheckTypeForVirtual(QualType{RD->getTypeForDecl(), 0}, RD->getLocation());
106+
return true;
107+
}
108+
109+
bool VisitParmVarDecl(VarDecl *VD) {
110+
CheckTypeForVirtual(VD->getType(), VD->getLocation());
111+
return true;
112+
}
113+
114+
bool VisitVarDecl(VarDecl *VD) {
115+
CheckTypeForVirtual(VD->getType(), VD->getLocation());
116+
return true;
117+
}
118+
119+
bool VisitDeclRefExpr(DeclRefExpr *E) {
120+
CheckTypeForVirtual(E->getType(), E->getSourceRange());
121+
return true;
122+
}
123+
72124
private:
125+
bool CheckTypeForVirtual(QualType Ty, SourceRange Loc) {
126+
while (Ty->isAnyPointerType() || Ty->isArrayType())
127+
Ty = QualType{Ty->getPointeeOrArrayElementType(), 0};
128+
129+
if (const auto *CRD = Ty->getAsCXXRecordDecl()) {
130+
if (CRD->isPolymorphic()) {
131+
SemaRef.Diag(CRD->getLocation(), diag::err_sycl_virtual_types);
132+
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
133+
return false;
134+
}
135+
136+
for (const auto &Field : CRD->fields()) {
137+
if (!CheckTypeForVirtual(Field->getType(), Field->getSourceRange())) {
138+
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
139+
return false;
140+
}
141+
}
142+
} else if (const auto *RD = Ty->getAsRecordDecl()) {
143+
for (const auto &Field : RD->fields()) {
144+
if (!CheckTypeForVirtual(Field->getType(), Field->getSourceRange())) {
145+
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
146+
return false;
147+
}
148+
}
149+
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
150+
for (const auto &ParamTy : FPTy->param_types())
151+
if (!CheckTypeForVirtual(ParamTy, Loc))
152+
return false;
153+
return CheckTypeForVirtual(FPTy->getReturnType(), Loc);
154+
} else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
155+
return CheckTypeForVirtual(FTy->getReturnType(), Loc);
156+
}
157+
return true;
158+
}
73159
Sema &SemaRef;
74160
};
75161

clang/test/SemaSYCL/no-vtables.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -verify -fsyntax-only -x c++ -emit-llvm-only %s
2+
// expected-no-diagnostics
3+
// Should never fail, since the type is never used in kernel code.
4+
5+
struct Base {
6+
virtual void f(){}
7+
};
8+
9+
struct Inherit : Base {
10+
virtual void f() override {}
11+
};
12+
13+
void always_uses() {
14+
Inherit u;
15+
}
16+
17+
void usage() {
18+
}
19+
20+
21+
template <typename name, typename Func>
22+
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
23+
kernelFunc();
24+
}
25+
int main() {
26+
always_uses();
27+
kernel_single_task<class fake_kernel>([]() { usage(); });
28+
return 0;
29+
}
30+

clang/test/SemaSYCL/no-vtables2.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -Wno-return-type -verify -fsyntax-only -x c++ -emit-llvm-only %s
2+
3+
struct Base {
4+
virtual void f() const {}
5+
};
6+
7+
// expected-error@+1 9{{No class with a vtable can be used in a SYCL kernel or any code included in the kernel}}
8+
struct Inherit : Base {
9+
virtual void f() const override {}
10+
};
11+
12+
Inherit always_uses() {
13+
Inherit u;
14+
}
15+
16+
static constexpr Inherit IH;
17+
18+
// expected-note@+1{{used here}}
19+
Inherit *usage_child(){}
20+
21+
// expected-note@+1{{used here}}
22+
Inherit usage() {
23+
// expected-note@+1{{used here}}
24+
Inherit u;
25+
// expected-note@+1{{used here}}
26+
Inherit *u_ptr;
27+
28+
// expected-note@+1{{used here}}
29+
using foo = Inherit;
30+
// expected-note@+1{{used here}}
31+
typedef Inherit bar;
32+
// expected-note@+1{{used here}}
33+
IH.f();
34+
35+
// expected-note@+1{{used here}}
36+
usage_child();
37+
}
38+
39+
40+
template <typename name, typename Func>
41+
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
42+
kernelFunc();
43+
}
44+
int main() {
45+
// expected-note@+1{{used here}}
46+
kernel_single_task<class fake_kernel>([]() { usage(); });
47+
return 0;
48+
}
49+

clang/test/SemaSYCL/no-vtables3.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -Wno-return-type -verify -fsyntax-only -x c++ -emit-llvm-only %s
2+
3+
struct Base {
4+
virtual void f() const {}
5+
};
6+
7+
// expected-error@+1 3{{No class with a vtable can be used in a SYCL kernel or any code included in the kernel}}
8+
struct Inherit : Base {
9+
virtual void f() const override {}
10+
};
11+
12+
struct Wrapper{
13+
Wrapper() {
14+
// expected-note@+1{{used here}}
15+
Inherit IH;
16+
}
17+
18+
void Func() {
19+
// expected-note@+1{{used here}}
20+
Inherit IH;
21+
}
22+
23+
~Wrapper() {
24+
// expected-note@+1{{used here}}
25+
Inherit IH;
26+
}
27+
};
28+
29+
void usage() {
30+
Wrapper WR;
31+
WR.Func();
32+
}
33+
34+
template <typename name, typename Func>
35+
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
36+
kernelFunc();
37+
}
38+
int main() {
39+
kernel_single_task<class fake_kernel>([]() { usage(); });
40+
return 0;
41+
}
42+

clang/test/SemaSYCL/no-vtables4.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -Wno-return-type -verify -fsyntax-only -x c++ -emit-llvm-only %s
2+
3+
struct Base {
4+
virtual void f() const {}
5+
};
6+
7+
// expected-error@+1{{No class with a vtable can be used in a SYCL kernel or any code included in the kernel}}
8+
struct Inherit : Base {
9+
virtual void f() const override {}
10+
};
11+
12+
struct Wrapper{
13+
// expected-note@+1{{used here}}
14+
Inherit I;
15+
};
16+
17+
void usage() {
18+
// expected-note@+1{{used here}}
19+
Wrapper WR;
20+
}
21+
22+
template <typename name, typename Func>
23+
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
24+
kernelFunc();
25+
}
26+
int main() {
27+
kernel_single_task<class fake_kernel>([]() { usage(); });
28+
return 0;
29+
}
30+

0 commit comments

Comments
 (0)