@@ -1109,6 +1109,21 @@ static bool isFreeFunction(SemaSYCL &SemaSYCLRef, const FunctionDecl *FD) {
1109
1109
return false ;
1110
1110
}
1111
1111
1112
+ static int getFreeFunctionRangeDim (SemaSYCL &SemaSYCLRef,
1113
+ const FunctionDecl *FD) {
1114
+ for (auto *IRAttr : FD->specific_attrs <SYCLAddIRAttributesFunctionAttr>()) {
1115
+ SmallVector<std::pair<std::string, std::string>, 4 > NameValuePairs =
1116
+ IRAttr->getAttributeNameValuePairs (SemaSYCLRef.getASTContext ());
1117
+ for (const auto &NameValuePair : NameValuePairs) {
1118
+ if (NameValuePair.first == " sycl-nd-range-kernel" )
1119
+ return std::stoi (NameValuePair.second );
1120
+ if (NameValuePair.first == " sycl-single-task-kernel" )
1121
+ return 0 ;
1122
+ }
1123
+ }
1124
+ return false ;
1125
+ }
1126
+
1112
1127
// Creates a name for the free function kernel function.
1113
1128
// Consider a free function named "MyFunction". The normal device function will
1114
1129
// be given its mangled name, say "_Z10MyFunctionIiEvPT_S0_". The corresponding
@@ -2568,7 +2583,6 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler {
2568
2583
2569
2584
// A type to Create and own the FunctionDecl for the kernel.
2570
2585
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
2571
- bool IsFreeFunction = false ;
2572
2586
FunctionDecl *KernelDecl = nullptr ;
2573
2587
llvm::SmallVector<ParmVarDecl *, 8 > Params;
2574
2588
Sema::ContextRAII FuncContext;
@@ -2788,9 +2802,8 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
2788
2802
public:
2789
2803
static constexpr const bool VisitInsideSimpleContainers = false ;
2790
2804
SyclKernelDeclCreator (SemaSYCL &S, SourceLocation Loc, bool IsInline,
2791
- bool IsSIMDKernel, bool IsFreeFunction,
2792
- FunctionDecl *SYCLKernel)
2793
- : SyclKernelFieldHandler(S), IsFreeFunction(IsFreeFunction),
2805
+ bool IsSIMDKernel, FunctionDecl *SYCLKernel)
2806
+ : SyclKernelFieldHandler(S),
2794
2807
KernelDecl (
2795
2808
createKernelDecl (S.getASTContext(), Loc, IsInline, IsSIMDKernel)),
2796
2809
FuncContext(SemaSYCLRef.SemaRef, KernelDecl) {
@@ -5110,7 +5123,7 @@ void SemaSYCL::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
5110
5123
5111
5124
SyclKernelDeclCreator kernel_decl (*this , KernelObj->getLocation (),
5112
5125
KernelCallerFunc->isInlined (), IsSIMDKernel,
5113
- false /* IsFreeFunction */ , KernelCallerFunc);
5126
+ KernelCallerFunc);
5114
5127
SyclKernelBodyCreator kernel_body (*this , kernel_decl, KernelObj,
5115
5128
KernelCallerFunc, IsSIMDKernel,
5116
5129
CallOperator);
@@ -5152,7 +5165,7 @@ void ConstructFreeFunctionKernel(SemaSYCL &SemaSYCLRef, FunctionDecl *FD) {
5152
5165
false /* IsSIMDKernel*/ );
5153
5166
SyclKernelDeclCreator kernel_decl (SemaSYCLRef, FD->getLocation (),
5154
5167
FD->isInlined (), false /* IsSIMDKernel */ ,
5155
- true /* IsFreeFunction */ , FD);
5168
+ FD);
5156
5169
5157
5170
FreeFunctionKernelBodyCreator kernel_body (SemaSYCLRef, kernel_decl, FD);
5158
5171
@@ -6052,6 +6065,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6052
6065
6053
6066
O << " #include <sycl/detail/defines_elementary.hpp>\n " ;
6054
6067
O << " #include <sycl/detail/kernel_desc.hpp>\n " ;
6068
+ O << " #include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n " ;
6055
6069
6056
6070
O << " \n " ;
6057
6071
@@ -6301,7 +6315,102 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6301
6315
O << " } // namespace detail\n " ;
6302
6316
O << " } // namespace _V1\n " ;
6303
6317
O << " } // namespace sycl\n " ;
6304
- O << " \n " ;
6318
+
6319
+ unsigned ShimCounter = 1 ;
6320
+ int FreeFunctionCount = 0 ;
6321
+ for (const KernelDesc &K : KernelDescs) {
6322
+ if (!isFreeFunction (S, K.SyclKernel ))
6323
+ continue ;
6324
+
6325
+ ++FreeFunctionCount;
6326
+ // Generate forward declaration for free function.
6327
+ O << " \n // Definition of " << K.Name << " as a free function kernel\n " ;
6328
+ if (K.SyclKernel ->getLanguageLinkage () == CLanguageLinkage)
6329
+ O << " extern \" C\" " ;
6330
+ std::string ParmList;
6331
+ bool FirstParam = true ;
6332
+ for (ParmVarDecl *Param : K.SyclKernel ->parameters ()) {
6333
+ if (FirstParam)
6334
+ FirstParam = false ;
6335
+ else
6336
+ ParmList += " , " ;
6337
+ ParmList += Param->getType ().getCanonicalType ().getAsString ();
6338
+ }
6339
+ FunctionTemplateDecl *FTD = K.SyclKernel ->getPrimaryTemplate ();
6340
+ Policy.SuppressDefinition = true ;
6341
+ Policy.PolishForDeclaration = true ;
6342
+ if (FTD) {
6343
+ FTD->print (O, Policy);
6344
+ } else {
6345
+ K.SyclKernel ->print (O, Policy);
6346
+ }
6347
+ O << " ;\n " ;
6348
+
6349
+ // Generate a shim function that returns the address of the free function.
6350
+ O << " static constexpr auto __sycl_shim" << ShimCounter << " () {\n " ;
6351
+ O << " return (void (*)(" << ParmList << " ))"
6352
+ << K.SyclKernel ->getIdentifier ()->getName ().data ();
6353
+ if (FTD) {
6354
+ const TemplateArgumentList *TAL =
6355
+ K.SyclKernel ->getTemplateSpecializationArgs ();
6356
+ ArrayRef<TemplateArgument> A = TAL->asArray ();
6357
+ bool FirstParam = true ;
6358
+ O << " <" ;
6359
+ for (auto X : A) {
6360
+ if (FirstParam)
6361
+ FirstParam = false ;
6362
+ else
6363
+ O << " , " ;
6364
+ X.print (Policy, O, true );
6365
+ }
6366
+ O << " >" ;
6367
+ }
6368
+ O << " ;\n " ;
6369
+ O << " }\n " ;
6370
+
6371
+ // Generate is_kernel, is_single_task_kernel and nd_range_kernel functions.
6372
+ O << " namespace sycl {\n " ;
6373
+ O << " template <>\n " ;
6374
+ O << " struct ext::oneapi::experimental::is_kernel<__sycl_shim"
6375
+ << ShimCounter << " ()" ;
6376
+ O << " > {\n " ;
6377
+ O << " static constexpr bool value = true;\n " ;
6378
+ O << " };\n " ;
6379
+ int Dim = getFreeFunctionRangeDim (S, K.SyclKernel );
6380
+ O << " template <>\n " ;
6381
+ if (Dim > 0 )
6382
+ O << " struct ext::oneapi::experimental::is_nd_range_kernel<__sycl_shim"
6383
+ << ShimCounter << " (), " << Dim;
6384
+ else
6385
+ O << " struct "
6386
+ " ext::oneapi::experimental::is_single_task_kernel<__sycl_shim"
6387
+ << ShimCounter << " ()" ;
6388
+ O << " > {\n " ;
6389
+ O << " static constexpr bool value = true;\n " ;
6390
+ O << " };\n " ;
6391
+ O << " }\n " ;
6392
+ ++ShimCounter;
6393
+ }
6394
+
6395
+ if (FreeFunctionCount > 0 ) {
6396
+ O << " \n #include <sycl/kernel_bundle.hpp>\n " ;
6397
+ }
6398
+ ShimCounter = 1 ;
6399
+ for (const KernelDesc &K : KernelDescs) {
6400
+ if (!isFreeFunction (S, K.SyclKernel ))
6401
+ continue ;
6402
+
6403
+ O << " \n // Definition of kernel_id of " << K.Name << " \n " ;
6404
+ O << " namespace sycl {\n " ;
6405
+ O << " template <>\n " ;
6406
+ O << " kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim"
6407
+ << ShimCounter << " ()>() {\n " ;
6408
+ O << " return sycl::detail::get_kernel_id_impl(std::string_view{\" "
6409
+ << K.Name << " \" });\n " ;
6410
+ O << " }\n " ;
6411
+ O << " }\n " ;
6412
+ ++ShimCounter;
6413
+ }
6305
6414
}
6306
6415
6307
6416
bool SYCLIntegrationHeader::emit (StringRef IntHeaderName) {
0 commit comments