Skip to content

Commit 0e7639a

Browse files
Ilya Stepykinbader
authored andcommitted
[SYCL] Handle captures by ref in kernel lambda
Captures by reference wasn't properly handled by the compiler which led to a crash. Values captured by ref are not standard layout so if we see one, emit a corresponding error as for other non-standard layout structures/classes. This patch also slightly changes diagnostic for non-standard layout types: in case of explicit capture point to the capture location, otherwise point to the first usage of a captured variable inside a lambda. This should be more informative because standard layout property of a capture inside lambda may not match the one outside(especially for capture by ref). Signed-off-by: Ilya Stepykin <[email protected]>
1 parent be75e56 commit 0e7639a

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,8 @@ static target getAccessTarget(const ClassTemplateSpecializationDecl *AccTy) {
989989
// Fields of kernel object must be initialized with SYCL kernel arguments so
990990
// in the following function we extract types of kernel object fields and add it
991991
// to the array with kernel parameters descriptors.
992-
static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
992+
// Returns true if all arguments are successfully built.
993+
static bool buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
993994
SmallVectorImpl<ParamDesc> &ParamDescs) {
994995
const LambdaCapture *Cpt = KernelObj->captures_begin();
995996
auto CreateAndAddPrmDsc = [&](const FieldDecl *Fld, const QualType &ArgType) {
@@ -1040,6 +1041,7 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
10401041
}
10411042
};
10421043

1044+
bool AllArgsAreValid = true;
10431045
// Run through kernel object fields and create corresponding kernel
10441046
// parameters descriptors. There are a several possible cases:
10451047
// - Kernel object field is a SYCL special object (SYCL accessor or SYCL
@@ -1054,17 +1056,22 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
10541056
QualType ArgTy = Fld->getType();
10551057
if (Util::isSyclAccessorType(ArgTy) || Util::isSyclSamplerType(ArgTy)) {
10561058
createSpecialSYCLObjParamDesc(Fld, ArgTy);
1057-
} else if (ArgTy->isStructureOrClassType()) {
1059+
} else if (!ArgTy->isStandardLayoutType()) {
10581060
// SYCL v1.2.1 s4.8.10 p5:
10591061
// C++ non-standard layout values must not be passed as arguments to a
10601062
// kernel that is compiled for a device.
1061-
if (!ArgTy->isStandardLayoutType()) {
1062-
const DeclaratorDecl *V =
1063-
Cpt ? cast<DeclaratorDecl>(Cpt->getCapturedVar())
1064-
: cast<DeclaratorDecl>(Fld);
1065-
KernelObj->getASTContext().getDiagnostics().Report(
1066-
V->getLocation(), diag::err_sycl_non_std_layout_type);
1067-
}
1063+
const auto &DiagLocation =
1064+
Cpt ? Cpt->getLocation() : cast<DeclaratorDecl>(Fld)->getLocation();
1065+
1066+
Context.getDiagnostics().Report(DiagLocation,
1067+
diag::err_sycl_non_std_layout_type);
1068+
1069+
// Set the flag and continue processing so we can emit error for each
1070+
// invalid argument.
1071+
AllArgsAreValid = false;
1072+
} else if (ArgTy->isStructureOrClassType()) {
1073+
assert(ArgTy->isStandardLayoutType());
1074+
10681075
CreateAndAddPrmDsc(Fld, ArgTy);
10691076

10701077
// Create descriptors for each accessor field in the class or struct
@@ -1077,14 +1084,20 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
10771084
PointeeTy = Context.getQualifiedType(PointeeTy.getUnqualifiedType(),
10781085
Quals);
10791086
QualType ModTy = Context.getPointerType(PointeeTy);
1080-
1087+
10811088
CreateAndAddPrmDsc(Fld, ModTy);
10821089
} else if (ArgTy->isScalarType()) {
10831090
CreateAndAddPrmDsc(Fld, ArgTy);
10841091
} else {
10851092
llvm_unreachable("Unsupported kernel parameter type");
10861093
}
1094+
1095+
// Update capture iterator as we process arguments
1096+
if (Cpt && Cpt != KernelObj->captures_end())
1097+
++Cpt;
10871098
}
1099+
1100+
return AllArgsAreValid;
10881101
}
10891102

10901103
/// Adds necessary data describing given kernel to the integration header.
@@ -1238,7 +1251,8 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
12381251

12391252
// Build list of kernel arguments
12401253
llvm::SmallVector<ParamDesc, 16> ParamDescs;
1241-
buildArgTys(getASTContext(), LE, ParamDescs);
1254+
if (!buildArgTys(getASTContext(), LE, ParamDescs))
1255+
return;
12421256

12431257
// Extract name from kernel caller parameters and mangle it.
12441258
const TemplateArgumentList *TemplateArgs =

clang/test/SemaSYCL/non-std-layout-param.cpp

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,61 @@ __attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
2020

2121

2222
void test() {
23-
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
2423
C C0;
2524
C0.Y=0;
26-
kernel_single_task<class MyKernel>([=] { (void)C0.Y; });
25+
kernel_single_task<class MyKernel>([=] {
26+
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
27+
(void)C0.Y;
28+
});
29+
}
30+
31+
void test_capture_explicit_ref() {
32+
int p = 0;
33+
double q = 0;
34+
float s = 0;
35+
kernel_single_task<class kernel_capture_single_ref>([
36+
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
37+
&p,
38+
q,
39+
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
40+
&s] {
41+
(void) q;
42+
(void) p;
43+
(void) s;
44+
});
2745
}
2846

47+
void test_capture_implicit_refs() {
48+
int p = 0;
49+
double q = 0;
50+
kernel_single_task<class kernel_capture_refs>([&] {
51+
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
52+
(void) p;
53+
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
54+
(void) q;
55+
});
56+
}
57+
58+
struct Kernel {
59+
void operator()() {
60+
(void) c1;
61+
(void) c2;
62+
(void) p;
63+
(void) q;
64+
}
65+
66+
int p;
67+
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
68+
C c1;
69+
70+
int q;
71+
72+
// expected-error@+1 {{kernel parameter has non-standard layout class/struct type}}
73+
C c2;
74+
};
75+
76+
void test_struct_field() {
77+
Kernel k{};
78+
79+
kernel_single_task<class kernel_object>(k);
80+
}

0 commit comments

Comments
 (0)