Skip to content

Commit 6da42a0

Browse files
[SYCL] Add type checking to SYCL accessors. (#1741)
Signed-off-by: Chris Perkins <[email protected]>
1 parent 357e9c8 commit 6da42a0

File tree

3 files changed

+125
-28
lines changed

3 files changed

+125
-28
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10816,7 +10816,6 @@ def err_sycl_restrict : Error<
1081610816
"}0">;
1081710817
def err_sycl_virtual_types : Error<
1081810818
"No class with a vtable can be used in a SYCL kernel or any code included in the kernel">;
10819-
def note_sycl_used_here : Note<"used here">;
1082010819
def note_sycl_recursive_function_declared_here: Note<"function implemented using recursion declared here">;
1082110820
def err_sycl_non_trivially_copy_ctor_dtor_type
1082210821
: Error<"kernel parameter has non-trivially %select{copy "

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -210,19 +210,9 @@ static bool isZeroSizedArray(QualType Ty) {
210210
return false;
211211
}
212212

213-
static Sema::DeviceDiagBuilder
214-
emitDeferredDiagnosticAndNote(Sema &S, SourceRange Loc, unsigned DiagID,
215-
SourceRange UsedAtLoc) {
216-
Sema::DeviceDiagBuilder builder =
217-
S.SYCLDiagIfDeviceCode(Loc.getBegin(), DiagID);
218-
if (UsedAtLoc.isValid())
219-
S.SYCLDiagIfDeviceCode(UsedAtLoc.getBegin(), diag::note_sycl_used_here);
220-
return builder;
221-
}
222-
223-
static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
224-
llvm::DenseSet<QualType> Visited,
225-
SourceRange UsedAtLoc = SourceRange()) {
213+
static void checkSYCLType(Sema &S, QualType Ty, SourceRange Loc,
214+
llvm::DenseSet<QualType> Visited,
215+
SourceRange UsedAtLoc = SourceRange()) {
226216
// Not all variable types are supported inside SYCL kernels,
227217
// for example the quad type __float128 will cause errors in the
228218
// SPIR-V translation phase.
@@ -233,16 +223,21 @@ static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
233223
// different location than the variable declaration and we need to
234224
// inform the user of both, e.g. struct member usage vs declaration.
235225

226+
bool Emitting = false;
227+
236228
//--- check types ---
237229

238230
// zero length arrays
239-
if (isZeroSizedArray(Ty))
240-
emitDeferredDiagnosticAndNote(S, Loc, diag::err_typecheck_zero_array_size,
241-
UsedAtLoc);
231+
if (isZeroSizedArray(Ty)) {
232+
S.SYCLDiagIfDeviceCode(Loc.getBegin(), diag::err_typecheck_zero_array_size);
233+
Emitting = true;
234+
}
242235

243236
// variable length arrays
244-
if (Ty->isVariableArrayType())
245-
emitDeferredDiagnosticAndNote(S, Loc, diag::err_vla_unsupported, UsedAtLoc);
237+
if (Ty->isVariableArrayType()) {
238+
S.SYCLDiagIfDeviceCode(Loc.getBegin(), diag::err_vla_unsupported);
239+
Emitting = true;
240+
}
246241

247242
// Sub-reference array or pointer, then proceed with that type.
248243
while (Ty->isAnyPointerType() || Ty->isArrayType())
@@ -253,9 +248,14 @@ static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
253248
Ty->isSpecificBuiltinType(BuiltinType::UInt128) ||
254249
Ty->isSpecificBuiltinType(BuiltinType::LongDouble) ||
255250
(Ty->isSpecificBuiltinType(BuiltinType::Float128) &&
256-
!S.Context.getTargetInfo().hasFloat128Type()))
257-
emitDeferredDiagnosticAndNote(S, Loc, diag::err_type_unsupported, UsedAtLoc)
251+
!S.Context.getTargetInfo().hasFloat128Type())) {
252+
S.SYCLDiagIfDeviceCode(Loc.getBegin(), diag::err_type_unsupported)
258253
<< Ty.getUnqualifiedType().getCanonicalType();
254+
Emitting = true;
255+
}
256+
257+
if (Emitting && UsedAtLoc.isValid())
258+
S.SYCLDiagIfDeviceCode(UsedAtLoc.getBegin(), diag::note_used_here);
259259

260260
//--- now recurse ---
261261
// Pointers complicate recursion. Add this type to Visited.
@@ -264,16 +264,15 @@ static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
264264
return;
265265

266266
if (const auto *ATy = dyn_cast<AttributedType>(Ty))
267-
return checkSYCLVarType(S, ATy->getModifiedType(), Loc, Visited);
267+
return checkSYCLType(S, ATy->getModifiedType(), Loc, Visited);
268268

269269
if (const auto *RD = Ty->getAsRecordDecl()) {
270270
for (const auto &Field : RD->fields())
271-
checkSYCLVarType(S, Field->getType(), Field->getSourceRange(), Visited,
272-
Loc);
271+
checkSYCLType(S, Field->getType(), Field->getSourceRange(), Visited, Loc);
273272
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
274273
for (const auto &ParamTy : FPTy->param_types())
275-
checkSYCLVarType(S, ParamTy, Loc, Visited);
276-
checkSYCLVarType(S, FPTy->getReturnType(), Loc, Visited);
274+
checkSYCLType(S, ParamTy, Loc, Visited);
275+
checkSYCLType(S, FPTy->getReturnType(), Loc, Visited);
277276
}
278277
}
279278

@@ -284,7 +283,7 @@ void Sema::checkSYCLDeviceVarDecl(VarDecl *Var) {
284283
SourceRange Loc = Var->getLocation();
285284
llvm::DenseSet<QualType> Visited;
286285

287-
checkSYCLVarType(*this, Ty, Loc, Visited);
286+
checkSYCLType(*this, Ty, Loc, Visited);
288287
}
289288

290289
class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
@@ -805,6 +804,22 @@ class SyclKernelFieldChecker
805804
bool IsInvalid = false;
806805
DiagnosticsEngine &Diag;
807806

807+
void checkAccessorType(QualType Ty, SourceRange Loc) {
808+
assert(Util::isSyclAccessorType(Ty) &&
809+
"Should only be called on SYCL accessor types.");
810+
811+
const RecordDecl *RecD = Ty->getAsRecordDecl();
812+
if (const ClassTemplateSpecializationDecl *CTSD =
813+
dyn_cast<ClassTemplateSpecializationDecl>(RecD)) {
814+
const TemplateArgumentList &TAL = CTSD->getTemplateArgs();
815+
TemplateArgument TA = TAL.get(0);
816+
const QualType TemplateArgTy = TA.getAsType();
817+
818+
llvm::DenseSet<QualType> Visited;
819+
checkSYCLType(SemaRef, TemplateArgTy, Loc, Visited);
820+
}
821+
}
822+
808823
public:
809824
SyclKernelFieldChecker(Sema &S)
810825
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
@@ -836,6 +851,15 @@ class SyclKernelFieldChecker
836851
}
837852
}
838853

854+
void handleSyclAccessorType(const CXXBaseSpecifier &BS,
855+
QualType FieldTy) final {
856+
checkAccessorType(FieldTy, BS.getBeginLoc());
857+
}
858+
859+
void handleSyclAccessorType(FieldDecl *FD, QualType FieldTy) final {
860+
checkAccessorType(FieldTy, FD->getLocation());
861+
}
862+
839863
// We should be able to handle this, so we made it part of the visitor, but
840864
// this is 'to be implemented'.
841865
void handleArrayType(FieldDecl *FD, QualType FieldTy) final {
@@ -1543,7 +1567,9 @@ Sema::DeviceDiagBuilder Sema::SYCLDiagIfDeviceCode(SourceLocation Loc,
15431567
"Should only be called during SYCL compilation");
15441568
FunctionDecl *FD = dyn_cast<FunctionDecl>(getCurLexicalContext());
15451569
DeviceDiagBuilder::Kind DiagKind = [this, FD] {
1546-
if (ConstructingOpenCLKernel || !FD)
1570+
if (ConstructingOpenCLKernel)
1571+
return DeviceDiagBuilder::K_ImmediateWithCallStack;
1572+
if (!FD)
15471573
return DeviceDiagBuilder::K_Nop;
15481574
if (getEmissionStatus(FD) == Sema::FunctionEmissionStatus::Emitted)
15491575
return DeviceDiagBuilder::K_ImmediateWithCallStack;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: %clang_cc1 -I %S/Inputs -fsycl -triple spir64 -fsycl-is-device -verify -fsyntax-only %s
2+
//
3+
// Ensure SYCL type restrictions are applied to accessors as well.
4+
5+
#include <sycl.hpp>
6+
7+
using namespace cl::sycl;
8+
9+
template <typename name, typename Func>
10+
__attribute__((sycl_kernel)) void kernel(Func kernelFunc) {
11+
kernelFunc();
12+
}
13+
14+
//alias template
15+
template <typename...>
16+
using int128alias_t = __uint128_t;
17+
18+
//templated return type
19+
template <typename T>
20+
T bar() { return T(); };
21+
22+
//typedef
23+
typedef __float128 trickyFloatType;
24+
25+
//struct
26+
struct Mesh {
27+
__int128 prohib; //#struct_member
28+
};
29+
30+
int main() {
31+
accessor<int, 1, access::mode::read_write> ok_acc;
32+
// -- accessors using prohibited types
33+
accessor<__float128, 1, access::mode::read_write> f128_acc;
34+
accessor<__int128, 1, access::mode::read_write> i128_acc;
35+
accessor<long double, 1, access::mode::read_write> ld_acc;
36+
// -- pointers, aliases, auto, typedef, decltype of prohibited type
37+
accessor<__int128 *, 1, access::mode::read_write> i128Ptr_acc;
38+
accessor<int128alias_t<int>, 1, access::mode::read_write> aliased_acc;
39+
accessor<trickyFloatType, 1, access::mode::read_write> typedef_acc;
40+
auto V = bar<__int128>();
41+
accessor<decltype(V), 1, access::mode::read_write> declty_acc;
42+
// -- Accessor of struct that contains a prohibited type.
43+
accessor<Mesh, 1, access::mode::read_write> struct_acc;
44+
45+
kernel<class use_local>(
46+
[=]() {
47+
ok_acc.use();
48+
49+
// -- accessors using prohibited types
50+
// expected-error@+1 {{'__float128' is not supported on this target}}
51+
f128_acc.use();
52+
// expected-error@+1 {{'__int128' is not supported on this target}}
53+
i128_acc.use();
54+
// expected-error@+1 {{'long double' is not supported on this target}}
55+
ld_acc.use();
56+
57+
// -- pointers, aliases, auto, typedef, decltype of prohibited type
58+
// expected-error@+1 {{'__int128' is not supported on this target}}
59+
i128Ptr_acc.use();
60+
// expected-error@+1 {{'unsigned __int128' is not supported on this target}}
61+
aliased_acc.use();
62+
// expected-error@+1 {{'__float128' is not supported on this target}}
63+
typedef_acc.use();
64+
// expected-error@+1 {{'__int128' is not supported on this target}}
65+
declty_acc.use();
66+
67+
// -- Accessor of struct that contains a prohibited type.
68+
// expected-error@#struct_member {{'__int128' is not supported on this target}}
69+
// expected-note@+1 {{used here}}
70+
struct_acc.use();
71+
});
72+
}

0 commit comments

Comments
 (0)