Skip to content

Commit 9e39902

Browse files
author
LU-JOHN
authored
[SYCL] Do not internalize kernels when supporting dynamic linking (#15307)
Do not internalize kernels when supporting dynamic linking. Kernels must be visible so that host code can find them. --------- Signed-off-by: Lu, John <[email protected]>
1 parent 464b077 commit 9e39902

File tree

3 files changed

+293
-6
lines changed

3 files changed

+293
-6
lines changed

llvm/lib/SYCLLowerIR/ModuleSplitter.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,11 +667,13 @@ void ModuleDesc::restoreLinkageOfDirectInvokeSimdTargets() {
667667
// the transformation safe.
668668
static bool mustPreserveGV(const GlobalValue &GV) {
669669
if (const Function *F = dyn_cast<Function>(&GV)) {
670-
// When dynamic linking is supported, we internalize everything that can
671-
// not be imported which also means that there is no point of having it
670+
// When dynamic linking is supported, we internalize everything (except
671+
// kernels which are the entry points from host code to device code) that
672+
// cannot be imported which also means that there is no point of having it
672673
// visible outside of the current module.
673674
if (AllowDeviceImageDependencies)
674-
return canBeImportedFunction(*F);
675+
return F->getCallingConv() == CallingConv::SPIR_KERNEL ||
676+
canBeImportedFunction(*F);
675677

676678
// Otherwise, we are being even more aggressive: SYCL modules are expected
677679
// to be self-contained, meaning that they have no external dependencies.

sycl/doc/design/SharedLibraries.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ if the Function "can be imported". A `canBeImportedFunction` is:
160160
161161
1. Not an intrinsic
162162
2. Name does not start with "__"
163-
3. Demangled name does not start with "__"
164-
4. Must be a `SYCL_EXTERNAL` function
163+
3. Is not a SPIRV, SYCL, or ESIMD builtin function
164+
4. Demangled name does not start with "__"
165+
5. Must be a `SYCL_EXTERNAL` function
165166
166167
More information about `SYCL_EXTERNAL` can be found in:
167168
https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:syclexternal
@@ -173,7 +174,9 @@ the following modifications:
173174
function. Instead the dependency is recorded in the imported symbols property list.
174175
- An image that provides a `canBeImportedFunction` has the symbol recorded in the exported
175176
symbols property list.
176-
- All functions symbols that are not `canBeImportedFunction` are internalized
177+
- All functions symbols that are not `canBeImportedFunction` and are not kernels are internalized.
178+
Note that kernel functions should not be included in `canBeImportedFunction` since kernels
179+
are only callable by host code, and thus would never need to be imported into a device image.
177180
178181
179182
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
// Ensure -fsycl-allow-device-dependencies can work with free function kernels.
2+
3+
// REQUIRES: aspect-usm_shared_allocations
4+
// RUN: %{build} -o %t.out -fsycl-allow-device-dependencies
5+
// RUN: %{run} %t.out
6+
7+
// The name mangling for free function kernels currently does not work with PTX.
8+
// UNSUPPORTED: cuda
9+
10+
#include <iostream>
11+
#include <sycl/detail/core.hpp>
12+
#include <sycl/ext/oneapi/free_function_queries.hpp>
13+
#include <sycl/usm.hpp>
14+
15+
using namespace sycl;
16+
17+
void printUSM(int *usmPtr, int size) {
18+
std::cout << "usmPtr[] = {";
19+
for (int i = 0; i < size; i++) {
20+
std::cout << usmPtr[i] << ", ";
21+
}
22+
std::cout << "}\n";
23+
}
24+
25+
bool checkUSM(int *usmPtr, int size, int *Result) {
26+
bool Pass = true;
27+
for (int i = 0; i < size; i++) {
28+
if (usmPtr[i] != Result[i]) {
29+
Pass = false;
30+
break;
31+
}
32+
}
33+
if (Pass)
34+
return true;
35+
36+
std::cout << "Expected = {";
37+
for (int i = 0; i < size; i++) {
38+
std::cout << Result[i] << ", ";
39+
}
40+
std::cout << "}\n";
41+
std::cout << "Result = {";
42+
for (int i = 0; i < size; i++) {
43+
std::cout << usmPtr[i] << ", ";
44+
}
45+
std::cout << "}\n";
46+
return false;
47+
}
48+
49+
extern "C" SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
50+
(ext::oneapi::experimental::single_task_kernel)) void ff_0(int *ptr,
51+
int start,
52+
int end) {
53+
for (int i = start; i <= end; i++)
54+
ptr[i] = start + end;
55+
}
56+
57+
bool test_0(queue Queue) {
58+
constexpr int Range = 10;
59+
int *usmPtr = malloc_shared<int>(Range, Queue);
60+
int start = 3;
61+
int end = 5;
62+
int Result[Range] = {0, 0, 0, 8, 8, 8, 0, 0, 0, 0};
63+
range<1> R1{Range};
64+
65+
memset(usmPtr, 0, Range * sizeof(int));
66+
Queue.submit([&](handler &Handler) {
67+
Handler.single_task([=]() {
68+
for (int i = start; i <= end; i++)
69+
usmPtr[i] = start + end;
70+
});
71+
});
72+
Queue.wait();
73+
bool PassA = checkUSM(usmPtr, Range, Result);
74+
// TODO: Avoid printing anything if test passes to reduce I/O.
75+
std::cout << "Test 0a: " << (PassA ? "PASS" : "FAIL") << std::endl;
76+
77+
bool PassB = false;
78+
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
79+
#ifndef __SYCL_DEVICE_ONLY__
80+
kernel_bundle Bundle =
81+
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
82+
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<ff_0>();
83+
kernel Kernel = Bundle.get_kernel(Kernel_id);
84+
memset(usmPtr, 0, Range * sizeof(int));
85+
Queue.submit([&](handler &Handler) {
86+
Handler.set_arg(0, usmPtr);
87+
Handler.set_arg(1, start);
88+
Handler.set_arg(2, end);
89+
Handler.single_task(Kernel);
90+
});
91+
Queue.wait();
92+
PassB = checkUSM(usmPtr, Range, Result);
93+
// TODO: Avoid printing anything if test passes to reduce I/O.
94+
std::cout << "Test 0b: " << (PassB ? "PASS" : "FAIL") << std::endl;
95+
96+
free(usmPtr, Queue);
97+
#endif
98+
return PassA && PassB;
99+
}
100+
101+
// Overloaded free function definition.
102+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
103+
(ext::oneapi::experimental::nd_range_kernel<1>))
104+
void ff_1(int *ptr, int start, int end) {
105+
nd_item<1> Item = ext::oneapi::this_work_item::get_nd_item<1>();
106+
id<1> GId = Item.get_global_id();
107+
ptr[GId.get(0)] = GId.get(0) + start + end;
108+
}
109+
110+
bool test_1(queue Queue) {
111+
constexpr int Range = 10;
112+
int *usmPtr = malloc_shared<int>(Range, Queue);
113+
int start = 3;
114+
int Result[Range] = {13, 14, 15, 16, 17, 18, 19, 20, 21, 22};
115+
nd_range<1> R1{{Range}, {1}};
116+
117+
memset(usmPtr, 0, Range * sizeof(int));
118+
Queue.submit([&](handler &Handler) {
119+
Handler.parallel_for(R1, [=](nd_item<1> Item) {
120+
id<1> GId = Item.get_global_id();
121+
usmPtr[GId.get(0)] = GId.get(0) + start + Range;
122+
});
123+
});
124+
Queue.wait();
125+
bool PassA = checkUSM(usmPtr, Range, Result);
126+
// TODO: Avoid printing anything if test passes to reduce I/O.
127+
std::cout << "Test 1a: " << (PassA ? "PASS" : "FAIL") << std::endl;
128+
129+
bool PassB = false;
130+
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
131+
#ifndef __SYCL_DEVICE_ONLY__
132+
kernel_bundle Bundle =
133+
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
134+
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<(
135+
void (*)(int *, int, int))ff_1>();
136+
kernel Kernel = Bundle.get_kernel(Kernel_id);
137+
memset(usmPtr, 0, Range * sizeof(int));
138+
Queue.submit([&](handler &Handler) {
139+
Handler.set_arg(0, usmPtr);
140+
Handler.set_arg(1, start);
141+
Handler.set_arg(2, Range);
142+
Handler.parallel_for(R1, Kernel);
143+
});
144+
Queue.wait();
145+
PassB = checkUSM(usmPtr, Range, Result);
146+
// TODO: Avoid printing anything if test passes to reduce I/O.
147+
std::cout << "Test 1b: " << (PassB ? "PASS" : "FAIL") << std::endl;
148+
149+
free(usmPtr, Queue);
150+
#endif
151+
return PassA && PassB;
152+
}
153+
154+
// Overloaded free function definition.
155+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
156+
(ext::oneapi::experimental::nd_range_kernel<2>))
157+
void ff_1(int *ptr, int start) {
158+
int(&ptr2D)[4][4] = *reinterpret_cast<int(*)[4][4]>(ptr);
159+
nd_item<2> Item = ext::oneapi::this_work_item::get_nd_item<2>();
160+
id<2> GId = Item.get_global_id();
161+
id<2> LId = Item.get_local_id();
162+
ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + start;
163+
}
164+
165+
bool test_2(queue Queue) {
166+
constexpr int Range = 16;
167+
int *usmPtr = malloc_shared<int>(Range, Queue);
168+
int value = 55;
169+
int Result[Range] = {55, 56, 55, 56, 56, 57, 56, 57,
170+
55, 56, 55, 56, 56, 57, 56, 57};
171+
nd_range<2> R2{range<2>{4, 4}, range<2>{2, 2}};
172+
173+
memset(usmPtr, 0, Range * sizeof(int));
174+
Queue.submit([&](handler &Handler) {
175+
Handler.parallel_for(R2, [=](nd_item<2> Item) {
176+
int(&ptr2D)[4][4] = *reinterpret_cast<int(*)[4][4]>(usmPtr);
177+
id<2> GId = Item.get_global_id();
178+
id<2> LId = Item.get_local_id();
179+
ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + value;
180+
});
181+
});
182+
Queue.wait();
183+
bool PassA = checkUSM(usmPtr, Range, Result);
184+
// TODO: Avoid printing anything if test passes to reduce I/O.
185+
std::cout << "Test 2a: " << (PassA ? "PASS" : "FAIL") << std::endl;
186+
187+
bool PassB = false;
188+
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
189+
#ifndef __SYCL_DEVICE_ONLY__
190+
kernel_bundle Bundle =
191+
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
192+
kernel_id Kernel_id =
193+
ext::oneapi::experimental::get_kernel_id<(void (*)(int *, int))ff_1>();
194+
kernel Kernel = Bundle.get_kernel(Kernel_id);
195+
memset(usmPtr, 0, Range * sizeof(int));
196+
Queue.submit([&](handler &Handler) {
197+
Handler.set_arg(0, usmPtr);
198+
Handler.set_arg(1, value);
199+
Handler.parallel_for(R2, Kernel);
200+
});
201+
Queue.wait();
202+
PassB = checkUSM(usmPtr, Range, Result);
203+
// TODO: Avoid printing anything if test passes to reduce I/O.
204+
std::cout << "Test 2b: " << (PassB ? "PASS" : "FAIL") << std::endl;
205+
206+
free(usmPtr, Queue);
207+
#endif
208+
return PassA && PassB;
209+
}
210+
211+
// Templated free function definition.
212+
template <typename T>
213+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
214+
(ext::oneapi::experimental::nd_range_kernel<2>))
215+
void ff_3(T *ptr, T start) {
216+
int(&ptr2D)[4][4] = *reinterpret_cast<int(*)[4][4]>(ptr);
217+
nd_item<2> Item = ext::oneapi::this_work_item::get_nd_item<2>();
218+
id<2> GId = Item.get_global_id();
219+
id<2> LId = Item.get_local_id();
220+
ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + start;
221+
}
222+
223+
// Explicit instantiation with "int*".
224+
template void ff_3(int *ptr, int start);
225+
226+
bool test_3(queue Queue) {
227+
constexpr int Range = 16;
228+
int *usmPtr = malloc_shared<int>(Range, Queue);
229+
int value = 55;
230+
int Result[Range] = {55, 56, 55, 56, 56, 57, 56, 57,
231+
55, 56, 55, 56, 56, 57, 56, 57};
232+
nd_range<2> R2{range<2>{4, 4}, range<2>{2, 2}};
233+
234+
memset(usmPtr, 0, Range * sizeof(int));
235+
Queue.submit([&](handler &Handler) {
236+
Handler.parallel_for(R2, [=](nd_item<2> Item) {
237+
int(&ptr2D)[4][4] = *reinterpret_cast<int(*)[4][4]>(usmPtr);
238+
id<2> GId = Item.get_global_id();
239+
id<2> LId = Item.get_local_id();
240+
ptr2D[GId.get(0)][GId.get(1)] = LId.get(0) + LId.get(1) + value;
241+
});
242+
});
243+
Queue.wait();
244+
bool PassA = checkUSM(usmPtr, Range, Result);
245+
// TODO: Avoid printing anything if test passes to reduce I/O.
246+
std::cout << "Test 3a: " << (PassA ? "PASS" : "FAIL") << std::endl;
247+
248+
bool PassB = false;
249+
// TODO: Avoid using __SYCL_DEVICE_ONLY__ or give rationale with a comment
250+
#ifndef __SYCL_DEVICE_ONLY__
251+
kernel_bundle Bundle =
252+
get_kernel_bundle<bundle_state::executable>(Queue.get_context());
253+
kernel_id Kernel_id = ext::oneapi::experimental::get_kernel_id<(
254+
void (*)(int *, int))ff_3<int>>();
255+
kernel Kernel = Bundle.get_kernel(Kernel_id);
256+
memset(usmPtr, 0, Range * sizeof(int));
257+
Queue.submit([&](handler &Handler) {
258+
Handler.set_arg(0, usmPtr);
259+
Handler.set_arg(1, value);
260+
Handler.parallel_for(R2, Kernel);
261+
});
262+
Queue.wait();
263+
PassB = checkUSM(usmPtr, Range, Result);
264+
// TODO: Avoid printing anything if test passes to reduce I/O.
265+
std::cout << "Test 3b: " << (PassB ? "PASS" : "FAIL") << std::endl;
266+
267+
free(usmPtr, Queue);
268+
#endif
269+
return PassA && PassB;
270+
}
271+
272+
int main() {
273+
queue Queue;
274+
275+
bool Pass = true;
276+
Pass &= test_0(Queue);
277+
Pass &= test_1(Queue);
278+
Pass &= test_2(Queue);
279+
Pass &= test_3(Queue);
280+
281+
return Pass ? 0 : 1;
282+
}

0 commit comments

Comments
 (0)