Skip to content

Commit 2438f61

Browse files
[SYCL] Refactor semantic checks for variable types (#1513)
The various type checks have been steadily moving out of `CheckSYCLType` which is called by most of the AST Visitor methods. Here we finally move the last lingering type check (for VLAs) into the `CheckSYCLVarType` function and delete `CheckSYCLType` and most of its AST Visitor method callers. Some few of the AST Visitor methods are still used for other checks. Signed-off-by: Chris Perkins <[email protected]>
1 parent 3b8dd54 commit 2438f61

File tree

2 files changed

+21
-106
lines changed

2 files changed

+21
-106
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 4 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
236236
emitDeferredDiagnosticAndNote(S, Loc, diag::err_typecheck_zero_array_size,
237237
UsedAtLoc);
238238

239+
// variable length arrays
240+
if (Ty->isVariableArrayType())
241+
emitDeferredDiagnosticAndNote(S, Loc, diag::err_vla_unsupported, UsedAtLoc);
242+
239243
// Sub-reference array or pointer, then proceed with that type.
240244
while (Ty->isAnyPointerType() || Ty->isArrayType())
241245
Ty = QualType{Ty->getPointeeOrArrayElementType(), 0};
@@ -284,9 +288,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
284288
: RecursiveASTVisitor<MarkDeviceFunction>(), SemaRef(S) {}
285289

286290
bool VisitCallExpr(CallExpr *e) {
287-
for (const auto &Arg : e->arguments())
288-
CheckSYCLType(Arg->getType(), Arg->getSourceRange());
289-
290291
if (FunctionDecl *Callee = e->getDirectCallee()) {
291292
Callee = Callee->getCanonicalDecl();
292293
assert(Callee && "Device function canonical decl must be available");
@@ -308,8 +309,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
308309
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
309310
<< Sema::KernelCallVirtualFunction;
310311

311-
CheckSYCLType(Callee->getReturnType(), Callee->getSourceRange());
312-
313312
if (auto const *FD = dyn_cast<FunctionDecl>(Callee)) {
314313
// FIXME: We need check all target specified attributes for error if
315314
// that function with attribute can not be called from sycl kernel. The
@@ -338,12 +337,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
338337
return true;
339338
}
340339

341-
bool VisitCXXConstructExpr(CXXConstructExpr *E) {
342-
for (const auto &Arg : E->arguments())
343-
CheckSYCLType(Arg->getType(), Arg->getSourceRange());
344-
return true;
345-
}
346-
347340
bool VisitCXXTypeidExpr(CXXTypeidExpr *E) {
348341
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) << Sema::KernelRTTI;
349342
return true;
@@ -354,35 +347,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
354347
return true;
355348
}
356349

357-
bool VisitTypedefNameDecl(TypedefNameDecl *TD) {
358-
CheckSYCLType(TD->getUnderlyingType(), TD->getLocation());
359-
return true;
360-
}
361-
362-
bool VisitRecordDecl(RecordDecl *RD) {
363-
CheckSYCLType(QualType{RD->getTypeForDecl(), 0}, RD->getLocation());
364-
return true;
365-
}
366-
367-
bool VisitParmVarDecl(VarDecl *VD) {
368-
CheckSYCLType(VD->getType(), VD->getLocation());
369-
return true;
370-
}
371-
372-
bool VisitVarDecl(VarDecl *VD) {
373-
CheckSYCLType(VD->getType(), VD->getLocation());
374-
return true;
375-
}
376-
377-
bool VisitDeclRefExpr(DeclRefExpr *E) {
378-
Decl *D = E->getDecl();
379-
if (SemaRef.isKnownGoodSYCLDecl(D))
380-
return true;
381-
382-
CheckSYCLType(E->getType(), E->getSourceRange());
383-
return true;
384-
}
385-
386350
// The call graph for this translation unit.
387351
CallGraph SYCLCG;
388352
// The set of functions called by a kernel function.
@@ -506,64 +470,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
506470
}
507471

508472
private:
509-
bool CheckSYCLType(QualType Ty, SourceRange Loc) {
510-
llvm::DenseSet<QualType> visited;
511-
return CheckSYCLType(Ty, Loc, visited);
512-
}
513-
514-
bool CheckSYCLType(QualType Ty, SourceRange Loc,
515-
llvm::DenseSet<QualType> &Visited) {
516-
if (Ty->isVariableArrayType()) {
517-
SemaRef.Diag(Loc.getBegin(), diag::err_vla_unsupported);
518-
return false;
519-
}
520-
521-
while (Ty->isAnyPointerType() || Ty->isArrayType())
522-
Ty = QualType{Ty->getPointeeOrArrayElementType(), 0};
523-
524-
// Pointers complicate recursion. Add this type to Visited.
525-
// If already there, bail out.
526-
if (!Visited.insert(Ty).second)
527-
return true;
528-
529-
if (const auto *ATy = dyn_cast<AttributedType>(Ty))
530-
return CheckSYCLType(ATy->getModifiedType(), Loc, Visited);
531-
532-
if (const auto *CRD = Ty->getAsCXXRecordDecl()) {
533-
// If the class is a forward declaration - skip it, because otherwise we
534-
// would query property of class with no definition, which results in
535-
// clang crash.
536-
if (!CRD->hasDefinition())
537-
return true;
538-
539-
for (const auto &Field : CRD->fields()) {
540-
if (!CheckSYCLType(Field->getType(), Field->getSourceRange(),
541-
Visited)) {
542-
if (SemaRef.getLangOpts().SYCLIsDevice)
543-
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
544-
return false;
545-
}
546-
}
547-
} else if (const auto *RD = Ty->getAsRecordDecl()) {
548-
for (const auto &Field : RD->fields()) {
549-
if (!CheckSYCLType(Field->getType(), Field->getSourceRange(),
550-
Visited)) {
551-
if (SemaRef.getLangOpts().SYCLIsDevice)
552-
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
553-
return false;
554-
}
555-
}
556-
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
557-
for (const auto &ParamTy : FPTy->param_types())
558-
if (!CheckSYCLType(ParamTy, Loc, Visited))
559-
return false;
560-
return CheckSYCLType(FPTy->getReturnType(), Loc, Visited);
561-
} else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
562-
return CheckSYCLType(FTy->getReturnType(), Loc, Visited);
563-
}
564-
return true;
565-
}
566-
567473
Sema &SemaRef;
568474
};
569475

clang/test/SemaSYCL/sycl-restrict.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,16 @@ void no_restriction(int p) {
3636
int index[p + 2];
3737
}
3838
void restriction(int p) {
39-
int index[p + 2]; // expected-error {{variable length arrays are not supported for the current target}}
39+
// This particular violation is nested under two kernels with intermediate function calls.
40+
// e.g. main -> 1stkernel -> usage -> 2ndkernel -> isa_B -> restriction -> !!
41+
// Because the error is in two different kernels, we are given helpful notes for the origination of the error, twice.
42+
// expected-note@#call_usage {{called by 'operator()'}}
43+
// expected-note@#call_kernelFunc {{called by 'kernel_single_task<fake_kernel, (lambda at}}
44+
// expected-note@#call_isa_B 2{{called by 'operator()'}}
45+
// expected-note@#call_rtti_kernel {{called by 'usage'}}
46+
// expected-note@#rtti_kernel 2{{called by 'kernel1<kernel_name, (lambda at }}
47+
// expected-note@#call_vla 2{{called by 'isa_B'}}
48+
int index[p + 2]; // expected-error 2{{variable length arrays are not supported for the current target}}
4049
}
4150
} // namespace Check_VLA_Restriction
4251

@@ -67,8 +76,8 @@ bool isa_B(A *a) {
6776
if (f1 == f2) // expected-note 2{{called by 'isa_B'}}
6877
return false;
6978

70-
Check_VLA_Restriction::restriction(7);
71-
int *ip = new int; // expected-error 2{{SYCL kernel cannot allocate storage}}
79+
Check_VLA_Restriction::restriction(7); //#call_vla
80+
int *ip = new int; // expected-error 2{{SYCL kernel cannot allocate storage}}
7281
int i;
7382
int *p3 = new (&i) int; // no error on placement new
7483
OverloadedNewDelete *x = new (struct OverloadedNewDelete); // expected-note 2{{called by 'isa_B'}}
@@ -79,7 +88,7 @@ bool isa_B(A *a) {
7988

8089
template <typename N, typename L>
8190
__attribute__((sycl_kernel)) void kernel1(L l) {
82-
l(); // expected-note 6{{called by 'kernel1<kernel_name, (lambda at }}
91+
l(); //#rtti_kernel // expected-note 6{{called by 'kernel1<kernel_name, (lambda at }}
8392
}
8493
} // namespace Check_RTTI_Restriction
8594

@@ -189,9 +198,9 @@ void usage(myFuncDef functionPtr) {
189198
// expected-error@+1 {{SYCL kernel cannot use a non-const global variable}}
190199
b.f(); // expected-error {{SYCL kernel cannot call a virtual function}}
191200

192-
Check_RTTI_Restriction::kernel1<class kernel_name>([]() { // expected-note 3{{called by 'usage'}}
201+
Check_RTTI_Restriction::kernel1<class kernel_name>([]() { //#call_rtti_kernel // expected-note 3{{called by 'usage'}}
193202
Check_RTTI_Restriction::A *a;
194-
Check_RTTI_Restriction::isa_B(a); // expected-note 6{{called by 'operator()'}}
203+
Check_RTTI_Restriction::isa_B(a); //#call_isa_B // expected-note 6{{called by 'operator()'}}
195204
});
196205

197206
// ======= Float128 Not Allowed in Kernel ==========
@@ -323,7 +332,7 @@ int use2(a_type ab, a_type *abp) {
323332

324333
template <typename name, typename Func>
325334
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
326-
kernelFunc(); // expected-note 7{{called by 'kernel_single_task<fake_kernel, (lambda at}}
335+
kernelFunc(); //#call_kernelFunc // expected-note 7{{called by 'kernel_single_task<fake_kernel, (lambda at}}
327336
}
328337

329338
int main() {
@@ -340,7 +349,7 @@ int main() {
340349
auto notACrime = &commitInfraction;
341350

342351
kernel_single_task<class fake_kernel>([=]() {
343-
usage(&addInt); // expected-note 5{{called by 'operator()'}}
352+
usage(&addInt); //#call_usage // expected-note 5{{called by 'operator()'}}
344353
a_type *p;
345354
use2(ab, p); // expected-note 2{{called by 'operator()'}}
346355
});

0 commit comments

Comments
 (0)