Skip to content

Commit cd419b8

Browse files
sergey-semenovKornevNikita
authored andcommitted
[SYCL] Support image dependencies in kernel bundles (#16228)
Add kernel bundle support for image dependencies, which are used for dynamic linking and device virtual function features.
1 parent 44808ef commit cd419b8

File tree

12 files changed

+1001
-562
lines changed

12 files changed

+1001
-562
lines changed

sycl/source/detail/device_binary_image.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class RTDeviceBinaryImage {
119119
ConstIterator begin() const { return ConstIterator(Begin); }
120120
ConstIterator end() const { return ConstIterator(End); }
121121
size_t size() const { return std::distance(begin(), end()); }
122+
bool empty() const { return begin() == end(); }
122123
friend class RTDeviceBinaryImage;
123124
bool isAvailable() const { return !(Begin == nullptr); }
124125

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 87 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class kernel_bundle_impl {
8686

8787
MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
8888
MContext, MDevices, State);
89+
fillUniqueDeviceImages();
8990
}
9091

9192
// Interop constructor used by make_kernel
@@ -103,7 +104,8 @@ class kernel_bundle_impl {
103104
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
104105
device_image_plain &DevImage)
105106
: kernel_bundle_impl(Ctx, Devs) {
106-
MDeviceImages.push_back(DevImage);
107+
MDeviceImages.emplace_back(DevImage);
108+
MUniqueDeviceImages.emplace_back(DevImage);
107109
}
108110

109111
// Matches sycl::build and sycl::compile
@@ -115,10 +117,12 @@ class kernel_bundle_impl {
115117
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
116118
MState(TargetState) {
117119

118-
MSpecConstValues = getSyclObjImpl(InputBundle)->get_spec_const_map_ref();
120+
const std::shared_ptr<kernel_bundle_impl> &InputBundleImpl =
121+
getSyclObjImpl(InputBundle);
122+
MSpecConstValues = InputBundleImpl->get_spec_const_map_ref();
119123

120124
const std::vector<device> &InputBundleDevices =
121-
getSyclObjImpl(InputBundle)->get_devices();
125+
InputBundleImpl->get_devices();
122126
const bool AllDevsAssociatedWithInputBundle =
123127
std::all_of(MDevices.begin(), MDevices.end(),
124128
[&InputBundleDevices](const device &Dev) {
@@ -132,24 +136,37 @@ class kernel_bundle_impl {
132136
"Not all devices are in the set of associated "
133137
"devices for input bundle or vector of devices is empty");
134138

135-
for (const device_image_plain &DeviceImage : InputBundle) {
139+
for (const DevImgPlainWithDeps &DevImgWithDeps :
140+
InputBundleImpl->MDeviceImages) {
136141
// Skip images which are not compatible with devices provided
137-
if (std::none_of(
138-
MDevices.begin(), MDevices.end(),
139-
[&DeviceImage](const device &Dev) {
140-
return getSyclObjImpl(DeviceImage)->compatible_with_device(Dev);
141-
}))
142+
if (std::none_of(MDevices.begin(), MDevices.end(),
143+
[&DevImgWithDeps](const device &Dev) {
144+
return getSyclObjImpl(DevImgWithDeps.getMain())
145+
->compatible_with_device(Dev);
146+
}))
142147
continue;
143148

144149
switch (TargetState) {
145-
case bundle_state::object:
146-
MDeviceImages.push_back(detail::ProgramManager::getInstance().compile(
147-
DeviceImage, MDevices, PropList));
150+
case bundle_state::object: {
151+
DevImgPlainWithDeps CompiledImgWithDeps =
152+
detail::ProgramManager::getInstance().compile(DevImgWithDeps,
153+
MDevices, PropList);
154+
155+
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
156+
CompiledImgWithDeps.begin(),
157+
CompiledImgWithDeps.end());
158+
MDeviceImages.push_back(std::move(CompiledImgWithDeps));
148159
break;
149-
case bundle_state::executable:
150-
MDeviceImages.push_back(detail::ProgramManager::getInstance().build(
151-
DeviceImage, MDevices, PropList));
160+
}
161+
162+
case bundle_state::executable: {
163+
device_image_plain BuiltImg =
164+
detail::ProgramManager::getInstance().build(DevImgWithDeps,
165+
MDevices, PropList);
166+
MDeviceImages.emplace_back(BuiltImg);
167+
MUniqueDeviceImages.push_back(BuiltImg);
152168
break;
169+
}
153170
case bundle_state::input:
154171
case bundle_state::ext_oneapi_source:
155172
throw exception(make_error_code(errc::runtime),
@@ -158,6 +175,7 @@ class kernel_bundle_impl {
158175
break;
159176
}
160177
}
178+
removeDuplicateImages();
161179
}
162180

163181
// Matches sycl::link
@@ -201,7 +219,7 @@ class kernel_bundle_impl {
201219
"Not all devices are in the set of associated "
202220
"devices for input bundles");
203221

204-
// TODO: Unify with c'tor for sycl::comile and sycl::build by calling
222+
// TODO: Unify with c'tor for sycl::compile and sycl::build by calling
205223
// sycl::join on vector of kernel_bundles
206224

207225
// The loop below just links each device image separately, not linking any
@@ -213,23 +231,27 @@ class kernel_bundle_impl {
213231
// undefined symbols, then the logic in this loop will need to be changed.
214232
for (const kernel_bundle<bundle_state::object> &ObjectBundle :
215233
ObjectBundles) {
216-
for (const device_image_plain &DeviceImage : ObjectBundle) {
234+
for (const DevImgPlainWithDeps &DeviceImageWithDeps :
235+
getSyclObjImpl(ObjectBundle)->MDeviceImages) {
217236

218237
// Skip images which are not compatible with devices provided
219238
if (std::none_of(MDevices.begin(), MDevices.end(),
220-
[&DeviceImage](const device &Dev) {
221-
return getSyclObjImpl(DeviceImage)
239+
[&DeviceImageWithDeps](const device &Dev) {
240+
return getSyclObjImpl(DeviceImageWithDeps.getMain())
222241
->compatible_with_device(Dev);
223242
}))
224243
continue;
225244

226245
std::vector<device_image_plain> LinkedResults =
227-
detail::ProgramManager::getInstance().link(DeviceImage, MDevices,
228-
PropList);
246+
detail::ProgramManager::getInstance().link(DeviceImageWithDeps,
247+
MDevices, PropList);
229248
MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(),
230249
LinkedResults.end());
250+
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
251+
LinkedResults.begin(), LinkedResults.end());
231252
}
232253
}
254+
removeDuplicateImages();
233255

234256
for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
235257
const KernelBundleImplPtr BundlePtr = getSyclObjImpl(Bundle);
@@ -249,6 +271,7 @@ class kernel_bundle_impl {
249271

250272
MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
251273
MContext, MDevices, KernelIDs, State);
274+
fillUniqueDeviceImages();
252275
}
253276

254277
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
@@ -259,6 +282,7 @@ class kernel_bundle_impl {
259282

260283
MDeviceImages = detail::ProgramManager::getInstance().getSYCLDeviceImages(
261284
MContext, MDevices, Selector, State);
285+
fillUniqueDeviceImages();
262286
}
263287

264288
// C'tor matches sycl::join API
@@ -287,11 +311,10 @@ class kernel_bundle_impl {
287311
Bundle->MDeviceImages.end());
288312
}
289313

290-
std::sort(MDeviceImages.begin(), MDeviceImages.end(),
291-
LessByHash<device_image_plain>{});
314+
fillUniqueDeviceImages();
292315

293316
if (get_bundle_state() == bundle_state::input) {
294-
// Copy spec constants values from the device images to be removed.
317+
// Copy spec constants values from the device images.
295318
auto MergeSpecConstants = [this](const device_image_plain &Img) {
296319
const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl(Img);
297320
const std::map<std::string,
@@ -310,16 +333,9 @@ class kernel_bundle_impl {
310333
SpecConst.second.back().Size);
311334
}
312335
};
313-
std::for_each(MDeviceImages.begin(), MDeviceImages.end(),
314-
MergeSpecConstants);
336+
std::for_each(begin(), end(), MergeSpecConstants);
315337
}
316338

317-
const auto DevImgIt =
318-
std::unique(MDeviceImages.begin(), MDeviceImages.end());
319-
320-
// Remove duplicate device images.
321-
MDeviceImages.erase(DevImgIt, MDeviceImages.end());
322-
323339
for (const detail::KernelBundleImplPtr &Bundle : Bundles) {
324340
for (const std::pair<const std::string, std::vector<unsigned char>>
325341
&SpecConst : Bundle->MSpecConstValues) {
@@ -605,7 +621,7 @@ class kernel_bundle_impl {
605621

606622
assert(MDeviceImages.size() > 0);
607623
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
608-
detail::getSyclObjImpl(MDeviceImages[0]);
624+
detail::getSyclObjImpl(MDeviceImages[0].getMain());
609625
ur_program_handle_t UrProgram = DeviceImageImpl->get_ur_program_ref();
610626
ContextImplPtr ContextImpl = getSyclObjImpl(MContext);
611627
const AdapterPtr &Adapter = ContextImpl->getAdapter();
@@ -634,7 +650,7 @@ class kernel_bundle_impl {
634650
// Collect kernel ids from all device images, then remove duplicates
635651

636652
std::vector<kernel_id> Result;
637-
for (const device_image_plain &DeviceImage : MDeviceImages) {
653+
for (const device_image_plain &DeviceImage : MUniqueDeviceImages) {
638654
const std::vector<kernel_id> &KernelIDs =
639655
getSyclObjImpl(DeviceImage)->get_kernel_ids();
640656

@@ -662,8 +678,9 @@ class kernel_bundle_impl {
662678
// Used to track if any of the candidate images has specialization values
663679
// set.
664680
bool SpecConstsSet = false;
665-
for (auto &DeviceImage : MDeviceImages) {
666-
if (!DeviceImage.has_kernel(KernelID))
681+
for (const DevImgPlainWithDeps &DeviceImageWithDeps : MDeviceImages) {
682+
const device_image_plain &DeviceImage = DeviceImageWithDeps.getMain();
683+
if (!DeviceImageWithDeps.getMain().has_kernel(KernelID))
667684
continue;
668685

669686
const auto DeviceImageImpl = detail::getSyclObjImpl(DeviceImage);
@@ -718,39 +735,38 @@ class kernel_bundle_impl {
718735
}
719736

720737
bool has_kernel(const kernel_id &KernelID) const noexcept {
721-
return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
738+
return std::any_of(begin(), end(),
722739
[&KernelID](const device_image_plain &DeviceImage) {
723740
return DeviceImage.has_kernel(KernelID);
724741
});
725742
}
726743

727744
bool has_kernel(const kernel_id &KernelID, const device &Dev) const noexcept {
728745
return std::any_of(
729-
MDeviceImages.begin(), MDeviceImages.end(),
746+
begin(), end(),
730747
[&KernelID, &Dev](const device_image_plain &DeviceImage) {
731748
return DeviceImage.has_kernel(KernelID, Dev);
732749
});
733750
}
734751

735752
bool contains_specialization_constants() const noexcept {
736753
return std::any_of(
737-
MDeviceImages.begin(), MDeviceImages.end(),
738-
[](const device_image_plain &DeviceImage) {
754+
begin(), end(), [](const device_image_plain &DeviceImage) {
739755
return getSyclObjImpl(DeviceImage)->has_specialization_constants();
740756
});
741757
}
742758

743759
bool native_specialization_constant() const noexcept {
744760
return contains_specialization_constants() &&
745-
std::all_of(MDeviceImages.begin(), MDeviceImages.end(),
761+
std::all_of(begin(), end(),
746762
[](const device_image_plain &DeviceImage) {
747763
return getSyclObjImpl(DeviceImage)
748764
->all_specialization_constant_native();
749765
});
750766
}
751767

752768
bool has_specialization_constant(const char *SpecName) const noexcept {
753-
return std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
769+
return std::any_of(begin(), end(),
754770
[SpecName](const device_image_plain &DeviceImage) {
755771
return getSyclObjImpl(DeviceImage)
756772
->has_specialization_constant(SpecName);
@@ -761,7 +777,7 @@ class kernel_bundle_impl {
761777
const void *Value,
762778
size_t Size) noexcept {
763779
if (has_specialization_constant(SpecName))
764-
for (const device_image_plain &DeviceImage : MDeviceImages)
780+
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
765781
getSyclObjImpl(DeviceImage)
766782
->set_specialization_constant_raw_value(SpecName, Value);
767783
else {
@@ -773,7 +789,7 @@ class kernel_bundle_impl {
773789

774790
void get_specialization_constant_raw_value(const char *SpecName,
775791
void *ValueRet) const noexcept {
776-
for (const device_image_plain &DeviceImage : MDeviceImages)
792+
for (const device_image_plain &DeviceImage : MUniqueDeviceImages)
777793
if (getSyclObjImpl(DeviceImage)->has_specialization_constant(SpecName)) {
778794
getSyclObjImpl(DeviceImage)
779795
->get_specialization_constant_raw_value(SpecName, ValueRet);
@@ -796,21 +812,21 @@ class kernel_bundle_impl {
796812

797813
bool is_specialization_constant_set(const char *SpecName) const noexcept {
798814
bool SetInDevImg =
799-
std::any_of(MDeviceImages.begin(), MDeviceImages.end(),
815+
std::any_of(begin(), end(),
800816
[SpecName](const device_image_plain &DeviceImage) {
801817
return getSyclObjImpl(DeviceImage)
802818
->is_specialization_constant_set(SpecName);
803819
});
804820
return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
805821
}
806822

807-
const device_image_plain *begin() const { return MDeviceImages.data(); }
823+
const device_image_plain *begin() const { return MUniqueDeviceImages.data(); }
808824

809825
const device_image_plain *end() const {
810-
return MDeviceImages.data() + MDeviceImages.size();
826+
return MUniqueDeviceImages.data() + MUniqueDeviceImages.size();
811827
}
812828

813-
size_t size() const noexcept { return MDeviceImages.size(); }
829+
size_t size() const noexcept { return MUniqueDeviceImages.size(); }
814830

815831
bundle_state get_bundle_state() const { return MState; }
816832

@@ -827,7 +843,7 @@ class kernel_bundle_impl {
827843

828844
// First try and get images in current bundle state
829845
const bundle_state BundleState = get_bundle_state();
830-
std::vector<device_image_plain> NewDevImgs =
846+
std::vector<DevImgPlainWithDeps> NewDevImgs =
831847
detail::ProgramManager::getInstance().getSYCLDeviceImages(
832848
MContext, {Dev}, {KernelID}, BundleState);
833849

@@ -836,21 +852,38 @@ class kernel_bundle_impl {
836852
return false;
837853

838854
// Propagate already set specialization constants to the new images
839-
for (device_image_plain &DevImg : NewDevImgs)
840-
for (auto SpecConst : MSpecConstValues)
841-
getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
842-
SpecConst.first.c_str(), SpecConst.second.data());
855+
for (DevImgPlainWithDeps &DevImgWithDeps : NewDevImgs)
856+
for (device_image_plain &DevImg : DevImgWithDeps)
857+
for (auto SpecConst : MSpecConstValues)
858+
getSyclObjImpl(DevImg)->set_specialization_constant_raw_value(
859+
SpecConst.first.c_str(), SpecConst.second.data());
843860

844861
// Add the images to the collection
845862
MDeviceImages.insert(MDeviceImages.end(), NewDevImgs.begin(),
846863
NewDevImgs.end());
864+
removeDuplicateImages();
847865
return true;
848866
}
849867

850868
private:
869+
void fillUniqueDeviceImages() {
870+
assert(MUniqueDeviceImages.empty());
871+
for (const DevImgPlainWithDeps &Imgs : MDeviceImages)
872+
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(), Imgs.begin(),
873+
Imgs.end());
874+
removeDuplicateImages();
875+
}
876+
void removeDuplicateImages() {
877+
std::sort(MUniqueDeviceImages.begin(), MUniqueDeviceImages.end(),
878+
LessByHash<device_image_plain>{});
879+
const auto It =
880+
std::unique(MUniqueDeviceImages.begin(), MUniqueDeviceImages.end());
881+
MUniqueDeviceImages.erase(It, MUniqueDeviceImages.end());
882+
}
851883
context MContext;
852884
std::vector<device> MDevices;
853-
std::vector<device_image_plain> MDeviceImages;
885+
std::vector<DevImgPlainWithDeps> MDeviceImages;
886+
std::vector<device_image_plain> MUniqueDeviceImages;
854887
// This map stores values for specialization constants, that are missing
855888
// from any device image.
856889
SpecConstMapT MSpecConstValues;

0 commit comments

Comments
 (0)