Skip to content

Commit 26acbe2

Browse files
committed
Address comments
1 parent f8bbd7b commit 26acbe2

File tree

2 files changed

+22
-34
lines changed

2 files changed

+22
-34
lines changed

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ template <class T> struct LessByHash {
3333
}
3434
};
3535

36-
static bool CheckAllDevicesAreInContext(const std::vector<device> &Devices,
36+
static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
3737
const context &Context) {
3838
const std::vector<device> &ContextDevices = Context.get_devices();
3939
return std::all_of(
@@ -43,7 +43,7 @@ static bool CheckAllDevicesAreInContext(const std::vector<device> &Devices,
4343
});
4444
}
4545

46-
static bool CheckAllDevicesHaveAspect(const std::vector<device> &Devices,
46+
static bool checkAllDevicesHaveAspect(const std::vector<device> &Devices,
4747
aspect Aspect) {
4848
return std::all_of(Devices.begin(), Devices.end(),
4949
[&Aspect](const device &Dev) { return Dev.has(Aspect); });
@@ -56,20 +56,20 @@ class kernel_bundle_impl {
5656

5757
void common_ctor_checks(bundle_state State) {
5858
const bool AllDevicesInTheContext =
59-
CheckAllDevicesAreInContext(MDevices, MContext);
59+
checkAllDevicesAreInContext(MDevices, MContext);
6060
if (MDevices.empty() || !AllDevicesInTheContext)
6161
throw sycl::exception(
6262
make_error_code(errc::invalid),
6363
"Not all devices are associated with the context or "
6464
"vector of devices is empty");
6565

6666
if (bundle_state::input == State &&
67-
!CheckAllDevicesHaveAspect(MDevices, aspect::online_compiler))
67+
!checkAllDevicesHaveAspect(MDevices, aspect::online_compiler))
6868
throw sycl::exception(make_error_code(errc::invalid),
6969
"Not all devices have aspect::online_compiler");
7070

7171
if (bundle_state::object == State &&
72-
!CheckAllDevicesHaveAspect(MDevices, aspect::online_linker))
72+
!checkAllDevicesHaveAspect(MDevices, aspect::online_linker))
7373
throw sycl::exception(make_error_code(errc::invalid),
7474
"Not all devices have aspect::online_linker");
7575
}
@@ -155,6 +155,7 @@ class kernel_bundle_impl {
155155
// devices for any of the bundles in ObjectBundles
156156
const bool AllDevsAssociatedWithInputBundles = std::all_of(
157157
MDevices.begin(), MDevices.end(), [&ObjectBundles](const device &Dev) {
158+
// Number of devices is expected to be small
158159
return std::all_of(
159160
ObjectBundles.begin(), ObjectBundles.end(),
160161
[&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {

sycl/source/kernel_bundle.cpp

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,18 @@ join_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles) {
117117

118118
bool has_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
119119
bundle_state State) {
120-
const bool AllDevicesInTheContext = CheckAllDevicesAreInContext(Devs, Ctx);
120+
// Check that all requested devices are associated with the context
121+
const bool AllDevicesInTheContext = checkAllDevicesAreInContext(Devs, Ctx);
121122
if (Devs.empty() || !AllDevicesInTheContext)
122123
throw sycl::exception(make_error_code(errc::invalid),
123124
"Not all devices are associated with the context or "
124125
"vector of devices is empty");
125126

126127
if (bundle_state::input == State &&
127-
!CheckAllDevicesHaveAspect(Devs, aspect::online_compiler))
128+
!checkAllDevicesHaveAspect(Devs, aspect::online_compiler))
128129
return false;
129130
if (bundle_state::object == State &&
130-
!CheckAllDevicesHaveAspect(Devs, aspect::online_linker))
131+
!checkAllDevicesHaveAspect(Devs, aspect::online_linker))
131132
return false;
132133

133134
const std::vector<device_image_plain> DeviceImages =
@@ -143,7 +144,8 @@ bool has_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
143144
bool has_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
144145
const std::vector<kernel_id> &KernelIds,
145146
bundle_state State) {
146-
const bool AllDevicesInTheContext = CheckAllDevicesAreInContext(Devs, Ctx);
147+
// Check that all requested devices are associated with the context
148+
const bool AllDevicesInTheContext = checkAllDevicesAreInContext(Devs, Ctx);
147149

148150
if (Devs.empty() || !AllDevicesInTheContext)
149151
throw sycl::exception(make_error_code(errc::invalid),
@@ -218,33 +220,18 @@ std::vector<sycl::device> find_device_intersection(
218220
// for all bundles
219221
std::vector<sycl::device> IntersectDevices;
220222
std::vector<unsigned int> DevsCounters;
223+
std::map<device, unsigned int, LessByHash<device>> DevCounters;
221224
for (const sycl::kernel_bundle<bundle_state::object> &ObjectBundle :
222225
ObjectBundles)
223-
// Increment counter in "DevsCounters" each time a device is seen
224-
for (const sycl::device &Device : ObjectBundle.get_devices()) {
225-
auto It =
226-
std::find(IntersectDevices.begin(), IntersectDevices.end(), Device);
227-
if (IntersectDevices.end() != It) {
228-
assert((size_t)(std::distance(IntersectDevices.begin(), It) + 1) ==
229-
DevsCounters.size());
230-
++DevsCounters[std::distance(IntersectDevices.begin(), It)];
231-
continue;
232-
}
233-
IntersectDevices.push_back(Device);
234-
DevsCounters.push_back(1);
235-
}
236-
237-
// If for some device counter is less than ObjectBundles.size() it means some
238-
// bundle doesn't have it - remove such a device from the final result
239-
size_t NewSize = DevsCounters.size();
240-
for (size_t Idx = 0; Idx < NewSize; ++Idx) {
241-
if (ObjectBundles.size() == DevsCounters[Idx])
242-
continue;
243-
244-
std::swap(IntersectDevices[Idx], IntersectDevices.back());
245-
--NewSize;
246-
}
247-
IntersectDevices.resize(NewSize);
226+
// Increment counter in "DevCounters" each time a device is seen
227+
for (const sycl::device &Device : ObjectBundle.get_devices())
228+
DevCounters[Device]++;
229+
230+
// If some device counter is less than ObjectBundles.size() then some bundle
231+
// doesn't have it - do not add such a device to the final result
232+
for (const std::pair<const device, unsigned int> &It : DevCounters)
233+
if (ObjectBundles.size() == It.second)
234+
IntersectDevices.push_back(It.first);
248235

249236
return IntersectDevices;
250237
}

0 commit comments

Comments
 (0)