Skip to content

Commit b38ec2a

Browse files
authored
[SYCL] Free function traits and APIs (#13885)
This PR implements the traits and other APIs for free function support. The free function spec is here: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc. Later PRs will add the remaining parts of the spec. This is the current status: 1. Free functions are supported at file scope only. 2. The SYCL_EXTERNAL markup is needed for free functions in addition to the SYCL_EXT_ONEAPI_FUNCTION_PROPERTY property. 3. The compiler does not yet diagnose an error if the application violates any of the restrictions listed in the specification under the section "Defining a free function kernel". 4. Device code generation is supported for scalars and USM pointers only. It is not supported for complex kernel argument types requiring decomposition like accessor, local_accessor, or stream. 5. The implementation has not been tested to handle the case when a kernel argument is optimized away. The switch -fno-sycl-dead-args-optimization could be used to disable this optimization, if needed 6. The kernel information descriptor info::kernel::num_args cannot yet be used to query the number of arguments in a free function kernel.
1 parent c2ebf84 commit b38ec2a

File tree

9 files changed

+708
-63
lines changed

9 files changed

+708
-63
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,21 @@ static bool isFreeFunction(SemaSYCL &SemaSYCLRef, const FunctionDecl *FD) {
11091109
return false;
11101110
}
11111111

1112+
static int getFreeFunctionRangeDim(SemaSYCL &SemaSYCLRef,
1113+
const FunctionDecl *FD) {
1114+
for (auto *IRAttr : FD->specific_attrs<SYCLAddIRAttributesFunctionAttr>()) {
1115+
SmallVector<std::pair<std::string, std::string>, 4> NameValuePairs =
1116+
IRAttr->getAttributeNameValuePairs(SemaSYCLRef.getASTContext());
1117+
for (const auto &NameValuePair : NameValuePairs) {
1118+
if (NameValuePair.first == "sycl-nd-range-kernel")
1119+
return std::stoi(NameValuePair.second);
1120+
if (NameValuePair.first == "sycl-single-task-kernel")
1121+
return 0;
1122+
}
1123+
}
1124+
return false;
1125+
}
1126+
11121127
// Creates a name for the free function kernel function.
11131128
// Consider a free function named "MyFunction". The normal device function will
11141129
// be given its mangled name, say "_Z10MyFunctionIiEvPT_S0_". The corresponding
@@ -2568,7 +2583,6 @@ class SyclKernelPointerHandler : public SyclKernelFieldHandler {
25682583

25692584
// A type to Create and own the FunctionDecl for the kernel.
25702585
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
2571-
bool IsFreeFunction = false;
25722586
FunctionDecl *KernelDecl = nullptr;
25732587
llvm::SmallVector<ParmVarDecl *, 8> Params;
25742588
Sema::ContextRAII FuncContext;
@@ -2788,9 +2802,8 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
27882802
public:
27892803
static constexpr const bool VisitInsideSimpleContainers = false;
27902804
SyclKernelDeclCreator(SemaSYCL &S, SourceLocation Loc, bool IsInline,
2791-
bool IsSIMDKernel, bool IsFreeFunction,
2792-
FunctionDecl *SYCLKernel)
2793-
: SyclKernelFieldHandler(S), IsFreeFunction(IsFreeFunction),
2805+
bool IsSIMDKernel, FunctionDecl *SYCLKernel)
2806+
: SyclKernelFieldHandler(S),
27942807
KernelDecl(
27952808
createKernelDecl(S.getASTContext(), Loc, IsInline, IsSIMDKernel)),
27962809
FuncContext(SemaSYCLRef.SemaRef, KernelDecl) {
@@ -5110,7 +5123,7 @@ void SemaSYCL::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
51105123

51115124
SyclKernelDeclCreator kernel_decl(*this, KernelObj->getLocation(),
51125125
KernelCallerFunc->isInlined(), IsSIMDKernel,
5113-
false /*IsFreeFunction*/, KernelCallerFunc);
5126+
KernelCallerFunc);
51145127
SyclKernelBodyCreator kernel_body(*this, kernel_decl, KernelObj,
51155128
KernelCallerFunc, IsSIMDKernel,
51165129
CallOperator);
@@ -5152,7 +5165,7 @@ void ConstructFreeFunctionKernel(SemaSYCL &SemaSYCLRef, FunctionDecl *FD) {
51525165
false /*IsSIMDKernel*/);
51535166
SyclKernelDeclCreator kernel_decl(SemaSYCLRef, FD->getLocation(),
51545167
FD->isInlined(), false /*IsSIMDKernel */,
5155-
true /*IsFreeFunction*/, FD);
5168+
FD);
51565169

51575170
FreeFunctionKernelBodyCreator kernel_body(SemaSYCLRef, kernel_decl, FD);
51585171

@@ -6052,6 +6065,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
60526065

60536066
O << "#include <sycl/detail/defines_elementary.hpp>\n";
60546067
O << "#include <sycl/detail/kernel_desc.hpp>\n";
6068+
O << "#include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n";
60556069

60566070
O << "\n";
60576071

@@ -6301,7 +6315,102 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
63016315
O << "} // namespace detail\n";
63026316
O << "} // namespace _V1\n";
63036317
O << "} // namespace sycl\n";
6304-
O << "\n";
6318+
6319+
unsigned ShimCounter = 1;
6320+
int FreeFunctionCount = 0;
6321+
for (const KernelDesc &K : KernelDescs) {
6322+
if (!isFreeFunction(S, K.SyclKernel))
6323+
continue;
6324+
6325+
++FreeFunctionCount;
6326+
// Generate forward declaration for free function.
6327+
O << "\n// Definition of " << K.Name << " as a free function kernel\n";
6328+
if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage)
6329+
O << "extern \"C\" ";
6330+
std::string ParmList;
6331+
bool FirstParam = true;
6332+
for (ParmVarDecl *Param : K.SyclKernel->parameters()) {
6333+
if (FirstParam)
6334+
FirstParam = false;
6335+
else
6336+
ParmList += ", ";
6337+
ParmList += Param->getType().getCanonicalType().getAsString();
6338+
}
6339+
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
6340+
Policy.SuppressDefinition = true;
6341+
Policy.PolishForDeclaration = true;
6342+
if (FTD) {
6343+
FTD->print(O, Policy);
6344+
} else {
6345+
K.SyclKernel->print(O, Policy);
6346+
}
6347+
O << ";\n";
6348+
6349+
// Generate a shim function that returns the address of the free function.
6350+
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
6351+
O << " return (void (*)(" << ParmList << "))"
6352+
<< K.SyclKernel->getIdentifier()->getName().data();
6353+
if (FTD) {
6354+
const TemplateArgumentList *TAL =
6355+
K.SyclKernel->getTemplateSpecializationArgs();
6356+
ArrayRef<TemplateArgument> A = TAL->asArray();
6357+
bool FirstParam = true;
6358+
O << "<";
6359+
for (auto X : A) {
6360+
if (FirstParam)
6361+
FirstParam = false;
6362+
else
6363+
O << ", ";
6364+
X.print(Policy, O, true);
6365+
}
6366+
O << ">";
6367+
}
6368+
O << ";\n";
6369+
O << "}\n";
6370+
6371+
// Generate is_kernel, is_single_task_kernel and nd_range_kernel functions.
6372+
O << "namespace sycl {\n";
6373+
O << "template <>\n";
6374+
O << "struct ext::oneapi::experimental::is_kernel<__sycl_shim"
6375+
<< ShimCounter << "()";
6376+
O << "> {\n";
6377+
O << " static constexpr bool value = true;\n";
6378+
O << "};\n";
6379+
int Dim = getFreeFunctionRangeDim(S, K.SyclKernel);
6380+
O << "template <>\n";
6381+
if (Dim > 0)
6382+
O << "struct ext::oneapi::experimental::is_nd_range_kernel<__sycl_shim"
6383+
<< ShimCounter << "(), " << Dim;
6384+
else
6385+
O << "struct "
6386+
"ext::oneapi::experimental::is_single_task_kernel<__sycl_shim"
6387+
<< ShimCounter << "()";
6388+
O << "> {\n";
6389+
O << " static constexpr bool value = true;\n";
6390+
O << "};\n";
6391+
O << "}\n";
6392+
++ShimCounter;
6393+
}
6394+
6395+
if (FreeFunctionCount > 0) {
6396+
O << "\n#include <sycl/kernel_bundle.hpp>\n";
6397+
}
6398+
ShimCounter = 1;
6399+
for (const KernelDesc &K : KernelDescs) {
6400+
if (!isFreeFunction(S, K.SyclKernel))
6401+
continue;
6402+
6403+
O << "\n// Definition of kernel_id of " << K.Name << "\n";
6404+
O << "namespace sycl {\n";
6405+
O << "template <>\n";
6406+
O << "kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim"
6407+
<< ShimCounter << "()>() {\n";
6408+
O << " return sycl::detail::get_kernel_id_impl(std::string_view{\""
6409+
<< K.Name << "\"});\n";
6410+
O << "}\n";
6411+
O << "}\n";
6412+
++ShimCounter;
6413+
}
63056414
}
63066415

63076416
bool SYCLIntegrationHeader::emit(StringRef IntHeaderName) {

clang/test/CodeGenSYCL/free_function_int_header.cpp

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,126 @@ template <> void ff_3<double>(double *ptr, double start, double end) {
8282
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 8, 16 },
8383

8484
// CHECK: { kernel_param_kind_t::kind_invalid, -987654321, -987654321 },
85-
// CHECK-NEXT: };
85+
// CHECK-NEXT: };
86+
87+
// CHECK: Definition of _Z18__sycl_kernel_ff_2Piii as a free function kernel
88+
// CHECK-NEXT: void ff_2(int *ptr, int start, int end);
89+
// CHECK-NEXT: static constexpr auto __sycl_shim1() {
90+
// CHECK-NEXT: return (void (*)(int *, int, int))ff_2;
91+
// CHECK-NEXT: }
92+
// CHECK-NEXT: namespace sycl {
93+
// CHECK-NEXT: template <>
94+
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim1()> {
95+
// CHECK-NEXT: static constexpr bool value = true;
96+
// CHECK-NEXT: };
97+
// CHECK-NEXT: template <>
98+
// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim1()> {
99+
// CHECK-NEXT: static constexpr bool value = true;
100+
// CHECK-NEXT: };
101+
// CHECK-NEXT: }
102+
103+
// CHECK: Definition of _Z18__sycl_kernel_ff_2Piiii as a free function kernel
104+
// CHECK-NEXT: void ff_2(int *ptr, int start, int end, int value);
105+
// CHECK-NEXT: static constexpr auto __sycl_shim2() {
106+
// CHECK-NEXT: return (void (*)(int *, int, int, int))ff_2;
107+
// CHECK-NEXT: }
108+
// CHECK-NEXT: namespace sycl {
109+
// CHECK-NEXT: template <>
110+
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim2()> {
111+
// CHECK-NEXT: static constexpr bool value = true;
112+
// CHECK-NEXT: };
113+
// CHECK-NEXT: template <>
114+
// CHECK-NEXT: struct ext::oneapi::experimental::is_single_task_kernel<__sycl_shim2()> {
115+
// CHECK-NEXT: static constexpr bool value = true;
116+
// CHECK-NEXT: };
117+
// CHECK-NEXT: }
118+
119+
// CHECK: Definition of _Z18__sycl_kernel_ff_3IiEvPT_S0_S0_ as a free function kernel
120+
// CHECK-NEXT: template <typename T> void ff_3(T *ptr, T start, T end);
121+
// CHECK-NEXT: static constexpr auto __sycl_shim3() {
122+
// CHECK-NEXT: return (void (*)(int *, int, int))ff_3<int>;
123+
// CHECK-NEXT: }
124+
// CHECK-NEXT: namespace sycl {
125+
// CHECK-NEXT: template <>
126+
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim3()> {
127+
// CHECK-NEXT: static constexpr bool value = true;
128+
// CHECK-NEXT: };
129+
// CHECK-NEXT: template <>
130+
// CHECK-NEXT: struct ext::oneapi::experimental::is_nd_range_kernel<__sycl_shim3(), 2> {
131+
// CHECK-NEXT: static constexpr bool value = true;
132+
// CHECK-NEXT: };
133+
// CHECK-NEXT: }
134+
135+
// CHECK: Definition of _Z18__sycl_kernel_ff_3IfEvPT_S0_S0_ as a free function kernel
136+
// CHECK-NEXT: template <typename T> void ff_3(T *ptr, T start, T end);
137+
// CHECK-NEXT: static constexpr auto __sycl_shim4() {
138+
// CHECK-NEXT: return (void (*)(float *, float, float))ff_3<float>;
139+
// CHECK-NEXT: }
140+
// CHECK-NEXT: namespace sycl {
141+
// CHECK-NEXT: template <>
142+
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim4()> {
143+
// CHECK-NEXT: static constexpr bool value = true;
144+
// CHECK-NEXT: };
145+
// CHECK-NEXT: template <>
146+
// CHECK-NEXT: struct ext::oneapi::experimental::is_nd_range_kernel<__sycl_shim4(), 2> {
147+
// CHECK-NEXT: static constexpr bool value = true;
148+
// CHECK-NEXT: };
149+
// CHECK-NEXT: }
150+
151+
// CHECK: Definition of _Z18__sycl_kernel_ff_3IdEvPT_S0_S0_ as a free function kernel
152+
// CHECK-NEXT: template <typename T> void ff_3(T *ptr, T start, T end);
153+
// CHECK-NEXT: static constexpr auto __sycl_shim5() {
154+
// CHECK-NEXT: return (void (*)(double *, double, double))ff_3<double>;
155+
// CHECK-NEXT: }
156+
// CHECK-NEXT: namespace sycl {
157+
// CHECK-NEXT: template <>
158+
// CHECK-NEXT: struct ext::oneapi::experimental::is_kernel<__sycl_shim5()> {
159+
// CHECK-NEXT: static constexpr bool value = true;
160+
// CHECK-NEXT: };
161+
// CHECK-NEXT: template <>
162+
// CHECK-NEXT: struct ext::oneapi::experimental::is_nd_range_kernel<__sycl_shim5(), 2> {
163+
// CHECK-NEXT: static constexpr bool value = true;
164+
// CHECK-NEXT: };
165+
// CHECK-NEXT: }
166+
167+
// CHECK: #include <sycl/kernel_bundle.hpp>
168+
169+
// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_2Piii
170+
// CHECK-NEXT: namespace sycl {
171+
// CHECK-NEXT: template <>
172+
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim1()>() {
173+
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z18__sycl_kernel_ff_2Piii"});
174+
// CHECK-NEXT: }
175+
// CHECK-NEXT: }
176+
177+
// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_2Piiii
178+
// CHECK-NEXT: namespace sycl {
179+
// CHECK-NEXT: template <>
180+
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim2()>() {
181+
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z18__sycl_kernel_ff_2Piiii"});
182+
// CHECK-NEXT: }
183+
// CHECK-NEXT: }
184+
185+
// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_3IiEvPT_S0_S0_
186+
// CHECK-NEXT: namespace sycl {
187+
// CHECK-NEXT: template <>
188+
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim3()>() {
189+
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z18__sycl_kernel_ff_3IiEvPT_S0_S0_"});
190+
// CHECK-NEXT: }
191+
// CHECK-NEXT: }
192+
193+
// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_3IfEvPT_S0_S0_
194+
// CHECK-NEXT: namespace sycl {
195+
// CHECK-NEXT: template <>
196+
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim4()>() {
197+
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z18__sycl_kernel_ff_3IfEvPT_S0_S0_"});
198+
// CHECK-NEXT: }
199+
// CHECK-NEXT: }
200+
201+
// CHECK: Definition of kernel_id of _Z18__sycl_kernel_ff_3IdEvPT_S0_S0_
202+
// CHECK-NEXT: namespace sycl {
203+
// CHECK-NEXT: template <>
204+
// CHECK-NEXT: kernel_id ext::oneapi::experimental::get_kernel_id<__sycl_shim5()>() {
205+
// CHECK-NEXT: return sycl::detail::get_kernel_id_impl(std::string_view{"_Z18__sycl_kernel_ff_3IdEvPT_S0_S0_"});
206+
// CHECK-NEXT: }
207+
// CHECK-NEXT: }
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//==-------- free_function_traits.hpp - SYCL free function queries --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
namespace sycl {
12+
inline namespace _V1 {
13+
namespace ext::oneapi::experimental {
14+
15+
template <auto *Func, int Dims> struct is_nd_range_kernel {
16+
static constexpr bool value = false;
17+
};
18+
19+
template <auto *Func> struct is_single_task_kernel {
20+
static constexpr bool value = false;
21+
};
22+
23+
template <auto *Func, int Dims>
24+
inline constexpr bool is_nd_range_kernel_v =
25+
is_nd_range_kernel<Func, Dims>::value;
26+
27+
template <auto *Func>
28+
inline constexpr bool is_single_task_kernel_v =
29+
is_single_task_kernel<Func>::value;
30+
31+
template <auto *Func> struct is_kernel {
32+
static constexpr bool value = false;
33+
};
34+
35+
template <auto *Func>
36+
inline constexpr bool is_kernel_v = is_kernel<Func>::value;
37+
38+
} // namespace ext::oneapi::experimental
39+
} // namespace _V1
40+
} // namespace sycl
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//==-------- free_function_traits.hpp - SYCL free function queries --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
namespace sycl {
12+
inline namespace _V1 {
13+
namespace ext::oneapi::experimental {
14+
15+
template <auto *Func, int Dims> struct is_nd_range_kernel {
16+
static constexpr bool value = false;
17+
};
18+
19+
template <auto *Func> struct is_single_task_kernel {
20+
static constexpr bool value = false;
21+
};
22+
23+
template <auto *Func, int Dims>
24+
inline constexpr bool is_nd_range_kernel_v =
25+
is_nd_range_kernel<Func, Dims>::value;
26+
27+
template <auto *Func>
28+
inline constexpr bool is_single_task_kernel_v =
29+
is_single_task_kernel<Func>::value;
30+
31+
template <auto *Func> struct is_kernel {
32+
static constexpr bool value = false;
33+
};
34+
35+
template <auto *Func>
36+
inline constexpr bool is_kernel_v = is_kernel<Func>::value;
37+
38+
} // namespace ext::oneapi::experimental
39+
} // namespace _V1
40+
} // namespace sycl

0 commit comments

Comments
 (0)