Skip to content

[SYCL] Remove duplicate devices on submission to kernel_bundle API functions #4790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sycl/include/CL/sycl/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace sycl {
class device_selector;
namespace detail {
class device_impl;
auto getDeviceComparisonLambda();
}

/// The SYCL device class encapsulates a single SYCL device on which kernels
Expand Down Expand Up @@ -215,6 +216,8 @@ class __SYCL_EXPORT device {

template <class T>
friend T detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);

friend auto detail::getDeviceComparisonLambda();
};

} // namespace sycl
Expand Down
42 changes: 35 additions & 7 deletions sycl/include/CL/sycl/kernel_bundle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <cassert>
#include <memory>
#include <set>
#include <vector>

__SYCL_INLINE_NAMESPACE(cl) {
Expand Down Expand Up @@ -375,6 +376,21 @@ namespace detail {
__SYCL_EXPORT detail::KernelBundleImplPtr
get_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
bundle_state State);

inline auto getDeviceComparisonLambda() {
return [](device a, device b) { return a.getNative() != b.getNative(); };
}

inline const std::vector<device>
removeDuplicateDevices(const std::vector<device> &Devs) {
auto compareDevices = getDeviceComparisonLambda();
std::set<device, decltype(compareDevices)> UniqueDeviceSet(
Devs.begin(), Devs.end(), compareDevices);
std::vector<device> UniqueDevices(UniqueDeviceSet.begin(),
UniqueDeviceSet.end());
return UniqueDevices;
}

} // namespace detail

/// A kernel bundle in state State which contains all of the kernels in the
Expand All @@ -384,8 +400,10 @@ get_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
template <bundle_state State>
kernel_bundle<State> get_kernel_bundle(const context &Ctx,
const std::vector<device> &Devs) {
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);

detail::KernelBundleImplPtr Impl =
detail::get_kernel_bundle_impl(Ctx, Devs, State);
detail::get_kernel_bundle_impl(Ctx, UniqueDevices, State);

return detail::createSyclObjFromImpl<kernel_bundle<State>>(Impl);
}
Expand Down Expand Up @@ -417,8 +435,10 @@ template <bundle_state State>
kernel_bundle<State>
get_kernel_bundle(const context &Ctx, const std::vector<device> &Devs,
const std::vector<kernel_id> &KernelIDs) {
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);

detail::KernelBundleImplPtr Impl =
detail::get_kernel_bundle_impl(Ctx, Devs, KernelIDs, State);
detail::get_kernel_bundle_impl(Ctx, UniqueDevices, KernelIDs, State);
return detail::createSyclObjFromImpl<kernel_bundle<State>>(Impl);
}

Expand Down Expand Up @@ -459,14 +479,16 @@ template <bundle_state State, typename SelectorT>
kernel_bundle<State> get_kernel_bundle(const context &Ctx,
const std::vector<device> &Devs,
SelectorT Selector) {
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);

detail::DevImgSelectorImpl SelectorWrapper =
[Selector](const detail::DeviceImageImplPtr &DevImg) {
return Selector(
detail::createSyclObjFromImpl<sycl::device_image<State>>(DevImg));
};

detail::KernelBundleImplPtr Impl =
detail::get_kernel_bundle_impl(Ctx, Devs, State, SelectorWrapper);
detail::KernelBundleImplPtr Impl = detail::get_kernel_bundle_impl(
Ctx, UniqueDevices, State, SelectorWrapper);

return detail::createSyclObjFromImpl<sycl::kernel_bundle<State>>(Impl);
}
Expand Down Expand Up @@ -589,8 +611,10 @@ compile_impl(const kernel_bundle<bundle_state::input> &InputBundle,
inline kernel_bundle<bundle_state::object>
compile(const kernel_bundle<bundle_state::input> &InputBundle,
const std::vector<device> &Devs, const property_list &PropList = {}) {
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);

detail::KernelBundleImplPtr Impl =
detail::compile_impl(InputBundle, Devs, PropList);
detail::compile_impl(InputBundle, UniqueDevices, PropList);
return detail::createSyclObjFromImpl<
kernel_bundle<sycl::bundle_state::object>>(Impl);
}
Expand Down Expand Up @@ -622,8 +646,10 @@ link_impl(const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
inline kernel_bundle<bundle_state::executable>
link(const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
const std::vector<device> &Devs, const property_list &PropList = {}) {
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);

detail::KernelBundleImplPtr Impl =
detail::link_impl(ObjectBundles, Devs, PropList);
detail::link_impl(ObjectBundles, UniqueDevices, PropList);
return detail::createSyclObjFromImpl<
kernel_bundle<sycl::bundle_state::executable>>(Impl);
}
Expand Down Expand Up @@ -667,8 +693,10 @@ build_impl(const kernel_bundle<bundle_state::input> &InputBundle,
inline kernel_bundle<bundle_state::executable>
build(const kernel_bundle<bundle_state::input> &InputBundle,
const std::vector<device> &Devs, const property_list &PropList = {}) {
std::vector<device> UniqueDevices = detail::removeDuplicateDevices(Devs);

detail::KernelBundleImplPtr Impl =
detail::build_impl(InputBundle, Devs, PropList);
detail::build_impl(InputBundle, UniqueDevices, PropList);
return detail::createSyclObjFromImpl<
kernel_bundle<sycl::bundle_state::executable>>(Impl);
}
Expand Down
18 changes: 12 additions & 6 deletions sycl/plugins/level_zero/pi_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3407,8 +3407,10 @@ pi_result piProgramCreateWithBinary(
PI_ASSERT(Program, PI_INVALID_PROGRAM);

// For now we support only one device.
if (NumDevices != 1)
die("piProgramCreateWithBinary: level_zero supports only one device.");
if (NumDevices != 1) {
zePrint("piProgramCreateWithBinary: level_zero supports only one device.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify why it is not ok to use "die" here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mostly a matter of getting the devices to behave similarly. Right now, only the OpenCL GPU device supports compilation for multiple unique devices. All the other non-supporting devices throw exceptions when asked ... except LevelZero which dies.

return PI_INVALID_VALUE;
}
if (!Binaries[0] || !Lengths[0]) {
if (BinaryStatus)
*BinaryStatus = PI_INVALID_VALUE;
Expand Down Expand Up @@ -3605,8 +3607,10 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices,

// We only support one device with Level Zero currently.
pi_device Device = Context->Devices[0];
if (NumDevices != 1)
die("piProgramLink: level_zero supports only one device.");
if (NumDevices != 1) {
zePrint("piProgramLink: level_zero supports only one device.");
return PI_INVALID_VALUE;
}

PI_ASSERT(DeviceList && DeviceList[0] == Device, PI_INVALID_DEVICE);
PI_ASSERT(!PFnNotify && !UserData, PI_INVALID_VALUE);
Expand Down Expand Up @@ -3783,8 +3787,10 @@ static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
// We only support build to one device with Level Zero now.
// TODO: we should eventually build to the possibly multiple root
// devices in the context.
if (NumDevices != 1)
die("compileOrBuild: level_zero supports only one device.");
if (NumDevices != 1) {
zePrint("compileOrBuild: level_zero supports only one device.");
return PI_INVALID_VALUE;
}

PI_ASSERT(DeviceList, PI_INVALID_DEVICE);

Expand Down