Skip to content

Commit 7f76000

Browse files
[SYCL] Emit forward declaration of template specialization into integration header (#18929)
This PR fixes bug with templated kernel free function and its specialization, i.e. forward declaration of specialization is emitted into integration header too.
1 parent 332430a commit 7f76000

File tree

6 files changed

+226
-27
lines changed

6 files changed

+226
-27
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6617,9 +6617,18 @@ class FreeFunctionPrinter {
66176617
// function
66186618
NSInserted = true;
66196619
}
6620+
if (FD->isFunctionTemplateSpecialization() &&
6621+
FD->isThisDeclarationADefinition())
6622+
O << "template <> ";
66206623
O << TemplateParameters;
66216624
O << FD->getReturnType().getAsString() << " ";
6622-
O << FD->getNameAsString() << "(" << Args << ");";
6625+
FD->printName(O, Policy);
6626+
if (FD->isFunctionTemplateSpecialization() &&
6627+
FD->isThisDeclarationADefinition())
6628+
O << getTemplateSpecializationArgString(
6629+
FD->getTemplateSpecializationArgs());
6630+
6631+
O << "(" << Args << ");";
66236632
if (NSInserted) {
66246633
O << "\n";
66256634
PrintNSClosingBraces(O, FD);
@@ -6641,35 +6650,49 @@ class FreeFunctionPrinter {
66416650
if (NSInserted)
66426651
PrintNamespaces(O, FD, /*isPrintNamesOnly=*/true);
66436652
O << FD->getIdentifier()->getName().data();
6644-
if (FD->getPrimaryTemplate()) {
6645-
std::string Buffer;
6646-
llvm::raw_string_ostream StringStream(Buffer);
6647-
const TemplateArgumentList *TAL = FD->getTemplateSpecializationArgs();
6648-
ArrayRef<TemplateArgument> A = TAL->asArray();
6649-
bool FirstParam = true;
6650-
for (const auto &X : A) {
6651-
if (FirstParam)
6652-
FirstParam = false;
6653-
else if (X.getKind() == TemplateArgument::Pack) {
6654-
for (const auto &PackArg : X.pack_elements()) {
6655-
StringStream << ", ";
6656-
PackArg.print(Policy, StringStream, true);
6657-
}
6658-
continue;
6659-
} else {
6653+
if (FD->getPrimaryTemplate())
6654+
O << getTemplateSpecializationArgString(
6655+
FD->getTemplateSpecializationArgs());
6656+
}
6657+
6658+
private:
6659+
/// Helper method to get string with template types
6660+
/// \param TAL The template argument list.
6661+
/// \returns string Example:
6662+
/// \code
6663+
/// template <typename T1, typename T2>
6664+
/// void foo(T1 a, T2 b);
6665+
/// \endcode
6666+
/// returns string "<T1, T2>"
6667+
/// If TAL is nullptr, returns empty string.
6668+
std::string
6669+
getTemplateSpecializationArgString(const TemplateArgumentList *TAL) {
6670+
if (!TAL)
6671+
return "";
6672+
std::string Buffer;
6673+
llvm::raw_string_ostream StringStream(Buffer);
6674+
ArrayRef<TemplateArgument> A = TAL->asArray();
6675+
bool FirstParam = true;
6676+
for (const auto &X : A) {
6677+
if (FirstParam)
6678+
FirstParam = false;
6679+
else if (X.getKind() == TemplateArgument::Pack) {
6680+
for (const auto &PackArg : X.pack_elements()) {
66606681
StringStream << ", ";
6682+
PackArg.print(Policy, StringStream, /*IncludeType*/ true);
66616683
}
6684+
continue;
6685+
} else
6686+
StringStream << ", ";
66626687

6663-
X.print(Policy, StringStream, true);
6664-
}
6665-
StringStream.flush();
6666-
if (Buffer.front() != '<')
6667-
Buffer = "<" + Buffer + ">";
6668-
O << Buffer;
6688+
X.print(Policy, StringStream, /*IncludeType*/ true);
66696689
}
6690+
StringStream.flush();
6691+
if (Buffer.front() != '<')
6692+
Buffer = "<" + Buffer + ">";
6693+
return Buffer;
66706694
}
66716695

6672-
private:
66736696
/// Helper method to get arguments of templated function as a string
66746697
/// \param Parameters Array of parameters of the function.
66756698
/// \param Policy Printing policy.
@@ -7083,6 +7106,10 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
70837106
FreeFunctionPrinter FFPrinter(O, Policy);
70847107
if (FTD) {
70857108
FFPrinter.printFreeFunctionDeclaration(FTD, S);
7109+
if (const auto kind = K.SyclKernel->getTemplateSpecializationKind();
7110+
K.SyclKernel->isFunctionTemplateSpecialization() &&
7111+
kind == TSK_ExplicitSpecialization)
7112+
FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames);
70867113
} else {
70877114
FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames);
70887115
}

clang/test/CodeGenSYCL/free_function_int_header.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ void ff_22(AliasType start, AliasType *ptr) {
491491
// CHECK: Definition of _Z18__sycl_kernel_ff_3IdEvPT_S0_S0_ as a free function kernel
492492
// CHECK: Forward declarations of kernel and its argument types:
493493
// CHECK: template <typename T> void ff_3(T * ptr, T start, T end);
494+
// CHECK: template <> void ff_3<double>(double * ptr, double start, double end);
494495
// CHECK-NEXT: static constexpr auto __sycl_shim5() {
495496
// CHECK-NEXT: return (void (*)(double *, double, double))ff_3<double>;
496497
// CHECK-NEXT: }

sycl/test-e2e/DeviceImageDependencies/free_function_kernels.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Ensure -fsycl-allow-device-dependencies can work with free function kernels.
22

33
// REQUIRES: aspect-usm_shared_allocations
4-
// RUN: %{build} -o %t.out -fsycl-allow-device-image-dependencies
4+
// RUN: %{build} --save-temps -o %t.out -fsycl-allow-device-image-dependencies
55
// RUN: %{run} %t.out
66

77
#include <iostream>

sycl/test-e2e/FreeFunctionKernels/address_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %{build} -o %t.out
1+
// RUN: %{build} --save-temps -o %t.out
22
// RUN: %{run} %t.out
33

44
#include <sycl/detail/core.hpp>

sycl/test-e2e/FreeFunctionKernels/free_function_kernels_as_device_host_functions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// REQUIRES: aspect-usm_shared_allocations
2-
// RUN: %{build} -o %t.out
2+
// RUN: %{build} --save-temps -o %t.out
33
// RUN: %{run} %t.out
44

55
// This test verifies whether free function kernel can be used as device
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
// REQUIRES: aspect-usm_shared_allocations
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
#include <sycl/detail/core.hpp>
6+
#include <sycl/ext/oneapi/free_function_queries.hpp>
7+
#include <sycl/kernel_bundle.hpp>
8+
#include <sycl/usm.hpp>
9+
10+
using namespace sycl;
11+
12+
struct TestStruct {
13+
int x;
14+
float y;
15+
16+
TestStruct(int a, float b) : x(a), y(b) {}
17+
};
18+
19+
namespace A::B::C {
20+
class TestClass {
21+
int a;
22+
float b;
23+
24+
public:
25+
TestClass(int x, float y) : a(x), b(y) {}
26+
27+
void setA(int x) { a = x; }
28+
void setB(float y) { b = y; }
29+
};
30+
} // namespace A::B::C
31+
32+
template <typename T>
33+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
34+
(ext::oneapi::experimental::nd_range_kernel<1>))
35+
void sum(T arg) {}
36+
37+
template <>
38+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
39+
(ext::oneapi::experimental::nd_range_kernel<1>))
40+
void sum<int>(int arg) {
41+
arg = 42;
42+
}
43+
44+
template <>
45+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
46+
(ext::oneapi::experimental::nd_range_kernel<1>))
47+
void sum<int *>(int *arg) {
48+
*arg = 42;
49+
}
50+
51+
template <int, typename T>
52+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
53+
(ext::oneapi::experimental::nd_range_kernel<1>))
54+
void sum1(T arg) {}
55+
56+
template <>
57+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
58+
(ext::oneapi::experimental::nd_range_kernel<1>))
59+
void sum1<3, float>(float arg) {
60+
arg = 3.14f + static_cast<float>(3);
61+
}
62+
63+
template <> void sum<float>(float arg) { arg = 3.14f; }
64+
65+
template <> void sum<TestStruct>(TestStruct arg) {
66+
arg.x = 100;
67+
arg.y = 2.0f;
68+
}
69+
70+
template <> void sum<A::B::C::TestClass>(A::B::C::TestClass arg) {
71+
arg.setA(10);
72+
arg.setB(5.0f);
73+
}
74+
75+
template <typename T>
76+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
77+
(ext::oneapi::experimental::single_task_kernel))
78+
void F(int X) {
79+
volatile T Y = static_cast<T>(X);
80+
}
81+
82+
template <>
83+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
84+
(ext::oneapi::experimental::single_task_kernel))
85+
void F<float>(int X) {
86+
volatile float Y = static_cast<float>(X);
87+
}
88+
89+
template <typename... Args>
90+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
91+
(ext::oneapi::experimental::single_task_kernel))
92+
void variadic_templated(Args... args) {}
93+
94+
template <>
95+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
96+
(ext::oneapi::experimental::single_task_kernel))
97+
void variadic_templated<double>(double b) {
98+
b = 20.0f;
99+
}
100+
101+
template <auto *Func, typename T> void test_func() {
102+
queue Q;
103+
kernel_bundle bundle =
104+
get_kernel_bundle<bundle_state::executable>(Q.get_context());
105+
kernel_id id = ext::oneapi::experimental::get_kernel_id<Func>();
106+
kernel Kernel = bundle.get_kernel(id);
107+
Q.submit([&](handler &h) {
108+
h.set_args(static_cast<T>(4));
109+
h.parallel_for(nd_range{{1}, {1}}, Kernel);
110+
});
111+
}
112+
113+
template <typename T> void test_func_custom_type() {
114+
queue Q;
115+
kernel_bundle bundle =
116+
get_kernel_bundle<bundle_state::executable>(Q.get_context());
117+
kernel_id id = ext::oneapi::experimental::get_kernel_id<sum<T>>();
118+
kernel Kernel = bundle.get_kernel(id);
119+
Q.submit([&](handler &h) {
120+
h.set_args(T(1, 2.0f));
121+
h.parallel_for(nd_range{{1}, {1}}, Kernel);
122+
});
123+
}
124+
125+
void test_accessor() {
126+
sycl::queue Q;
127+
constexpr size_t N = 4;
128+
int data[N] = {0, 1, 2, 3};
129+
kernel_bundle bundle =
130+
get_kernel_bundle<bundle_state::executable>(Q.get_context());
131+
kernel_id id = ext::oneapi::experimental::get_kernel_id<
132+
sum1<3, sycl::accessor<int, 1>>>();
133+
kernel Kernel = bundle.get_kernel(id);
134+
sycl::buffer<int, 1> buf(data, sycl::range<1>(N));
135+
Q.submit([&](handler &h) {
136+
auto acc = buf.get_access<sycl::access::mode::write>(h);
137+
h.set_args(acc);
138+
h.parallel_for(nd_range{{1}, {1}}, Kernel);
139+
});
140+
}
141+
142+
void test_shared() {
143+
sycl::queue Q;
144+
int *data = sycl::malloc_shared<int>(4, Q);
145+
146+
kernel_bundle bundle =
147+
get_kernel_bundle<bundle_state::executable>(Q.get_context());
148+
kernel_id id = ext::oneapi::experimental::get_kernel_id<sum<int *>>();
149+
kernel Kernel = bundle.get_kernel(id);
150+
Q.submit([&](handler &h) {
151+
h.set_args(data);
152+
h.parallel_for(nd_range{{1}, {1}}, Kernel);
153+
});
154+
sycl::free(data, Q);
155+
}
156+
157+
int main() {
158+
test_func<sum<int>, int>();
159+
test_func<sum<float>, float>();
160+
test_func<sum<uint32_t>, uint32_t>();
161+
test_func<sum<char>, char>();
162+
test_func_custom_type<TestStruct>();
163+
test_func_custom_type<A::B::C::TestClass>();
164+
test_func<F<float>, float>();
165+
test_func<F<uint32_t>, uint32_t>();
166+
test_func<variadic_templated<double>, int>();
167+
test_func<sum1<3, float>, float>();
168+
test_accessor();
169+
test_shared();
170+
return 0;
171+
}

0 commit comments

Comments
 (0)