Skip to content

Commit c222497

Browse files
[SYCL] Remove duplicate devices on submission to kernel_bundle API functions (#4790)
When submitting devices to compile, section 4.11.7 of the SYCL 2020 spec states that the compilation is for each unique device, duplicate devices removed. Here we ensure that the devices sent to the backend are each unique. Also, Level Zero doesn't support compilation for more than one device but it calls die() presently. Changing that as well. Signed-off-by: Chris Perkins <[email protected]>
1 parent df1ff7a commit c222497

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

sycl/include/CL/sycl/device.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace sycl {
2525
class device_selector;
2626
namespace detail {
2727
class device_impl;
28+
auto getDeviceComparisonLambda();
2829
}
2930

3031
/// The SYCL device class encapsulates a single SYCL device on which kernels
@@ -215,6 +216,8 @@ class __SYCL_EXPORT device {
215216

216217
template <class T>
217218
friend T detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);
219+
220+
friend auto detail::getDeviceComparisonLambda();
218221
};
219222

220223
} // namespace sycl

sycl/include/CL/sycl/kernel_bundle.hpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <cassert>
2121
#include <memory>
22+
#include <set>
2223
#include <vector>
2324

2425
__SYCL_INLINE_NAMESPACE(cl) {
@@ -375,6 +376,21 @@ namespace detail {
375376
__SYCL_EXPORT detail::KernelBundleImplPtr
376377
get_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
377378
bundle_state State);
379+
380+
inline auto getDeviceComparisonLambda() {
381+
return [](device a, device b) { return a.getNative() != b.getNative(); };
382+
}
383+
384+
inline const std::vector<device>
385+
removeDuplicateDevices(const std::vector<device> &Devs) {
386+
auto compareDevices = getDeviceComparisonLambda();
387+
std::set<device, decltype(compareDevices)> UniqueDeviceSet(
388+
Devs.begin(), Devs.end(), compareDevices);
389+
std::vector<device> UniqueDevices(UniqueDeviceSet.begin(),
390+
UniqueDeviceSet.end());
391+
return UniqueDevices;
392+
}
393+
378394
} // namespace detail
379395

380396
/// A kernel bundle in state State which contains all of the kernels in the
@@ -384,8 +400,10 @@ get_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
384400
template <bundle_state State>
385401
kernel_bundle<State> get_kernel_bundle(const context &Ctx,
386402
const std::vector<device> &Devs) {
403+
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);
404+
387405
detail::KernelBundleImplPtr Impl =
388-
detail::get_kernel_bundle_impl(Ctx, Devs, State);
406+
detail::get_kernel_bundle_impl(Ctx, UniqueDevices, State);
389407

390408
return detail::createSyclObjFromImpl<kernel_bundle<State>>(Impl);
391409
}
@@ -417,8 +435,10 @@ template <bundle_state State>
417435
kernel_bundle<State>
418436
get_kernel_bundle(const context &Ctx, const std::vector<device> &Devs,
419437
const std::vector<kernel_id> &KernelIDs) {
438+
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);
439+
420440
detail::KernelBundleImplPtr Impl =
421-
detail::get_kernel_bundle_impl(Ctx, Devs, KernelIDs, State);
441+
detail::get_kernel_bundle_impl(Ctx, UniqueDevices, KernelIDs, State);
422442
return detail::createSyclObjFromImpl<kernel_bundle<State>>(Impl);
423443
}
424444

@@ -459,14 +479,16 @@ template <bundle_state State, typename SelectorT>
459479
kernel_bundle<State> get_kernel_bundle(const context &Ctx,
460480
const std::vector<device> &Devs,
461481
SelectorT Selector) {
482+
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);
483+
462484
detail::DevImgSelectorImpl SelectorWrapper =
463485
[Selector](const detail::DeviceImageImplPtr &DevImg) {
464486
return Selector(
465487
detail::createSyclObjFromImpl<sycl::device_image<State>>(DevImg));
466488
};
467489

468-
detail::KernelBundleImplPtr Impl =
469-
detail::get_kernel_bundle_impl(Ctx, Devs, State, SelectorWrapper);
490+
detail::KernelBundleImplPtr Impl = detail::get_kernel_bundle_impl(
491+
Ctx, UniqueDevices, State, SelectorWrapper);
470492

471493
return detail::createSyclObjFromImpl<sycl::kernel_bundle<State>>(Impl);
472494
}
@@ -589,8 +611,10 @@ compile_impl(const kernel_bundle<bundle_state::input> &InputBundle,
589611
inline kernel_bundle<bundle_state::object>
590612
compile(const kernel_bundle<bundle_state::input> &InputBundle,
591613
const std::vector<device> &Devs, const property_list &PropList = {}) {
614+
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);
615+
592616
detail::KernelBundleImplPtr Impl =
593-
detail::compile_impl(InputBundle, Devs, PropList);
617+
detail::compile_impl(InputBundle, UniqueDevices, PropList);
594618
return detail::createSyclObjFromImpl<
595619
kernel_bundle<sycl::bundle_state::object>>(Impl);
596620
}
@@ -622,8 +646,10 @@ link_impl(const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
622646
inline kernel_bundle<bundle_state::executable>
623647
link(const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
624648
const std::vector<device> &Devs, const property_list &PropList = {}) {
649+
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);
650+
625651
detail::KernelBundleImplPtr Impl =
626-
detail::link_impl(ObjectBundles, Devs, PropList);
652+
detail::link_impl(ObjectBundles, UniqueDevices, PropList);
627653
return detail::createSyclObjFromImpl<
628654
kernel_bundle<sycl::bundle_state::executable>>(Impl);
629655
}
@@ -667,8 +693,10 @@ build_impl(const kernel_bundle<bundle_state::input> &InputBundle,
667693
inline kernel_bundle<bundle_state::executable>
668694
build(const kernel_bundle<bundle_state::input> &InputBundle,
669695
const std::vector<device> &Devs, const property_list &PropList = {}) {
696+
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);
697+
670698
detail::KernelBundleImplPtr Impl =
671-
detail::build_impl(InputBundle, Devs, PropList);
699+
detail::build_impl(InputBundle, UniqueDevices, PropList);
672700
return detail::createSyclObjFromImpl<
673701
kernel_bundle<sycl::bundle_state::executable>>(Impl);
674702
}

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3407,8 +3407,10 @@ pi_result piProgramCreateWithBinary(
34073407
PI_ASSERT(Program, PI_INVALID_PROGRAM);
34083408

34093409
// For now we support only one device.
3410-
if (NumDevices != 1)
3411-
die("piProgramCreateWithBinary: level_zero supports only one device.");
3410+
if (NumDevices != 1) {
3411+
zePrint("piProgramCreateWithBinary: level_zero supports only one device.");
3412+
return PI_INVALID_VALUE;
3413+
}
34123414
if (!Binaries[0] || !Lengths[0]) {
34133415
if (BinaryStatus)
34143416
*BinaryStatus = PI_INVALID_VALUE;
@@ -3605,8 +3607,10 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices,
36053607

36063608
// We only support one device with Level Zero currently.
36073609
pi_device Device = Context->Devices[0];
3608-
if (NumDevices != 1)
3609-
die("piProgramLink: level_zero supports only one device.");
3610+
if (NumDevices != 1) {
3611+
zePrint("piProgramLink: level_zero supports only one device.");
3612+
return PI_INVALID_VALUE;
3613+
}
36103614

36113615
PI_ASSERT(DeviceList && DeviceList[0] == Device, PI_INVALID_DEVICE);
36123616
PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE);
@@ -3783,8 +3787,10 @@ static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
37833787
// We only support build to one device with Level Zero now.
37843788
// TODO: we should eventually build to the possibly multiple root
37853789
// devices in the context.
3786-
if (NumDevices != 1)
3787-
die("compileOrBuild: level_zero supports only one device.");
3790+
if (NumDevices != 1) {
3791+
zePrint("compileOrBuild: level_zero supports only one device.");
3792+
return PI_INVALID_VALUE;
3793+
}
37883794

37893795
PI_ASSERT(DeviceList, PI_INVALID_DEVICE);
37903796

0 commit comments

Comments
 (0)