Skip to content

Commit 1c93c84

Browse files
committed
[SYCL] Support structures as kernel parameters (compiler part).
Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent eae0f80 commit 1c93c84

File tree

4 files changed

+53
-43
lines changed

4 files changed

+53
-43
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9491,4 +9491,6 @@ def err_sycl_kernel_name_class_not_top_level : Error<
94919491
def err_sycl_virtual_types : Error<
94929492
"No class with a vtable can be used in a SYCL kernel or any code included in the kernel">;
94939493
def note_sycl_used_here : Note<"used here">;
9494+
def err_sycl_non_std_layout_type : Error<
9495+
"kernel parameter has non-standard layout class/struct type">;
94949496
} // end of sema component.

clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,9 @@ class SYCLIntegrationHeader {
299299
enum kernel_param_kind_t {
300300
kind_first,
301301
kind_accessor = kind_first,
302-
kind_scalar,
303-
kind_struct,
302+
kind_std_layout,
304303
kind_sampler,
305-
kind_struct_padding, // can be added by the compiler to enforce alignment
306-
kind_last = kind_struct_padding
304+
kind_last = kind_sampler
307305
};
308306

309307
public:

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,24 @@ enum target {
3737
image_array
3838
};
3939

40+
/// Various utilities.
41+
class Util {
42+
public:
43+
// TODO SYCL use AST infrastructure instead of string matching
44+
45+
/// Checks whether given clang type is a sycl accessor class.
46+
static bool isSyclAccessorType(QualType Ty) {
47+
std::string Name = Ty.getCanonicalType().getAsString();
48+
return Name.find("class cl::sycl::accessor") != std::string::npos;
49+
}
50+
51+
/// Checks whether given clang type is a sycl stream class.
52+
static bool isSyclStreamType(QualType Ty) {
53+
std::string Name = Ty.getCanonicalType().getAsString();
54+
return Name == "stream";
55+
}
56+
};
57+
4058
static CXXRecordDecl *getKernelCallerLambdaArg(FunctionDecl *FD) {
4159
auto FirstArg = (*FD->param_begin());
4260
if (FirstArg)
@@ -271,7 +289,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
271289

272290
QualType FieldType = Field->getType();
273291
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl();
274-
if (CRD) {
292+
if (CRD && Util::isSyclAccessorType(FieldType)) {
275293
DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none);
276294
// lambda.accessor
277295
auto AccessorME = MemberExpr::Create(
@@ -373,9 +391,11 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
373391
"unsupported accessor and without initialized range");
374392
}
375393
}
376-
} else if (FieldType->isBuiltinType()) {
377-
// If field have built-in type just initialize this field
378-
// with corresponding kernel argument using '=' binary operator.
394+
} else if (CRD || FieldType->isBuiltinType()) {
395+
// If field have built-in or a structure/class type just initialize
396+
// this field with corresponding kernel argument using '=' binary
397+
// operator. The structure/class type must be copy assignable - this
398+
// holds because SYCL kernel lambdas capture arguments by copy.
379399
DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none);
380400
auto Lhs = MemberExpr::Create(
381401
S.Context, LambdaDRE, false, SourceLocation(),
@@ -416,31 +436,13 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
416436
SourceLocation());
417437
}
418438

419-
/// Various utilities.
420-
class Util {
421-
public:
422-
// TODO SYCL use AST infrastructure instead of string matching
423-
424-
/// Checks whether given clang type is a sycl accessor class.
425-
static bool isSyclAccessorType(QualType Ty) {
426-
std::string Name = Ty.getCanonicalType().getAsString();
427-
return Name.find("class cl::sycl::accessor") != std::string::npos;
428-
}
429-
430-
/// Checks whether given clang type is a sycl stream class.
431-
static bool isSyclStreamType(QualType Ty) {
432-
std::string Name = Ty.getCanonicalType().getAsString();
433-
return Name == "stream";
434-
}
435-
};
436-
437439
/// Identifies context of kernel lambda capture visitor function
438440
/// invocation.
439441
enum VisitorContext {
440442
pre_visit,
441443
pre_visit_class_field,
442444
visit_accessor,
443-
visit_scalar,
445+
visit_std_layout,
444446
visit_stream,
445447
post_visit,
446448
};
@@ -508,9 +510,16 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
508510
// stream parameter context
509511
auto F = std::get<visit_stream>(Vis);
510512
F(Cnt, V, *Fld);
513+
} else if (ArgTy->isStructureOrClassType()) {
514+
if (!ArgTy->isStandardLayoutType())
515+
Lambda->getASTContext().getDiagnostics().Report(V->getLocation(),
516+
diag::err_sycl_non_std_layout_type);
517+
// structure or class typed parameter - the same handling as a scalar
518+
auto F = std::get<visit_std_layout>(Vis);
519+
F(Cnt, V, *Fld);
511520
} else if (ArgTy->isScalarType()) {
512521
// scalar typed parameter context
513-
auto F = std::get<visit_scalar>(Vis);
522+
auto F = std::get<visit_std_layout>(Vis);
514523
F(Cnt, V, *Fld);
515524
} else {
516525
llvm_unreachable("unsupported kernel parameter type");
@@ -523,7 +532,7 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
523532
// pre-visit context the same like for accessor
524533
auto F1Range = std::get<pre_visit_class_field>(Vis);
525534
F1Range(Cnt, V, *Fld, AccessorRangeField);
526-
auto FRange = std::get<visit_scalar>(Vis);
535+
auto FRange = std::get<visit_std_layout>(Vis);
527536
FRange(Cnt, V, AccessorRangeField);
528537
// post-visit context
529538
auto F2Range = std::get<post_visit>(Vis);
@@ -568,7 +577,7 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
568577
ActualArgType =
569578
Context.getQualifiedType(PointerType.getUnqualifiedType(), Quals);
570579
},
571-
// visit_scalar
580+
// visit_std_layout
572581
[&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal) {
573582
ActualArgType = CapturedVal->getType();
574583
},
@@ -643,9 +652,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
643652
Knd = SYCLIntegrationHeader::kind_accessor;
644653
Info = static_cast<int>(AccTrg);
645654
},
646-
// visit_scalar
655+
// visit_std_layout
647656
[&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal) {
648-
Knd = SYCLIntegrationHeader::kind_scalar;
657+
// TODO this code (when used to handle a structure-typed scalar) relies
658+
// on the host and device structure layouts and sizes to be the same.
659+
// Need SYCL spec clarification on passing structures as parameters.
660+
Knd = SYCLIntegrationHeader::kind_std_layout;
649661
Info = static_cast<unsigned>(
650662
Ctx.getTypeSizeInChars(CapturedVal->getType()).getQuantity());
651663
},
@@ -740,10 +752,8 @@ static const char *paramKind2Str(KernelParamKind K) {
740752
return "kind_" #x
741753
switch (K) {
742754
CASE(accessor);
743-
CASE(scalar);
744-
CASE(struct);
755+
CASE(std_layout);
745756
CASE(sampler);
746-
CASE(struct_padding);
747757
default:
748758
return "<ERROR>";
749759
}
@@ -766,7 +776,7 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D) {
766776
cast<ClassTemplateDecl>(D)->getTemplatedDecl() : dyn_cast<TagDecl>(D);
767777

768778
if (TD && TD->isCompleteDefinition()) {
769-
// defied class constituting the kernel name is not globally
779+
// defined class constituting the kernel name is not globally
770780
// accessible - contradicts the spec
771781
Diag.Report(D->getSourceRange().getBegin(),
772782
diag::err_sycl_kernel_name_class_not_top_level);

clang/test/CodeGenSYCL/integration_header.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,21 @@
2121
// CHECK: static constexpr
2222
// CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = {
2323
// CHECK-NEXT: //--- first_kernel
24-
// CHECK-NEXT: { kernel_param_kind_t::kind_scalar, 4, 0 },
24+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
2525
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 2014, 4 },
26-
// CHECK-NEXT: { kernel_param_kind_t::kind_scalar, 1, 4 },
26+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 1, 4 },
2727
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 2016, 5 },
28-
// CHECK-NEXT: { kernel_param_kind_t::kind_scalar, 1, 5 },
28+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 1, 5 },
2929
// CHECK-EMPTY:
3030
// CHECK-NEXT: //--- ::second_namespace::second_kernel<char>
31-
// CHECK-NEXT: { kernel_param_kind_t::kind_scalar, 4, 0 },
31+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
3232
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 2016, 4 },
33-
// CHECK-NEXT: { kernel_param_kind_t::kind_scalar, 1, 4 },
33+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 1, 4 },
3434
// CHECK-EMPTY:
3535
// CHECK-NEXT: //--- ::third_kernel<1, int, ::point<X> >
36-
// CHECK-NEXT: { kernel_param_kind_t::kind_scalar, 4, 0 },
36+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
3737
// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 2016, 4 },
38-
// CHECK-NEXT: { kernel_param_kind_t::kind_scalar, 1, 4 },
38+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 1, 4 },
3939
// CHECK-EMPTY:
4040
// CHECK-NEXT: };
4141
//

0 commit comments

Comments
 (0)