19
19
20
20
#include < cassert>
21
21
#include < memory>
22
+ #include < set>
22
23
#include < vector>
23
24
24
25
__SYCL_INLINE_NAMESPACE (cl) {
@@ -375,6 +376,21 @@ namespace detail {
375
376
__SYCL_EXPORT detail::KernelBundleImplPtr
376
377
get_kernel_bundle_impl (const context &Ctx, const std::vector<device> &Devs,
377
378
bundle_state State);
379
+
380
+ inline auto getDeviceComparisonLambda () {
381
+ return [](device a, device b) { return a.getNative () != b.getNative (); };
382
+ }
383
+
384
+ inline const std::vector<device>
385
+ removeDuplicateDevices (const std::vector<device> &Devs) {
386
+ auto compareDevices = getDeviceComparisonLambda ();
387
+ std::set<device, decltype (compareDevices)> UniqueDeviceSet (
388
+ Devs.begin (), Devs.end (), compareDevices);
389
+ std::vector<device> UniqueDevices (UniqueDeviceSet.begin (),
390
+ UniqueDeviceSet.end ());
391
+ return UniqueDevices;
392
+ }
393
+
378
394
} // namespace detail
379
395
380
396
// / A kernel bundle in state State which contains all of the kernels in the
@@ -384,8 +400,10 @@ get_kernel_bundle_impl(const context &Ctx, const std::vector<device> &Devs,
384
400
template <bundle_state State>
385
401
kernel_bundle<State> get_kernel_bundle (const context &Ctx,
386
402
const std::vector<device> &Devs) {
403
+ std::vector<device> UniqueDevices = detail::removeDuplicateDevices (Devs);
404
+
387
405
detail::KernelBundleImplPtr Impl =
388
- detail::get_kernel_bundle_impl (Ctx, Devs , State);
406
+ detail::get_kernel_bundle_impl (Ctx, UniqueDevices , State);
389
407
390
408
return detail::createSyclObjFromImpl<kernel_bundle<State>>(Impl);
391
409
}
@@ -417,8 +435,10 @@ template <bundle_state State>
417
435
kernel_bundle<State>
418
436
get_kernel_bundle (const context &Ctx, const std::vector<device> &Devs,
419
437
const std::vector<kernel_id> &KernelIDs) {
438
+ std::vector<device> UniqueDevices = detail::removeDuplicateDevices (Devs);
439
+
420
440
detail::KernelBundleImplPtr Impl =
421
- detail::get_kernel_bundle_impl (Ctx, Devs , KernelIDs, State);
441
+ detail::get_kernel_bundle_impl (Ctx, UniqueDevices , KernelIDs, State);
422
442
return detail::createSyclObjFromImpl<kernel_bundle<State>>(Impl);
423
443
}
424
444
@@ -459,14 +479,16 @@ template <bundle_state State, typename SelectorT>
459
479
kernel_bundle<State> get_kernel_bundle (const context &Ctx,
460
480
const std::vector<device> &Devs,
461
481
SelectorT Selector) {
482
+ std::vector<device> UniqueDevices = detail::removeDuplicateDevices (Devs);
483
+
462
484
detail::DevImgSelectorImpl SelectorWrapper =
463
485
[Selector](const detail::DeviceImageImplPtr &DevImg) {
464
486
return Selector (
465
487
detail::createSyclObjFromImpl<sycl::device_image<State>>(DevImg));
466
488
};
467
489
468
- detail::KernelBundleImplPtr Impl =
469
- detail::get_kernel_bundle_impl ( Ctx, Devs , State, SelectorWrapper);
490
+ detail::KernelBundleImplPtr Impl = detail::get_kernel_bundle_impl (
491
+ Ctx, UniqueDevices , State, SelectorWrapper);
470
492
471
493
return detail::createSyclObjFromImpl<sycl::kernel_bundle<State>>(Impl);
472
494
}
@@ -589,8 +611,10 @@ compile_impl(const kernel_bundle<bundle_state::input> &InputBundle,
589
611
inline kernel_bundle<bundle_state::object>
590
612
compile (const kernel_bundle<bundle_state::input> &InputBundle,
591
613
const std::vector<device> &Devs, const property_list &PropList = {}) {
614
+ std::vector<device> UniqueDevices = detail::removeDuplicateDevices (Devs);
615
+
592
616
detail::KernelBundleImplPtr Impl =
593
- detail::compile_impl (InputBundle, Devs , PropList);
617
+ detail::compile_impl (InputBundle, UniqueDevices , PropList);
594
618
return detail::createSyclObjFromImpl<
595
619
kernel_bundle<sycl::bundle_state::object>>(Impl);
596
620
}
@@ -622,8 +646,10 @@ link_impl(const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
622
646
inline kernel_bundle<bundle_state::executable>
623
647
link (const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
624
648
const std::vector<device> &Devs, const property_list &PropList = {}) {
649
+ std::vector<device> UniqueDevices = detail::removeDuplicateDevices (Devs);
650
+
625
651
detail::KernelBundleImplPtr Impl =
626
- detail::link_impl (ObjectBundles, Devs , PropList);
652
+ detail::link_impl (ObjectBundles, UniqueDevices , PropList);
627
653
return detail::createSyclObjFromImpl<
628
654
kernel_bundle<sycl::bundle_state::executable>>(Impl);
629
655
}
@@ -667,8 +693,10 @@ build_impl(const kernel_bundle<bundle_state::input> &InputBundle,
667
693
inline kernel_bundle<bundle_state::executable>
668
694
build (const kernel_bundle<bundle_state::input> &InputBundle,
669
695
const std::vector<device> &Devs, const property_list &PropList = {}) {
696
+ std::vector<device> UniqueDevices = detail::removeDuplicateDevices (Devs);
697
+
670
698
detail::KernelBundleImplPtr Impl =
671
- detail::build_impl (InputBundle, Devs , PropList);
699
+ detail::build_impl (InputBundle, UniqueDevices , PropList);
672
700
return detail::createSyclObjFromImpl<
673
701
kernel_bundle<sycl::bundle_state::executable>>(Impl);
674
702
}
0 commit comments