Skip to content

Commit fca7f5e

Browse files
[SYCL][Graph] Enable specialization constants with graph (#11556)
- Enables specialization constants handling in SYCL-Graph extension. - Adds E2E tests that verify this behavior. - Removes unittests tests that checked for unsupported feature exception throwing. --------- Co-authored-by: Maxime France-Pillois <[email protected]>
1 parent 1f4ff10 commit fca7f5e

File tree

11 files changed

+535
-69
lines changed

11 files changed

+535
-69
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,10 +1776,6 @@ class __SYCL_EXPORT handler {
17761776
void set_specialization_constant(
17771777
typename std::remove_reference_t<decltype(SpecName)>::value_type Value) {
17781778

1779-
throwIfGraphAssociated<
1780-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
1781-
sycl_specialization_constants>();
1782-
17831779
setStateSpecConstSet();
17841780

17851781
std::shared_ptr<detail::kernel_bundle_impl> KernelBundleImplPtr =
@@ -1794,10 +1790,6 @@ class __SYCL_EXPORT handler {
17941790
typename std::remove_reference_t<decltype(SpecName)>::value_type
17951791
get_specialization_constant() const {
17961792

1797-
throwIfGraphAssociated<
1798-
ext::oneapi::experimental::detail::UnsupportedGraphFeatures::
1799-
sycl_specialization_constants>();
1800-
18011793
if (isStateExplicitKernelBundle())
18021794
throw sycl::exception(make_error_code(errc::invalid),
18031795
"Specialization constants cannot be read after "

sycl/source/detail/scheduler/commands.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,7 @@ pi_int32 enqueueImpCommandBufferKernel(
24052405
pi_kernel PiKernel = nullptr;
24062406
std::mutex *KernelMutex = nullptr;
24072407
pi_program PiProgram = nullptr;
2408+
std::shared_ptr<device_image_impl> DeviceImageImpl = nullptr;
24082409

24092410
auto Kernel = CommandGroup.MSyclKernel;
24102411
auto KernelBundleImplPtr = CommandGroup.MKernelBundle;
@@ -2417,7 +2418,6 @@ pi_int32 enqueueImpCommandBufferKernel(
24172418
// they can simply be launched directly.
24182419
if (KernelBundleImplPtr && !KernelBundleImplPtr->isInterop()) {
24192420
std::shared_ptr<kernel_impl> SyclKernelImpl;
2420-
std::shared_ptr<device_image_impl> DeviceImageImpl;
24212421
auto KernelName = CommandGroup.MKernelName;
24222422
kernel_id KernelID =
24232423
detail::ProgramManager::getInstance().getSYCLKernelID(KernelName);
@@ -2439,13 +2439,12 @@ pi_int32 enqueueImpCommandBufferKernel(
24392439
ContextImpl, DeviceImpl, CommandGroup.MKernelName);
24402440
}
24412441

2442-
auto SetFunc = [&Plugin, &PiKernel, &Ctx, &getMemAllocationFunc](
2443-
sycl::detail::ArgDesc &Arg, size_t NextTrueIndex) {
2444-
sycl::detail::SetArgBasedOnType(
2445-
Plugin, PiKernel,
2446-
nullptr /* TODO: Handle spec constants and pass device image here */
2447-
,
2448-
getMemAllocationFunc, Ctx, false, Arg, NextTrueIndex);
2442+
auto SetFunc = [&Plugin, &PiKernel, &DeviceImageImpl, &Ctx,
2443+
&getMemAllocationFunc](sycl::detail::ArgDesc &Arg,
2444+
size_t NextTrueIndex) {
2445+
sycl::detail::SetArgBasedOnType(Plugin, PiKernel, DeviceImageImpl,
2446+
getMemAllocationFunc, Ctx, false, Arg,
2447+
NextTrueIndex);
24492448
};
24502449
// Copy args for modification
24512450
auto Args = CommandGroup.MArgs;

sycl/source/handler.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ bool handler::isStateExplicitKernelBundle() const {
111111
std::shared_ptr<detail::kernel_bundle_impl>
112112
handler::getOrInsertHandlerKernelBundle(bool Insert) const {
113113
if (!MImpl->MKernelBundle && Insert) {
114-
MImpl->MKernelBundle =
115-
detail::getSyclObjImpl(get_kernel_bundle<bundle_state::input>(
116-
MQueue->get_context(), {MQueue->get_device()}, {}));
114+
auto Ctx = MGraph ? MGraph->getContext() : MQueue->get_context();
115+
auto Dev = MGraph ? MGraph->getDevice() : MQueue->get_device();
116+
MImpl->MKernelBundle = detail::getSyclObjImpl(
117+
get_kernel_bundle<bundle_state::input>(Ctx, {Dev}, {}));
117118
}
118119
return MImpl->MKernelBundle;
119120
}
@@ -179,10 +180,10 @@ event handler::finalize() {
179180
// Make sure implicit non-interop kernel bundles have the kernel
180181
if (!KernelBundleImpPtr->isInterop() &&
181182
!MImpl->isStateExplicitKernelBundle()) {
183+
auto Dev = MGraph ? MGraph->getDevice() : MQueue->get_device();
182184
kernel_id KernelID =
183185
detail::ProgramManager::getInstance().getSYCLKernelID(MKernelName);
184-
bool KernelInserted =
185-
KernelBundleImpPtr->add_kernel(KernelID, MQueue->get_device());
186+
bool KernelInserted = KernelBundleImpPtr->add_kernel(KernelID, Dev);
186187
// If kernel was not inserted and the bundle is in input mode we try
187188
// building it and trying to find the kernel in executable mode
188189
if (!KernelInserted &&
@@ -194,8 +195,7 @@ event handler::finalize() {
194195
build(KernelBundle);
195196
KernelBundleImpPtr = detail::getSyclObjImpl(ExecKernelBundle);
196197
setHandlerKernelBundle(KernelBundleImpPtr);
197-
KernelInserted =
198-
KernelBundleImpPtr->add_kernel(KernelID, MQueue->get_device());
198+
KernelInserted = KernelBundleImpPtr->add_kernel(KernelID, Dev);
199199
}
200200
// If the kernel was not found in executable mode we throw an exception
201201
if (!KernelInserted)
@@ -835,7 +835,7 @@ void handler::verifyUsedKernelBundle(const std::string &KernelName) {
835835

836836
kernel_id KernelID = detail::get_kernel_id_impl(KernelName);
837837
device Dev =
838-
(MGraph) ? MGraph->getDevice() : detail::getDeviceFromHandler(*this);
838+
MGraph ? MGraph->getDevice() : detail::getDeviceFromHandler(*this);
839839
if (!UsedKernelBundleImplPtr->has_kernel(KernelID, Dev))
840840
throw sycl::exception(
841841
make_error_code(errc::kernel_not_supported),
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// REQUIRES: cuda || level_zero, gpu
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
// Extra run to check for leaks in Level Zero using ZE_DEBUG
5+
// RUN: %if ext_oneapi_level_zero %{env ZE_DEBUG=4 %{run} %t.out 2>&1 | FileCheck %s %}
6+
//
7+
// CHECK-NOT: LEAK
8+
9+
// The following limitation is not restricted to Sycl-Graph
10+
// but comes from the orignal test : `SpecConstants/2020/handler-api.cpp`
11+
// FIXME: ACC devices use emulation path, which is not yet supported
12+
// UNSUPPORTED: accelerator
13+
14+
#define GRAPH_E2E_EXPLICIT
15+
16+
#include "../Inputs/spec_constants_handler_api.cpp"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// REQUIRES: cuda || level_zero, gpu
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
// Extra run to check for leaks in Level Zero using ZE_DEBUG
5+
// RUN: %if ext_oneapi_level_zero %{env ZE_DEBUG=4 %{run} %t.out 2>&1 | FileCheck %s %}
6+
//
7+
// CHECK-NOT: LEAK
8+
9+
// The following limitation is not restricted to Sycl-Graph
10+
// but comes from the orignal test : `SpecConstants/2020/kernel-bundle-api.cpp`
11+
// FIXME: ACC devices use emulation path, which is not yet supported
12+
// UNSUPPORTED: accelerator
13+
// UNSUPPORTED: hip
14+
15+
#define GRAPH_E2E_EXPLICIT
16+
17+
#include "../Inputs/spec_constants_kernel_bundle_api.cpp"
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// This test is intended to check basic operations with SYCL 2020 specialization
2+
// constants using Graph and sycl::handler and sycl::kernel_handler APIs
3+
// This test was taken from `SpecConstants/2020/handler-api.cpp`.
4+
// Variable names have been changed to meet PascalCase naming convention
5+
// requirements.
6+
7+
#include "../graph_common.hpp"
8+
9+
constexpr sycl::specialization_id<int> IntId;
10+
constexpr sycl::specialization_id<int> IntId2(2);
11+
constexpr sycl::specialization_id<float> FloatId(3.14);
12+
13+
class TestDefaultValuesKernel;
14+
class EmptyKernel;
15+
class TestSetAndGetOnDevice;
16+
17+
bool test_default_values(sycl::queue Queue);
18+
bool test_set_and_get_on_host(sycl::queue Queue);
19+
bool test_set_and_get_on_device(sycl::queue Queue);
20+
21+
bool test_set_and_get_on_device(sycl::queue Queue);
22+
23+
int main() {
24+
auto ExceptionHandler = [&](sycl::exception_list Exceptions) {
25+
for (std::exception_ptr const &E : Exceptions) {
26+
try {
27+
std::rethrow_exception(E);
28+
} catch (sycl::exception const &E) {
29+
std::cout << "An async SYCL exception was caught: " << E.what()
30+
<< std::endl;
31+
std::exit(1);
32+
}
33+
}
34+
};
35+
36+
queue Queue{ExceptionHandler,
37+
{sycl::ext::intel::property::queue::no_immediate_command_list{}}};
38+
39+
unsigned Errors = 0;
40+
if (!test_default_values(Queue)) {
41+
std::cout << "Test for default values of specialization constants failed!"
42+
<< std::endl;
43+
Errors++;
44+
}
45+
46+
if (!test_set_and_get_on_host(Queue)) {
47+
std::cout << "Test for set and get API on host failed!" << std::endl;
48+
Errors++;
49+
}
50+
51+
if (!test_set_and_get_on_device(Queue)) {
52+
std::cout << "Test for set and get API on device failed!" << std::endl;
53+
Errors++;
54+
}
55+
56+
return (Errors == 0) ? 0 : 1;
57+
};
58+
59+
bool test_default_values(sycl::queue Queue) {
60+
sycl::buffer<int> IntBuffer(1);
61+
IntBuffer.set_write_back(false);
62+
sycl::buffer<int> IntBuffer2(1);
63+
IntBuffer2.set_write_back(false);
64+
sycl::buffer<float> FloatBuffer(1);
65+
FloatBuffer.set_write_back(false);
66+
67+
{
68+
exp_ext::command_graph Graph{
69+
Queue.get_context(),
70+
Queue.get_device(),
71+
{exp_ext::property::graph::assume_buffer_outlives_graph{}}};
72+
73+
add_node(Graph, Queue, ([&](sycl::handler &CGH) {
74+
auto IntAcc =
75+
IntBuffer.get_access<sycl::access::mode::write>(CGH);
76+
auto IntAcc2 =
77+
IntBuffer2.get_access<sycl::access::mode::write>(CGH);
78+
auto FloatAcc =
79+
FloatBuffer.get_access<sycl::access::mode::write>(CGH);
80+
81+
CGH.single_task<TestDefaultValuesKernel>(
82+
[=](sycl::kernel_handler KH) {
83+
IntAcc[0] = KH.get_specialization_constant<IntId>();
84+
IntAcc2[0] = KH.get_specialization_constant<IntId2>();
85+
FloatAcc[0] = KH.get_specialization_constant<FloatId>();
86+
});
87+
}));
88+
89+
auto GraphExec = Graph.finalize();
90+
91+
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
92+
Queue.wait_and_throw();
93+
}
94+
95+
unsigned Errors = 0;
96+
sycl::host_accessor IntAcc(IntBuffer, sycl::read_only);
97+
if (!check_value(
98+
0, IntAcc[0],
99+
"integer specialization constant (defined without default value)"))
100+
Errors++;
101+
102+
sycl::host_accessor IntAcc2(IntBuffer2, sycl::read_only);
103+
if (!check_value(2, IntAcc2[0], "integer specialization constant"))
104+
Errors++;
105+
106+
sycl::host_accessor FloatAcc(FloatBuffer, sycl::read_only);
107+
if (!check_value(3.14f, FloatAcc[0], "float specialization constant"))
108+
Errors++;
109+
110+
return Errors == 0;
111+
}
112+
113+
bool test_set_and_get_on_host(sycl::queue Queue) {
114+
unsigned Errors = 0;
115+
116+
exp_ext::command_graph Graph{
117+
Queue.get_context(),
118+
Queue.get_device(),
119+
{exp_ext::property::graph::assume_buffer_outlives_graph{}}};
120+
121+
add_node(
122+
Graph, Queue, ([&](sycl::handler &CGH) {
123+
if (!check_value(
124+
0, CGH.get_specialization_constant<IntId>(),
125+
"integer specializaiton constant before setting any value"))
126+
++Errors;
127+
128+
if (!check_value(
129+
3.14f, CGH.get_specialization_constant<FloatId>(),
130+
"float specializaiton constant before setting any value"))
131+
++Errors;
132+
133+
int NewIntValue = 8;
134+
float NewFloatValue = 3.0f;
135+
CGH.set_specialization_constant<IntId>(NewIntValue);
136+
CGH.set_specialization_constant<FloatId>(NewFloatValue);
137+
138+
if (!check_value(
139+
NewIntValue, CGH.get_specialization_constant<IntId>(),
140+
"integer specializaiton constant after setting a new value"))
141+
++Errors;
142+
143+
if (!check_value(
144+
NewFloatValue, CGH.get_specialization_constant<FloatId>(),
145+
"float specializaiton constant after setting a new value"))
146+
++Errors;
147+
148+
CGH.single_task<EmptyKernel>([=]() {});
149+
}));
150+
151+
return Errors == 0;
152+
}
153+
154+
bool test_set_and_get_on_device(sycl::queue Queue) {
155+
sycl::buffer<int> IntBuffer(1);
156+
IntBuffer.set_write_back(false);
157+
sycl::buffer<int> IntBuffer2(1);
158+
IntBuffer2.set_write_back(false);
159+
sycl::buffer<float> FloatBuffer(1);
160+
FloatBuffer.set_write_back(false);
161+
162+
int NewIntValue = 8;
163+
int NewIntValue2 = 0;
164+
float NewFloatValue = 3.0f;
165+
166+
{
167+
exp_ext::command_graph Graph{
168+
Queue.get_context(),
169+
Queue.get_device(),
170+
{exp_ext::property::graph::assume_buffer_outlives_graph{}}};
171+
172+
add_node(
173+
Graph, Queue, ([&](sycl::handler &CGH) {
174+
auto IntAcc = IntBuffer.get_access<sycl::access::mode::write>(CGH);
175+
auto IntAcc2 = IntBuffer2.get_access<sycl::access::mode::write>(CGH);
176+
auto FloatAcc =
177+
FloatBuffer.get_access<sycl::access::mode::write>(CGH);
178+
179+
CGH.set_specialization_constant<IntId>(NewIntValue);
180+
CGH.set_specialization_constant<IntId2>(NewIntValue2);
181+
CGH.set_specialization_constant<FloatId>(NewFloatValue);
182+
183+
CGH.single_task<TestSetAndGetOnDevice>([=](sycl::kernel_handler KH) {
184+
IntAcc[0] = KH.get_specialization_constant<IntId>();
185+
IntAcc2[0] = KH.get_specialization_constant<IntId2>();
186+
FloatAcc[0] = KH.get_specialization_constant<FloatId>();
187+
});
188+
}));
189+
190+
auto GraphExec = Graph.finalize();
191+
192+
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
193+
Queue.wait_and_throw();
194+
}
195+
196+
unsigned Errors = 0;
197+
sycl::host_accessor IntAcc(IntBuffer, sycl::read_only);
198+
if (!check_value(NewIntValue, IntAcc[0], "integer specialization constant"))
199+
Errors++;
200+
201+
sycl::host_accessor IntAcc2(IntBuffer2, sycl::read_only);
202+
if (!check_value(NewIntValue2, IntAcc2[0], "integer specialization constant"))
203+
Errors++;
204+
205+
sycl::host_accessor FloatAcc(FloatBuffer, sycl::read_only);
206+
if (!check_value(NewFloatValue, FloatAcc[0], "float specialization constant"))
207+
Errors++;
208+
209+
return Errors == 0;
210+
}

0 commit comments

Comments
 (0)