Skip to content

Commit ce9fe95

Browse files
committed
[SYCL] Comment kernel wrapper generation code
Signed-off-by: Mariya Podchishchaeva <[email protected]>
1 parent 969129f commit ce9fe95

File tree

1 file changed

+105
-40
lines changed

1 file changed

+105
-40
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 105 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ class KernelBodyTransform : public TreeTransform<KernelBodyTransform> {
396396
Sema &SemaRef;
397397
};
398398

399-
static FunctionDecl *CreateSYCLKernelDeclaration(ASTContext &Context,
400-
StringRef Name,
401-
ArrayRef<ParamDesc> ParamDescs) {
399+
static FunctionDecl *
400+
CreateSYCLKernelDeclaration(ASTContext &Context, StringRef Name,
401+
ArrayRef<ParamDesc> ParamDescs) {
402402

403403
DeclContext *DC = Context.getTranslationUnitDecl();
404404
QualType RetTy = Context.VoidTy;
@@ -448,24 +448,30 @@ static CXXMethodDecl *getInitMethod(const CXXRecordDecl *CRD) {
448448
return InitMethod;
449449
}
450450

451-
static CompoundStmt *
452-
CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *KernelDecl) {
451+
// Creates body for new SYCL kernel. This body contains initialization of kernel
452+
// object fields with kernel parameters and a little bit transformed body of the
453+
// kernel caller function.
454+
static CompoundStmt *CreateSYCLKernelBody(Sema &S,
455+
FunctionDecl *KernelCallerFunc,
456+
DeclContext *KernelDecl) {
453457
llvm::SmallVector<Stmt *, 16> BodyStmts;
454458
CXXRecordDecl *LC = getKernelObjectType(KernelCallerFunc);
455459
assert(LC && "Kernel object must be available");
456460
TypeSourceInfo *TSInfo = LC->isLambda() ? LC->getLambdaTypeInfo() : nullptr;
461+
457462
// Create a local kernel object (lambda or functor) assembled from the
458463
// incoming formal parameters
459464
auto KernelObjClone = VarDecl::Create(
460-
S.Context, KernelDecl, SourceLocation(), SourceLocation(), LC->getIdentifier(),
461-
QualType(LC->getTypeForDecl(), 0), TSInfo, SC_None);
465+
S.Context, KernelDecl, SourceLocation(), SourceLocation(),
466+
LC->getIdentifier(), QualType(LC->getTypeForDecl(), 0), TSInfo, SC_None);
462467
Stmt *DS = new (S.Context) DeclStmt(DeclGroupRef(KernelObjClone),
463468
SourceLocation(), SourceLocation());
464469
BodyStmts.push_back(DS);
465470
auto KernelObjCloneRef =
466471
DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(), SourceLocation(),
467472
KernelObjClone, false, DeclarationNameInfo(),
468473
QualType(LC->getTypeForDecl(), 0), VK_LValue);
474+
469475
auto KernelFuncDecl = dyn_cast<FunctionDecl>(KernelDecl);
470476
assert(KernelFuncDecl && "No kernel function declaration?");
471477
auto KernelFuncParam =
@@ -484,11 +490,13 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *Kerne
484490
// initialize them. We create call of __init method and pass built kernel
485491
// arguments as parameters to the __init method.
486492
auto getExprForSpecialSYCLObj = [&](const QualType &paramTy,
487-
FieldDecl *Field,
488-
const CXXRecordDecl *CRD, Expr *Base) {
493+
FieldDecl *Field,
494+
const CXXRecordDecl *CRD,
495+
Expr *Base) {
489496
// All special SYCL objects must have __init method
490497
CXXMethodDecl *InitMethod = getInitMethod(CRD);
491-
assert(InitMethod && "The accessor/sampler must have the __init method");
498+
assert(InitMethod &&
499+
"The accessor/sampler must have the __init method");
492500
unsigned NumParams = InitMethod->getNumParams();
493501
llvm::SmallVector<DeclRefExpr *, 4> ParamDREs(NumParams);
494502
auto KFP = KernelFuncParam;
@@ -503,8 +511,8 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *Kerne
503511
DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none);
504512
// [kenrel_obj or wrapper object].special_obj
505513
auto AccessorME = MemberExpr::Create(
506-
S.Context, Base, false, SourceLocation(),
507-
NestedNameSpecifierLoc(), SourceLocation(), Field, FieldDAP,
514+
S.Context, Base, false, SourceLocation(), NestedNameSpecifierLoc(),
515+
SourceLocation(), Field, FieldDAP,
508516
DeclarationNameInfo(Field->getDeclName(), SourceLocation()),
509517
nullptr, Field->getType(), VK_LValue, OK_Ordinary);
510518

@@ -555,7 +563,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *Kerne
555563
// or parameter of the previous processed accessor object.
556564
KernelFuncParam++;
557565
getExprForSpecialSYCLObj(FldType, WrapperFld, WrapperFldCRD,
558-
Base);
566+
Base);
559567
} else {
560568
// Field is a structure or class so change the wrapper object
561569
// and recursively search for accessor field.
@@ -574,13 +582,26 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *Kerne
574582
}
575583
};
576584

585+
// Run through kernel object fields and add initialization for them using
586+
// built kernel parameters. There are a several possible cases:
587+
// - Kernel object field is a SYCL special object (SYCL accessor or SYCL
588+
// sampler). These objects has a special initialization scheme - using
589+
// __init method.
590+
// - Kernel object field has a scalar type. In this case we should add
591+
// simple initialization using binary '=' operator.
592+
// - Kernel object field has a structure or class type. Same handling as
593+
// a scalar but we should check if this structure/class contains
594+
// accessors and add initialization for them properly.
577595
QualType FieldType = Field->getType();
578596
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl();
579597
if (Util::isSyclAccessorType(FieldType) ||
580598
Util::isSyclSamplerType(FieldType)) {
581599
getExprForSpecialSYCLObj(FieldType, Field, CRD, KernelObjCloneRef);
600+
} else if (Util::isSyclStreamType(FieldType)) {
601+
// TODO add support for streams
602+
llvm_unreachable("Streams not supported yet");
582603
} else if (CRD || FieldType->isScalarType()) {
583-
// If field have built-in or a structure/class type just initialize
604+
// If field has built-in or a structure/class type just initialize
584605
// this field with corresponding kernel argument using '=' binary
585606
// operator. The structure/class type must be copy assignable - this
586607
// holds because SYCL kernel lambdas capture arguments by copy.
@@ -609,14 +630,14 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *Kerne
609630
if (CRD)
610631
getExprForWrappedAccessorInit(CRD, Lhs);
611632
} else {
612-
llvm_unreachable("unsupported field type");
633+
llvm_unreachable("Unsupported field type");
613634
}
614635
KernelFuncParam++;
615636
}
616637
}
617638

618-
// In kernel caller function lambda/functior is function parameter, we need
619-
// to replace all refs to this lambda/functor with our kernel object clone
639+
// In the kernel caller function kernel object is a function parameter, so we
640+
// need to replace all refs to this kernel oject with refs to our clone
620641
// declared inside kernel body.
621642
Stmt *FunctionBody = KernelCallerFunc->getBody();
622643
ParmVarDecl *KernelObjParam = *(KernelCallerFunc->param_begin());
@@ -654,22 +675,26 @@ static target getAccessTarget(const ClassTemplateSpecializationDecl *AccTy) {
654675
AccTy->getTemplateArgs()[3].getAsIntegral().getExtValue());
655676
}
656677

678+
// Creates list of kernel parameters descriptors using KernelObj (kernel object)
679+
// Fields of kernel object must be initialized with SYCL kernel arguments so
680+
// in the following function we extract types of kernel object fields and add it
681+
// to the array with kernel parameters descriptors.
657682
static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
658683
SmallVectorImpl<ParamDesc> &ParamDescs) {
659684
const LambdaCapture *Cpt = KernelObj->captures_begin();
660685
auto CreateAndAddPrmDsc = [&](const FieldDecl *Fld, const QualType &ArgType) {
661-
// create a parameter descriptor and append it to the result
686+
// Create a parameter descriptor and append it to the result
662687
ParamDescs.push_back(makeParamDesc(Fld, ArgType));
663688
};
664689

665-
// Create a parameter descriptor for SYCL special object - SYCL accessor or
690+
// Creates a parameter descriptor for SYCL special object - SYCL accessor or
666691
// sampler.
667692
// All special SYCL objects must have __init method. We extract types for
668693
// kernel parameters from __init method parameters. We will use __init method
669694
// and kernel parameters which we build here to initialize special objects in
670695
// the kernel body.
671696
auto createSpecialSYCLObjParamDesc = [&](const FieldDecl *Fld,
672-
const QualType &ArgTy) {
697+
const QualType &ArgTy) {
673698
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
674699
assert(RecordDecl && "Special SYCL object must be of a record type");
675700

@@ -682,6 +707,8 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
682707
}
683708
};
684709

710+
// Create parameter descriptor for accessor in case when it's wrapped with
711+
// some class.
685712
// TODO: Do we need support case when sampler is wrapped with some class or
686713
// struct?
687714
std::function<void(const FieldDecl *, const QualType &ArgTy)>
@@ -703,27 +730,39 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
703730
}
704731
};
705732

733+
// Run through kernel object fields and create corresponding kernel
734+
// parameters descriptors. There are a several possible cases:
735+
// - Kernel object field is a SYCL special object (SYCL accessor or SYCL
736+
// sampler). These objects has a special initialization scheme - using
737+
// __init method.
738+
// - Kernel object field has a scalar type. In this case we should add
739+
// kernel parameter with the same type.
740+
// - Kernel object field has a structure or class type. Same handling as a
741+
// scalar but we should check if this structure/class contains accessors
742+
// and add parameter decriptor for them properly.
706743
for (const auto *Fld : KernelObj->fields()) {
707744
QualType ArgTy = Fld->getType();
708745
if (Util::isSyclAccessorType(ArgTy) || Util::isSyclSamplerType(ArgTy)) {
709746
createSpecialSYCLObjParamDesc(Fld, ArgTy);
710747
} else if (ArgTy->isStructureOrClassType()) {
748+
// SYCL v1.2.1 s4.8.10 p5:
749+
// C++ non-standard layout values must not be passed as arguments to a
750+
// kernel that is compiled for a device.
711751
if (!ArgTy->isStandardLayoutType()) {
712752
const DeclaratorDecl *V =
713753
Cpt ? cast<DeclaratorDecl>(Cpt->getCapturedVar())
714754
: cast<DeclaratorDecl>(Fld);
715755
KernelObj->getASTContext().getDiagnostics().Report(
716756
V->getLocation(), diag::err_sycl_non_std_layout_type);
717757
}
718-
// structure or class typed parameter - the same handling as a scalar
719758
CreateAndAddPrmDsc(Fld, ArgTy);
720-
// create descriptors for each accessor field in the class or struct
759+
760+
// Create descriptors for each accessor field in the class or struct
721761
createParamDescForWrappedAccessors(Fld, ArgTy);
722762
} else if (ArgTy->isScalarType()) {
723-
// scalar typed parameter
724763
CreateAndAddPrmDsc(Fld, ArgTy);
725764
} else {
726-
llvm_unreachable("unsupported kernel parameter type");
765+
llvm_unreachable("Unsupported kernel parameter type");
727766
}
728767
}
729768
}
@@ -743,18 +782,18 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
743782
H.startKernel(Name, NameType);
744783

745784
auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) {
746-
// The parameter is a SYCL accessor object.
747-
// The Info field of the parameter descriptor for accessor contains
748-
// two template parameters packed into thid integer field:
749-
// - target (e.g. global_buffer, constant_buffer, local);
750-
// - dimension of the accessor.
751-
const auto *AccTy = ArgTy->getAsCXXRecordDecl();
752-
assert(AccTy && "accessor must be of a record type");
753-
const auto *AccTmplTy = cast<ClassTemplateSpecializationDecl>(AccTy);
754-
int Dims = static_cast<int>(
755-
AccTmplTy->getTemplateArgs()[1].getAsIntegral().getExtValue());
756-
int Info = getAccessTarget(AccTmplTy) | (Dims << 11);
757-
H.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, Offset);
785+
// The parameter is a SYCL accessor object.
786+
// The Info field of the parameter descriptor for accessor contains
787+
// two template parameters packed into thid integer field:
788+
// - target (e.g. global_buffer, constant_buffer, local);
789+
// - dimension of the accessor.
790+
const auto *AccTy = ArgTy->getAsCXXRecordDecl();
791+
assert(AccTy && "accessor must be of a record type");
792+
const auto *AccTmplTy = cast<ClassTemplateSpecializationDecl>(AccTy);
793+
int Dims = static_cast<int>(
794+
AccTmplTy->getTemplateArgs()[1].getAsIntegral().getExtValue());
795+
int Info = getAccessTarget(AccTmplTy) | (Dims << 11);
796+
H.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, Offset);
758797
};
759798

760799
std::function<void(const QualType &, uint64_t Offset)>
@@ -847,22 +886,48 @@ static std::string constructKernelName(QualType KernelNameType,
847886
return Out.str();
848887
}
849888

889+
// Generates the "kernel wrapper" using KernelCallerFunc (kernel caller
890+
// function) defined is SYCL headers.
891+
// A "kernel wrapper" function contains the body of the kernel caller function,
892+
// receives OpenCL like parameters and additionally does some manipulation to
893+
// initialize captured lambda/functor fields with these parameters.
894+
// SYCL runtime marks kernel caller function with sycl_kernel attribute.
895+
// To be able to generate "kernel wrapper" from KernelCallerFunc we put
896+
// the following requirements to the function which SYCL runtime can mark with
897+
// sycl_kernel attribute:
898+
// - Must be template function with at least two template parameters.
899+
// First parameter must represent "unique kernel name"
900+
// Second parameter must be the function object type
901+
// - Must have only one function parameter - function object.
902+
//
903+
// Example of kernel caller function:
904+
// template <typename KernelName, typename KernelType/*, ...*/>
905+
// __attribute__((sycl_kernel)) void kernel_caller_function(KernelType
906+
// KernelFuncObj) {
907+
// KernelFuncObj();
908+
// }
909+
//
910+
// In the code below we call "kernel wrapper" SYCLKernel.
911+
//
850912
void Sema::ConstructSYCLKernel(FunctionDecl *KernelCallerFunc) {
851-
// TODO: Case when kernel is functor
852913
CXXRecordDecl *LE = getKernelObjectType(KernelCallerFunc);
853914
assert(LE && "invalid kernel caller");
915+
916+
// Build list of kernel arguments
854917
llvm::SmallVector<ParamDesc, 16> ParamDescs;
855918
buildArgTys(getASTContext(), LE, ParamDescs);
856-
// Get Name for our kernel.
919+
920+
// Extract name from kernel caller parameters and mangle it.
857921
const TemplateArgumentList *TemplateArgs =
858922
KernelCallerFunc->getTemplateSpecializationArgs();
859923
assert(TemplateArgs && "No template argument info");
860-
// The first template argument always describes the kernel name - whether
861-
// it is lambda or functor.
862924
QualType KernelNameType = TypeName::getFullyQualifiedType(
863925
TemplateArgs->get(0).getAsType(), getASTContext(), true);
864926
std::string Name = constructKernelName(KernelNameType, getASTContext());
927+
928+
// TODO Maybe don't emit integration header inside the Sema?
865929
populateIntHeader(getSyclIntegrationHeader(), Name, KernelNameType, LE);
930+
866931
FunctionDecl *SYCLKernel =
867932
CreateSYCLKernelDeclaration(getASTContext(), Name, ParamDescs);
868933

0 commit comments

Comments
 (0)