@@ -86,6 +86,7 @@ class kernel_bundle_impl {
86
86
87
87
MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
88
88
MContext, MDevices, State);
89
+ fillUniqueDeviceImages ();
89
90
}
90
91
91
92
// Interop constructor used by make_kernel
@@ -103,7 +104,8 @@ class kernel_bundle_impl {
103
104
kernel_bundle_impl (context Ctx, std::vector<device> Devs,
104
105
device_image_plain &DevImage)
105
106
: kernel_bundle_impl(Ctx, Devs) {
106
- MDeviceImages.push_back (DevImage);
107
+ MDeviceImages.emplace_back (DevImage);
108
+ MUniqueDeviceImages.emplace_back (DevImage);
107
109
}
108
110
109
111
// Matches sycl::build and sycl::compile
@@ -115,10 +117,12 @@ class kernel_bundle_impl {
115
117
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
116
118
MState (TargetState) {
117
119
118
- MSpecConstValues = getSyclObjImpl (InputBundle)->get_spec_const_map_ref ();
120
+ const std::shared_ptr<kernel_bundle_impl> &InputBundleImpl =
121
+ getSyclObjImpl (InputBundle);
122
+ MSpecConstValues = InputBundleImpl->get_spec_const_map_ref ();
119
123
120
124
const std::vector<device> &InputBundleDevices =
121
- getSyclObjImpl (InputBundle) ->get_devices ();
125
+ InputBundleImpl ->get_devices ();
122
126
const bool AllDevsAssociatedWithInputBundle =
123
127
std::all_of (MDevices.begin (), MDevices.end (),
124
128
[&InputBundleDevices](const device &Dev) {
@@ -132,24 +136,37 @@ class kernel_bundle_impl {
132
136
" Not all devices are in the set of associated "
133
137
" devices for input bundle or vector of devices is empty" );
134
138
135
- for (const device_image_plain &DeviceImage : InputBundle) {
139
+ for (const DevImgPlainWithDeps &DevImgWithDeps :
140
+ InputBundleImpl->MDeviceImages ) {
136
141
// Skip images which are not compatible with devices provided
137
- if (std::none_of (
138
- MDevices. begin (), MDevices. end (),
139
- [&DeviceImage]( const device &Dev) {
140
- return getSyclObjImpl (DeviceImage) ->compatible_with_device (Dev);
141
- }))
142
+ if (std::none_of (MDevices. begin (), MDevices. end (),
143
+ [&DevImgWithDeps]( const device &Dev) {
144
+ return getSyclObjImpl (DevImgWithDeps. getMain ())
145
+ ->compatible_with_device (Dev);
146
+ }))
142
147
continue ;
143
148
144
149
switch (TargetState) {
145
- case bundle_state::object:
146
- MDeviceImages.push_back (detail::ProgramManager::getInstance ().compile (
147
- DeviceImage, MDevices, PropList));
150
+ case bundle_state::object: {
151
+ DevImgPlainWithDeps CompiledImgWithDeps =
152
+ detail::ProgramManager::getInstance ().compile (DevImgWithDeps,
153
+ MDevices, PropList);
154
+
155
+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
156
+ CompiledImgWithDeps.begin (),
157
+ CompiledImgWithDeps.end ());
158
+ MDeviceImages.push_back (std::move (CompiledImgWithDeps));
148
159
break ;
149
- case bundle_state::executable:
150
- MDeviceImages.push_back (detail::ProgramManager::getInstance ().build (
151
- DeviceImage, MDevices, PropList));
160
+ }
161
+
162
+ case bundle_state::executable: {
163
+ device_image_plain BuiltImg =
164
+ detail::ProgramManager::getInstance ().build (DevImgWithDeps,
165
+ MDevices, PropList);
166
+ MDeviceImages.emplace_back (BuiltImg);
167
+ MUniqueDeviceImages.push_back (BuiltImg);
152
168
break ;
169
+ }
153
170
case bundle_state::input:
154
171
case bundle_state::ext_oneapi_source:
155
172
throw exception (make_error_code (errc::runtime),
@@ -158,6 +175,7 @@ class kernel_bundle_impl {
158
175
break ;
159
176
}
160
177
}
178
+ removeDuplicateImages ();
161
179
}
162
180
163
181
// Matches sycl::link
@@ -201,7 +219,7 @@ class kernel_bundle_impl {
201
219
" Not all devices are in the set of associated "
202
220
" devices for input bundles" );
203
221
204
- // TODO: Unify with c'tor for sycl::comile and sycl::build by calling
222
+ // TODO: Unify with c'tor for sycl::compile and sycl::build by calling
205
223
// sycl::join on vector of kernel_bundles
206
224
207
225
// The loop below just links each device image separately, not linking any
@@ -213,23 +231,27 @@ class kernel_bundle_impl {
213
231
// undefined symbols, then the logic in this loop will need to be changed.
214
232
for (const kernel_bundle<bundle_state::object> &ObjectBundle :
215
233
ObjectBundles) {
216
- for (const device_image_plain &DeviceImage : ObjectBundle) {
234
+ for (const DevImgPlainWithDeps &DeviceImageWithDeps :
235
+ getSyclObjImpl (ObjectBundle)->MDeviceImages ) {
217
236
218
237
// Skip images which are not compatible with devices provided
219
238
if (std::none_of (MDevices.begin (), MDevices.end (),
220
- [&DeviceImage ](const device &Dev) {
221
- return getSyclObjImpl (DeviceImage )
239
+ [&DeviceImageWithDeps ](const device &Dev) {
240
+ return getSyclObjImpl (DeviceImageWithDeps. getMain () )
222
241
->compatible_with_device (Dev);
223
242
}))
224
243
continue ;
225
244
226
245
std::vector<device_image_plain> LinkedResults =
227
- detail::ProgramManager::getInstance ().link (DeviceImage, MDevices ,
228
- PropList);
246
+ detail::ProgramManager::getInstance ().link (DeviceImageWithDeps ,
247
+ MDevices, PropList);
229
248
MDeviceImages.insert (MDeviceImages.end (), LinkedResults.begin (),
230
249
LinkedResults.end ());
250
+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
251
+ LinkedResults.begin (), LinkedResults.end ());
231
252
}
232
253
}
254
+ removeDuplicateImages ();
233
255
234
256
for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
235
257
const KernelBundleImplPtr BundlePtr = getSyclObjImpl (Bundle);
@@ -249,6 +271,7 @@ class kernel_bundle_impl {
249
271
250
272
MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
251
273
MContext, MDevices, KernelIDs, State);
274
+ fillUniqueDeviceImages ();
252
275
}
253
276
254
277
kernel_bundle_impl (context Ctx, std::vector<device> Devs,
@@ -259,6 +282,7 @@ class kernel_bundle_impl {
259
282
260
283
MDeviceImages = detail::ProgramManager::getInstance ().getSYCLDeviceImages (
261
284
MContext, MDevices, Selector, State);
285
+ fillUniqueDeviceImages ();
262
286
}
263
287
264
288
// C'tor matches sycl::join API
@@ -287,11 +311,10 @@ class kernel_bundle_impl {
287
311
Bundle->MDeviceImages .end ());
288
312
}
289
313
290
- std::sort (MDeviceImages.begin (), MDeviceImages.end (),
291
- LessByHash<device_image_plain>{});
314
+ fillUniqueDeviceImages ();
292
315
293
316
if (get_bundle_state () == bundle_state::input) {
294
- // Copy spec constants values from the device images to be removed .
317
+ // Copy spec constants values from the device images.
295
318
auto MergeSpecConstants = [this ](const device_image_plain &Img) {
296
319
const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl (Img);
297
320
const std::map<std::string,
@@ -310,16 +333,9 @@ class kernel_bundle_impl {
310
333
SpecConst.second .back ().Size );
311
334
}
312
335
};
313
- std::for_each (MDeviceImages.begin (), MDeviceImages.end (),
314
- MergeSpecConstants);
336
+ std::for_each (begin (), end (), MergeSpecConstants);
315
337
}
316
338
317
- const auto DevImgIt =
318
- std::unique (MDeviceImages.begin (), MDeviceImages.end ());
319
-
320
- // Remove duplicate device images.
321
- MDeviceImages.erase (DevImgIt, MDeviceImages.end ());
322
-
323
339
for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
324
340
for (const std::pair<const std::string, std::vector<unsigned char >>
325
341
&SpecConst : Bundle->MSpecConstValues ) {
@@ -605,7 +621,7 @@ class kernel_bundle_impl {
605
621
606
622
assert (MDeviceImages.size () > 0 );
607
623
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
608
- detail::getSyclObjImpl (MDeviceImages[0 ]);
624
+ detail::getSyclObjImpl (MDeviceImages[0 ]. getMain () );
609
625
ur_program_handle_t UrProgram = DeviceImageImpl->get_ur_program_ref ();
610
626
ContextImplPtr ContextImpl = getSyclObjImpl (MContext);
611
627
const AdapterPtr &Adapter = ContextImpl->getAdapter ();
@@ -634,7 +650,7 @@ class kernel_bundle_impl {
634
650
// Collect kernel ids from all device images, then remove duplicates
635
651
636
652
std::vector<kernel_id> Result;
637
- for (const device_image_plain &DeviceImage : MDeviceImages ) {
653
+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages ) {
638
654
const std::vector<kernel_id> &KernelIDs =
639
655
getSyclObjImpl (DeviceImage)->get_kernel_ids ();
640
656
@@ -662,8 +678,9 @@ class kernel_bundle_impl {
662
678
// Used to track if any of the candidate images has specialization values
663
679
// set.
664
680
bool SpecConstsSet = false ;
665
- for (auto &DeviceImage : MDeviceImages) {
666
- if (!DeviceImage.has_kernel (KernelID))
681
+ for (const DevImgPlainWithDeps &DeviceImageWithDeps : MDeviceImages) {
682
+ const device_image_plain &DeviceImage = DeviceImageWithDeps.getMain ();
683
+ if (!DeviceImageWithDeps.getMain ().has_kernel (KernelID))
667
684
continue ;
668
685
669
686
const auto DeviceImageImpl = detail::getSyclObjImpl (DeviceImage);
@@ -718,39 +735,38 @@ class kernel_bundle_impl {
718
735
}
719
736
720
737
bool has_kernel (const kernel_id &KernelID) const noexcept {
721
- return std::any_of (MDeviceImages. begin (), MDeviceImages. end (),
738
+ return std::any_of (begin (), end (),
722
739
[&KernelID](const device_image_plain &DeviceImage) {
723
740
return DeviceImage.has_kernel (KernelID);
724
741
});
725
742
}
726
743
727
744
bool has_kernel (const kernel_id &KernelID, const device &Dev) const noexcept {
728
745
return std::any_of (
729
- MDeviceImages. begin (), MDeviceImages. end (),
746
+ begin (), end (),
730
747
[&KernelID, &Dev](const device_image_plain &DeviceImage) {
731
748
return DeviceImage.has_kernel (KernelID, Dev);
732
749
});
733
750
}
734
751
735
752
bool contains_specialization_constants () const noexcept {
736
753
return std::any_of (
737
- MDeviceImages.begin (), MDeviceImages.end (),
738
- [](const device_image_plain &DeviceImage) {
754
+ begin (), end (), [](const device_image_plain &DeviceImage) {
739
755
return getSyclObjImpl (DeviceImage)->has_specialization_constants ();
740
756
});
741
757
}
742
758
743
759
bool native_specialization_constant () const noexcept {
744
760
return contains_specialization_constants () &&
745
- std::all_of (MDeviceImages. begin (), MDeviceImages. end (),
761
+ std::all_of (begin (), end (),
746
762
[](const device_image_plain &DeviceImage) {
747
763
return getSyclObjImpl (DeviceImage)
748
764
->all_specialization_constant_native ();
749
765
});
750
766
}
751
767
752
768
bool has_specialization_constant (const char *SpecName) const noexcept {
753
- return std::any_of (MDeviceImages. begin (), MDeviceImages. end (),
769
+ return std::any_of (begin (), end (),
754
770
[SpecName](const device_image_plain &DeviceImage) {
755
771
return getSyclObjImpl (DeviceImage)
756
772
->has_specialization_constant (SpecName);
@@ -761,7 +777,7 @@ class kernel_bundle_impl {
761
777
const void *Value,
762
778
size_t Size) noexcept {
763
779
if (has_specialization_constant (SpecName))
764
- for (const device_image_plain &DeviceImage : MDeviceImages )
780
+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages )
765
781
getSyclObjImpl (DeviceImage)
766
782
->set_specialization_constant_raw_value (SpecName, Value);
767
783
else {
@@ -773,7 +789,7 @@ class kernel_bundle_impl {
773
789
774
790
void get_specialization_constant_raw_value (const char *SpecName,
775
791
void *ValueRet) const noexcept {
776
- for (const device_image_plain &DeviceImage : MDeviceImages )
792
+ for (const device_image_plain &DeviceImage : MUniqueDeviceImages )
777
793
if (getSyclObjImpl (DeviceImage)->has_specialization_constant (SpecName)) {
778
794
getSyclObjImpl (DeviceImage)
779
795
->get_specialization_constant_raw_value (SpecName, ValueRet);
@@ -796,21 +812,21 @@ class kernel_bundle_impl {
796
812
797
813
bool is_specialization_constant_set (const char *SpecName) const noexcept {
798
814
bool SetInDevImg =
799
- std::any_of (MDeviceImages. begin (), MDeviceImages. end (),
815
+ std::any_of (begin (), end (),
800
816
[SpecName](const device_image_plain &DeviceImage) {
801
817
return getSyclObjImpl (DeviceImage)
802
818
->is_specialization_constant_set (SpecName);
803
819
});
804
820
return SetInDevImg || MSpecConstValues.count (std::string{SpecName}) != 0 ;
805
821
}
806
822
807
- const device_image_plain *begin () const { return MDeviceImages .data (); }
823
+ const device_image_plain *begin () const { return MUniqueDeviceImages .data (); }
808
824
809
825
const device_image_plain *end () const {
810
- return MDeviceImages .data () + MDeviceImages .size ();
826
+ return MUniqueDeviceImages .data () + MUniqueDeviceImages .size ();
811
827
}
812
828
813
- size_t size () const noexcept { return MDeviceImages .size (); }
829
+ size_t size () const noexcept { return MUniqueDeviceImages .size (); }
814
830
815
831
bundle_state get_bundle_state () const { return MState; }
816
832
@@ -827,7 +843,7 @@ class kernel_bundle_impl {
827
843
828
844
// First try and get images in current bundle state
829
845
const bundle_state BundleState = get_bundle_state ();
830
- std::vector<device_image_plain > NewDevImgs =
846
+ std::vector<DevImgPlainWithDeps > NewDevImgs =
831
847
detail::ProgramManager::getInstance ().getSYCLDeviceImages (
832
848
MContext, {Dev}, {KernelID}, BundleState);
833
849
@@ -836,21 +852,38 @@ class kernel_bundle_impl {
836
852
return false ;
837
853
838
854
// Propagate already set specialization constants to the new images
839
- for (device_image_plain &DevImg : NewDevImgs)
840
- for (auto SpecConst : MSpecConstValues)
841
- getSyclObjImpl (DevImg)->set_specialization_constant_raw_value (
842
- SpecConst.first .c_str (), SpecConst.second .data ());
855
+ for (DevImgPlainWithDeps &DevImgWithDeps : NewDevImgs)
856
+ for (device_image_plain &DevImg : DevImgWithDeps)
857
+ for (auto SpecConst : MSpecConstValues)
858
+ getSyclObjImpl (DevImg)->set_specialization_constant_raw_value (
859
+ SpecConst.first .c_str (), SpecConst.second .data ());
843
860
844
861
// Add the images to the collection
845
862
MDeviceImages.insert (MDeviceImages.end (), NewDevImgs.begin (),
846
863
NewDevImgs.end ());
864
+ removeDuplicateImages ();
847
865
return true ;
848
866
}
849
867
850
868
private:
869
+ void fillUniqueDeviceImages () {
870
+ assert (MUniqueDeviceImages.empty ());
871
+ for (const DevImgPlainWithDeps &Imgs : MDeviceImages)
872
+ MUniqueDeviceImages.insert (MUniqueDeviceImages.end (), Imgs.begin (),
873
+ Imgs.end ());
874
+ removeDuplicateImages ();
875
+ }
876
+ void removeDuplicateImages () {
877
+ std::sort (MUniqueDeviceImages.begin (), MUniqueDeviceImages.end (),
878
+ LessByHash<device_image_plain>{});
879
+ const auto It =
880
+ std::unique (MUniqueDeviceImages.begin (), MUniqueDeviceImages.end ());
881
+ MUniqueDeviceImages.erase (It, MUniqueDeviceImages.end ());
882
+ }
851
883
context MContext;
852
884
std::vector<device> MDevices;
853
- std::vector<device_image_plain> MDeviceImages;
885
+ std::vector<DevImgPlainWithDeps> MDeviceImages;
886
+ std::vector<device_image_plain> MUniqueDeviceImages;
854
887
// This map stores values for specialization constants, that are missing
855
888
// from any device image.
856
889
SpecConstMapT MSpecConstValues;
0 commit comments