Skip to content

Commit 1bedf88

Browse files
schittirromanovvlad
authored andcommitted
[SYCL] Fixing device check in program link constructor
This patch gets sycl devices via context in program interoperability constructor and sorts devices in program link constructor to check that all the programs in the list use the same devices It also addresses the case where program is constructed using only a subset of devices associated with sycl context Signed-off-by: Sindhu Chittireddy <[email protected]>
1 parent 0df0754 commit 1bedf88

File tree

4 files changed

+48
-4
lines changed

4 files changed

+48
-4
lines changed

sycl/include/CL/sycl/detail/device_host.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class device_host : public device_impl {
2424
cl_device_id get() const override {
2525
throw invalid_object_error("This instance of device is a host instance");
2626
}
27+
cl_device_id &getHandleRef() override {
28+
throw invalid_object_error("This instance of device is a host instance");
29+
}
2730

2831
bool is_host() const override { return true; }
2932

sycl/include/CL/sycl/detail/device_impl.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ class device_impl {
2929

3030
virtual cl_device_id get() const = 0;
3131

32+
// Returns underlying native device object (if any) w/o reference count
33+
// modification. Caller must ensure the returned object lives on stack only.
34+
// It can also be safely passed to the underlying native runtime API.
35+
// Warning. Returned reference will be invalid if device_impl was destroyed.
36+
virtual cl_device_id &getHandleRef() = 0;
37+
3238
virtual bool is_host() const = 0;
3339

3440
virtual bool is_cpu() const = 0;

sycl/include/CL/sycl/detail/device_opencl.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class device_opencl : public device_impl {
5757
return id;
5858
}
5959

60+
cl_device_id &getHandleRef() override{
61+
return id;
62+
}
63+
6064
bool is_host() const override { return false; }
6165

6266
bool is_cpu() const override { return (type == CL_DEVICE_TYPE_CPU); }

sycl/include/CL/sycl/detail/program_impl.hpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,23 @@ class program_impl {
5151
}
5252
Context = ProgramList[0]->Context;
5353
Devices = ProgramList[0]->Devices;
54+
std::vector<device> DevicesSorted;
55+
if (!is_host()) {
56+
DevicesSorted = sort_devices_by_cl_device_id(Devices);
57+
}
5458
for (const auto &Prg : ProgramList) {
5559
Prg->throw_if_state_is_not(program_state::compiled);
5660
if (Prg->Context != Context) {
5761
throw invalid_object_error(
5862
"Not all programs are associated with the same context");
5963
}
60-
if (Prg->Devices != Devices) {
61-
throw invalid_object_error(
62-
"Not all programs are associated with the same devices");
64+
if (!is_host()) {
65+
std::vector<device> PrgDevicesSorted =
66+
sort_devices_by_cl_device_id(Prg->Devices);
67+
if (PrgDevicesSorted != DevicesSorted) {
68+
throw invalid_object_error(
69+
"Not all programs are associated with the same devices");
70+
}
6371
}
6472
}
6573

@@ -92,7 +100,20 @@ class program_impl {
92100
CHECK_OCL_CODE(clGetProgramInfo(ClProgram, CL_PROGRAM_DEVICES,
93101
sizeof(cl_device_id) * NumDevices,
94102
ClDevices.data(), nullptr));
95-
Devices = vector_class<device>(ClDevices.begin(), ClDevices.end());
103+
vector_class<device> SyclContextDevices = Context.get_devices();
104+
105+
// Keep only the subset of the devices (associated with context) that
106+
// were actually used to create the program.
107+
// This is possible when clCreateProgramWithBinary is used.
108+
auto NewEnd = std::remove_if(
109+
SyclContextDevices.begin(), SyclContextDevices.end(),
110+
[&ClDevices](const sycl::device &Dev) {
111+
return ClDevices.end() ==
112+
std::find(ClDevices.begin(), ClDevices.end(),
113+
detail::getSyclObjImpl(Dev)->getHandleRef());
114+
});
115+
SyclContextDevices.erase(NewEnd, SyclContextDevices.end());
116+
Devices = SyclContextDevices;
96117
// TODO check build for each device instead
97118
cl_program_binary_type BinaryType;
98119
CHECK_OCL_CODE(clGetProgramBuildInfo(
@@ -371,6 +392,16 @@ class program_impl {
371392
return ClKernel;
372393
}
373394

395+
std::vector<device>
396+
sort_devices_by_cl_device_id(vector_class<device> Devices) {
397+
std::sort(Devices.begin(), Devices.end(),
398+
[](const device &id1, const device &id2) {
399+
return (detail::getSyclObjImpl(id1)->getHandleRef() <
400+
detail::getSyclObjImpl(id2)->getHandleRef());
401+
});
402+
return Devices;
403+
}
404+
374405
void throw_if_state_is(program_state State) const {
375406
if (this->State == State) {
376407
throw invalid_object_error("Invalid program state");

0 commit comments

Comments
 (0)