@@ -37,6 +37,24 @@ enum target {
37
37
image_array
38
38
};
39
39
40
+ // / Various utilities.
41
+ class Util {
42
+ public:
43
+ // TODO SYCL use AST infrastructure instead of string matching
44
+
45
+ // / Checks whether given clang type is a sycl accessor class.
46
+ static bool isSyclAccessorType (QualType Ty) {
47
+ std::string Name = Ty.getCanonicalType ().getAsString ();
48
+ return Name.find (" class cl::sycl::accessor" ) != std::string::npos;
49
+ }
50
+
51
+ // / Checks whether given clang type is a sycl stream class.
52
+ static bool isSyclStreamType (QualType Ty) {
53
+ std::string Name = Ty.getCanonicalType ().getAsString ();
54
+ return Name == " stream" ;
55
+ }
56
+ };
57
+
40
58
static CXXRecordDecl *getKernelCallerLambdaArg (FunctionDecl *FD) {
41
59
auto FirstArg = (*FD->param_begin ());
42
60
if (FirstArg)
@@ -271,7 +289,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
271
289
272
290
QualType FieldType = Field->getType ();
273
291
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl ();
274
- if (CRD) {
292
+ if (CRD && Util::isSyclAccessorType (FieldType) ) {
275
293
DeclAccessPair FieldDAP = DeclAccessPair::make (Field, AS_none);
276
294
// lambda.accessor
277
295
auto AccessorME = MemberExpr::Create (
@@ -373,9 +391,11 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
373
391
" unsupported accessor and without initialized range" );
374
392
}
375
393
}
376
- } else if (FieldType->isBuiltinType ()) {
377
- // If field have built-in type just initialize this field
378
- // with corresponding kernel argument using '=' binary operator.
394
+ } else if (CRD || FieldType->isBuiltinType ()) {
395
+ // If field have built-in or a structure/class type just initialize
396
+ // this field with corresponding kernel argument using '=' binary
397
+ // operator. The structure/class type must be copy assignable - this
398
+ // holds because SYCL kernel lambdas capture arguments by copy.
379
399
DeclAccessPair FieldDAP = DeclAccessPair::make (Field, AS_none);
380
400
auto Lhs = MemberExpr::Create (
381
401
S.Context , LambdaDRE, false , SourceLocation (),
@@ -416,31 +436,13 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
416
436
SourceLocation ());
417
437
}
418
438
419
- // / Various utilities.
420
- class Util {
421
- public:
422
- // TODO SYCL use AST infrastructure instead of string matching
423
-
424
- // / Checks whether given clang type is a sycl accessor class.
425
- static bool isSyclAccessorType (QualType Ty) {
426
- std::string Name = Ty.getCanonicalType ().getAsString ();
427
- return Name.find (" class cl::sycl::accessor" ) != std::string::npos;
428
- }
429
-
430
- // / Checks whether given clang type is a sycl stream class.
431
- static bool isSyclStreamType (QualType Ty) {
432
- std::string Name = Ty.getCanonicalType ().getAsString ();
433
- return Name == " stream" ;
434
- }
435
- };
436
-
437
439
// / Identifies context of kernel lambda capture visitor function
438
440
// / invocation.
439
441
enum VisitorContext {
440
442
pre_visit,
441
443
pre_visit_class_field,
442
444
visit_accessor,
443
- visit_scalar ,
445
+ visit_std_layout ,
444
446
visit_stream,
445
447
post_visit,
446
448
};
@@ -508,9 +510,16 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
508
510
// stream parameter context
509
511
auto F = std::get<visit_stream>(Vis);
510
512
F (Cnt, V, *Fld);
513
+ } else if (ArgTy->isStructureOrClassType ()) {
514
+ if (!ArgTy->isStandardLayoutType ())
515
+ Lambda->getASTContext ().getDiagnostics ().Report (V->getLocation (),
516
+ diag::err_sycl_non_std_layout_type);
517
+ // structure or class typed parameter - the same handling as a scalar
518
+ auto F = std::get<visit_std_layout>(Vis);
519
+ F (Cnt, V, *Fld);
511
520
} else if (ArgTy->isScalarType ()) {
512
521
// scalar typed parameter context
513
- auto F = std::get<visit_scalar >(Vis);
522
+ auto F = std::get<visit_std_layout >(Vis);
514
523
F (Cnt, V, *Fld);
515
524
} else {
516
525
llvm_unreachable (" unsupported kernel parameter type" );
@@ -523,7 +532,7 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
523
532
// pre-visit context the same like for accessor
524
533
auto F1Range = std::get<pre_visit_class_field>(Vis);
525
534
F1Range (Cnt, V, *Fld, AccessorRangeField);
526
- auto FRange = std::get<visit_scalar >(Vis);
535
+ auto FRange = std::get<visit_std_layout >(Vis);
527
536
FRange (Cnt, V, AccessorRangeField);
528
537
// post-visit context
529
538
auto F2Range = std::get<post_visit>(Vis);
@@ -568,7 +577,7 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
568
577
ActualArgType =
569
578
Context.getQualifiedType (PointerType.getUnqualifiedType (), Quals);
570
579
},
571
- // visit_scalar
580
+ // visit_std_layout
572
581
[&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal) {
573
582
ActualArgType = CapturedVal->getType ();
574
583
},
@@ -643,9 +652,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
643
652
Knd = SYCLIntegrationHeader::kind_accessor;
644
653
Info = static_cast <int >(AccTrg);
645
654
},
646
- // visit_scalar
655
+ // visit_std_layout
647
656
[&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal) {
648
- Knd = SYCLIntegrationHeader::kind_scalar;
657
+ // TODO this code (when used to handle a structure-typed scalar) relies
658
+ // on the host and device structure layouts and sizes to be the same.
659
+ // Need SYCL spec clarification on passing structures as parameters.
660
+ Knd = SYCLIntegrationHeader::kind_std_layout;
649
661
Info = static_cast <unsigned >(
650
662
Ctx.getTypeSizeInChars (CapturedVal->getType ()).getQuantity ());
651
663
},
@@ -740,10 +752,8 @@ static const char *paramKind2Str(KernelParamKind K) {
740
752
return " kind_" #x
741
753
switch (K) {
742
754
CASE (accessor);
743
- CASE (scalar);
744
- CASE (struct );
755
+ CASE (std_layout);
745
756
CASE (sampler);
746
- CASE (struct_padding);
747
757
default :
748
758
return " <ERROR>" ;
749
759
}
@@ -766,7 +776,7 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D) {
766
776
cast<ClassTemplateDecl>(D)->getTemplatedDecl () : dyn_cast<TagDecl>(D);
767
777
768
778
if (TD && TD->isCompleteDefinition ()) {
769
- // defied class constituting the kernel name is not globally
779
+ // defined class constituting the kernel name is not globally
770
780
// accessible - contradicts the spec
771
781
Diag.Report (D->getSourceRange ().getBegin (),
772
782
diag::err_sycl_kernel_name_class_not_top_level);
0 commit comments