@@ -117,17 +117,18 @@ join_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles) {
117
117
118
118
bool has_kernel_bundle_impl (const context &Ctx, const std::vector<device> &Devs,
119
119
bundle_state State) {
120
- const bool AllDevicesInTheContext = CheckAllDevicesAreInContext (Devs, Ctx);
120
+ // Check that all requested devices are associated with the context
121
+ const bool AllDevicesInTheContext = checkAllDevicesAreInContext (Devs, Ctx);
121
122
if (Devs.empty () || !AllDevicesInTheContext)
122
123
throw sycl::exception (make_error_code (errc::invalid),
123
124
" Not all devices are associated with the context or "
124
125
" vector of devices is empty" );
125
126
126
127
if (bundle_state::input == State &&
127
- !CheckAllDevicesHaveAspect (Devs, aspect::online_compiler))
128
+ !checkAllDevicesHaveAspect (Devs, aspect::online_compiler))
128
129
return false ;
129
130
if (bundle_state::object == State &&
130
- !CheckAllDevicesHaveAspect (Devs, aspect::online_linker))
131
+ !checkAllDevicesHaveAspect (Devs, aspect::online_linker))
131
132
return false ;
132
133
133
134
const std::vector<device_image_plain> DeviceImages =
@@ -143,7 +144,8 @@ bool has_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
143
144
bool has_kernel_bundle_impl (const context &Ctx, const std::vector<device> &Devs,
144
145
const std::vector<kernel_id> &KernelIds,
145
146
bundle_state State) {
146
- const bool AllDevicesInTheContext = CheckAllDevicesAreInContext (Devs, Ctx);
147
+ // Check that all requested devices are associated with the context
148
+ const bool AllDevicesInTheContext = checkAllDevicesAreInContext (Devs, Ctx);
147
149
148
150
if (Devs.empty () || !AllDevicesInTheContext)
149
151
throw sycl::exception (make_error_code (errc::invalid),
@@ -218,33 +220,18 @@ std::vector<sycl::device> find_device_intersection(
218
220
// for all bundles
219
221
std::vector<sycl::device> IntersectDevices;
220
222
std::vector<unsigned int > DevsCounters;
223
+ std::map<device, unsigned int , LessByHash<device>> DevCounters;
221
224
for (const sycl::kernel_bundle<bundle_state::object> &ObjectBundle :
222
225
ObjectBundles)
223
- // Increment counter in "DevsCounters" each time a device is seen
224
- for (const sycl::device &Device : ObjectBundle.get_devices ()) {
225
- auto It =
226
- std::find (IntersectDevices.begin (), IntersectDevices.end (), Device);
227
- if (IntersectDevices.end () != It) {
228
- assert ((size_t )(std::distance (IntersectDevices.begin (), It) + 1 ) ==
229
- DevsCounters.size ());
230
- ++DevsCounters[std::distance (IntersectDevices.begin (), It)];
231
- continue ;
232
- }
233
- IntersectDevices.push_back (Device);
234
- DevsCounters.push_back (1 );
235
- }
236
-
237
- // If for some device counter is less than ObjectBundles.size() it means some
238
- // bundle doesn't have it - remove such a device from the final result
239
- size_t NewSize = DevsCounters.size ();
240
- for (size_t Idx = 0 ; Idx < NewSize; ++Idx) {
241
- if (ObjectBundles.size () == DevsCounters[Idx])
242
- continue ;
243
-
244
- std::swap (IntersectDevices[Idx], IntersectDevices.back ());
245
- --NewSize;
246
- }
247
- IntersectDevices.resize (NewSize);
226
+ // Increment counter in "DevCounters" each time a device is seen
227
+ for (const sycl::device &Device : ObjectBundle.get_devices ())
228
+ DevCounters[Device]++;
229
+
230
+ // If some device counter is less than ObjectBundles.size() then some bundle
231
+ // doesn't have it - do not add such a device to the final result
232
+ for (const std::pair<const device, unsigned int > &It : DevCounters)
233
+ if (ObjectBundles.size () == It.second )
234
+ IntersectDevices.push_back (It.first );
248
235
249
236
return IntersectDevices;
250
237
}
0 commit comments