Skip to content

Commit ba0da8b

Browse files
author
Alexander Batashev
authored
[SYCL] Fix spec constants in object and executable kernel bundles (#3853)
Allow set specialization constants values to be retrieved from kernel_bundles in object or executable states.
1 parent dbc6b57 commit ba0da8b

File tree

4 files changed

+51
-17
lines changed

4 files changed

+51
-17
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ namespace detail {
3636
// of specialization constants for it
3737
class device_image_impl {
3838
public:
39+
// The struct maps specialization ID to offset in the binary blob where value
40+
// for this spec const should be.
41+
struct SpecConstDescT {
42+
unsigned int ID = 0;
43+
unsigned int CompositeOffset = 0;
44+
unsigned int Size = 0;
45+
unsigned int BlobOffset = 0;
46+
bool IsSet = false;
47+
};
48+
49+
using SpecConstMapT = std::map<std::string, std::vector<SpecConstDescT>>;
50+
3951
device_image_impl(const RTDeviceBinaryImage *BinImage, context Context,
4052
std::vector<device> Devices, bundle_state State,
4153
std::vector<kernel_id> KernelIDs, RT::PiProgram Program)
@@ -45,6 +57,16 @@ class device_image_impl {
4557
updateSpecConstSymMap();
4658
}
4759

60+
device_image_impl(const RTDeviceBinaryImage *BinImage, context Context,
61+
std::vector<device> Devices, bundle_state State,
62+
std::vector<kernel_id> KernelIDs, RT::PiProgram Program,
63+
const SpecConstMapT &SpecConstMap,
64+
const std::vector<unsigned char> &SpecConstsBlob)
65+
: MBinImage(BinImage), MContext(std::move(Context)),
66+
MDevices(std::move(Devices)), MState(State), MProgram(Program),
67+
MKernelIDs(std::move(KernelIDs)), MSpecConstsBlob(SpecConstsBlob),
68+
MSpecConstSymMap(SpecConstMap) {}
69+
4870
bool has_kernel(const kernel_id &KernelIDCand) const noexcept {
4971
return std::binary_search(MKernelIDs.begin(), MKernelIDs.end(),
5072
KernelIDCand, LessByNameComp{});
@@ -76,16 +98,6 @@ class device_image_impl {
7698
return false;
7799
}
78100

79-
// The struct maps specialization ID to offset in the binary blob where value
80-
// for this spec const should be.
81-
struct SpecConstDescT {
82-
unsigned int ID = 0;
83-
unsigned int CompositeOffset = 0;
84-
unsigned int Size = 0;
85-
unsigned int BlobOffset = 0;
86-
bool IsSet = false;
87-
};
88-
89101
bool has_specialization_constant(const char *SpecName) const noexcept {
90102
// Lock the mutex to prevent when one thread in the middle of writing a
91103
// new value while another thread is reading the value to pass it to
@@ -182,8 +194,7 @@ class device_image_impl {
182194
return MSpecConstsBuffer;
183195
}
184196

185-
const std::map<std::string, std::vector<SpecConstDescT>> &
186-
get_spec_const_data_ref() const noexcept {
197+
const SpecConstMapT &get_spec_const_data_ref() const noexcept {
187198
return MSpecConstSymMap;
188199
}
189200

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ static bool checkAllDevicesHaveAspect(const std::vector<device> &Devices,
5555
// objects.
5656
class kernel_bundle_impl {
5757

58+
using SpecConstMapT = std::map<std::string, std::vector<unsigned char>>;
59+
5860
void common_ctor_checks(bundle_state State) {
5961
const bool AllDevicesInTheContext =
6062
checkAllDevicesAreInContext(MDevices, MContext);
@@ -105,6 +107,8 @@ class kernel_bundle_impl {
105107
bundle_state TargetState)
106108
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)) {
107109

110+
MSpecConstValues = getSyclObjImpl(InputBundle)->get_spec_const_map_ref();
111+
108112
const std::vector<device> &InputBundleDevices =
109113
getSyclObjImpl(InputBundle)->get_devices();
110114
const bool AllDevsAssociatedWithInputBundle =
@@ -207,6 +211,14 @@ class kernel_bundle_impl {
207211

208212
MDeviceImages = detail::ProgramManager::getInstance().link(
209213
std::move(DeviceImages), MDevices, PropList);
214+
215+
for (const kernel_bundle<bundle_state::object> &Bundle : ObjectBundles) {
216+
const KernelBundleImplPtr BundlePtr = getSyclObjImpl(Bundle);
217+
for (const std::pair<const std::string, std::vector<unsigned char>>
218+
&SpecConst : BundlePtr->MSpecConstValues) {
219+
MSpecConstValues[SpecConst.first] = SpecConst.second;
220+
}
221+
}
210222
}
211223

212224
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
@@ -258,8 +270,6 @@ class kernel_bundle_impl {
258270

259271
std::sort(MDeviceImages.begin(), MDeviceImages.end(),
260272
LessByHash<device_image_plain>{});
261-
const auto DevImgIt =
262-
std::unique(MDeviceImages.begin(), MDeviceImages.end());
263273

264274
if (get_bundle_state() == bundle_state::input) {
265275
// Copy spec constants values from the device images to be removed.
@@ -285,6 +295,9 @@ class kernel_bundle_impl {
285295
MergeSpecConstants);
286296
}
287297

298+
const auto DevImgIt =
299+
std::unique(MDeviceImages.begin(), MDeviceImages.end());
300+
288301
// Remove duplicate device images.
289302
MDeviceImages.erase(DevImgIt, MDeviceImages.end());
290303

@@ -459,13 +472,17 @@ class kernel_bundle_impl {
459472
: detail::getSyclObjImpl(MDeviceImages[0])->get_state();
460473
}
461474

475+
const SpecConstMapT &get_spec_const_map_ref() const noexcept {
476+
return MSpecConstValues;
477+
}
478+
462479
private:
463480
context MContext;
464481
std::vector<device> MDevices;
465482
std::vector<device_image_plain> MDeviceImages;
466483
// This map stores values for specialization constants, that are missing
467484
// from any device image.
468-
std::map<std::string, std::vector<unsigned char>> MSpecConstValues;
485+
SpecConstMapT MSpecConstValues;
469486
};
470487

471488
} // namespace detail

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,9 @@ ProgramManager::compile(const device_image_plain &DeviceImage,
14051405

14061406
DeviceImageImplPtr ObjectImpl = std::make_shared<detail::device_image_impl>(
14071407
InputImpl->get_bin_image_ref(), InputImpl->get_context(), Devs,
1408-
bundle_state::object, InputImpl->get_kernel_ids_ref(), Prog);
1408+
bundle_state::object, InputImpl->get_kernel_ids_ref(), Prog,
1409+
InputImpl->get_spec_const_data_ref(),
1410+
InputImpl->get_spec_const_blob_ref());
14091411

14101412
std::vector<pi_device> PIDevices;
14111413
PIDevices.reserve(Devs.size());
@@ -1652,7 +1654,9 @@ device_image_plain ProgramManager::build(const device_image_plain &DeviceImage,
16521654

16531655
DeviceImageImplPtr ExecImpl = std::make_shared<detail::device_image_impl>(
16541656
InputImpl->get_bin_image_ref(), Context, Devs, bundle_state::executable,
1655-
InputImpl->get_kernel_ids_ref(), ResProgram);
1657+
InputImpl->get_kernel_ids_ref(), ResProgram,
1658+
InputImpl->get_spec_const_data_ref(),
1659+
InputImpl->get_spec_const_blob_ref());
16561660

16571661
return createSyclObjFromImpl<device_image_plain>(ExecImpl);
16581662
}

sycl/test/on-device/basic_tests/specialization_constants/host_apis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ int main() {
5858
KernelBundle.set_specialization_constant<SpecConst2>(1);
5959
{
6060
auto ExecBundle = sycl::build(KernelBundle);
61+
assert(ExecBundle.get_specialization_constant<SpecConst1>() == 1);
62+
assert(ExecBundle.get_specialization_constant<SpecConst2>() == 1);
6163
sycl::buffer<int, 1> Buf{sycl::range{1}};
6264
Q.submit([&](sycl::handler &CGH) {
6365
CGH.use_kernel_bundle(ExecBundle);

0 commit comments

Comments
 (0)