Skip to content

Commit d54a65e

Browse files
committed
[SYCL] Specialization constants: fix scope, allow redifinition and AOT.
- Specialization constants are now per-program, as the implemented spec requires. - Program's specialization constant set is added to the program compilation cache key, it gets recompiled on new specialization constant combination. Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent b375b04 commit d54a65e

File tree

11 files changed

+212
-229
lines changed

11 files changed

+212
-229
lines changed

sycl/include/CL/sycl/detail/spec_constant_impl.hpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,44 @@
99
#pragma once
1010

1111
#include <CL/sycl/detail/defines.hpp>
12+
#include <CL/sycl/detail/util.hpp>
13+
#include <CL/sycl/stl.hpp>
1214

1315
#include <iostream>
16+
#include <map>
1417

1518
__SYCL_INLINE_NAMESPACE(cl) {
1619
namespace sycl {
1720
namespace detail {
1821

19-
// Represents a specialization constant in SYCL runtime.
22+
// Represents a specialization constant value in SYCL runtime.
2023
class spec_constant_impl {
2124
public:
22-
spec_constant_impl(unsigned int ID) : ID(ID), Size(0), Bytes{0} {}
25+
spec_constant_impl() = default;
2326

24-
spec_constant_impl(unsigned int ID, size_t Size, const void *Val) : ID(ID) {
25-
set(Size, Val);
26-
}
27+
spec_constant_impl(size_t Size, const void *Val) { set(Size, Val); }
2728

2829
void set(size_t Size, const void *Val);
2930

30-
unsigned int getID() const { return ID; }
3131
size_t getSize() const { return Size; }
3232
const unsigned char *getValuePtr() const { return Bytes; }
3333
bool isSet() const { return Size != 0; }
3434

3535
private:
36-
unsigned int ID; // specialization constant's ID (equals to SPIRV ID)
37-
size_t Size; // size of its value
36+
size_t Size; // size of its value
3837
// TODO invent more flexible approach to support values of arbitrary type:
3938
unsigned char Bytes[8]; // memory to hold the value bytes
4039
};
4140

4241
std::ostream &operator<<(std::ostream &Out, const spec_constant_impl &V);
4342

43+
// Used to define specialization constant registry. Must be ordered map, since
44+
// the order of entries matters in stableSerializeSpecConstRegistry.
45+
using SpecConstRegistryT = std::map<string_class, spec_constant_impl>;
46+
47+
void stableSerializeSpecConstRegistry(const SpecConstRegistryT &Reg,
48+
SerializedObj &Dst);
49+
4450
} // namespace detail
4551
} // namespace sycl
4652
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/detail/util.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#ifndef __SYCL_DEVICE_ONLY
1212

1313
#include <CL/sycl/detail/defines.hpp>
14+
#include <CL/sycl/stl.hpp>
1415

1516
#include <cstring>
1617
#include <mutex>
@@ -52,12 +53,7 @@ struct CmpCStr {
5253
}
5354
};
5455

55-
// Interface to iterate via C strings.
56-
class CStringIterator {
57-
public:
58-
// Get the next string. Returns next string's pointer or nullptr.
59-
virtual const char *next() = 0;
60-
};
56+
using SerializedObj = sycl::vector_class<unsigned char>;
6157

6258
} // namespace detail
6359
} // namespace sycl

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
#pragma once
1010

11+
#include <CL/sycl/detail/common.hpp>
1112
#include <CL/sycl/detail/locked.hpp>
1213
#include <CL/sycl/detail/os_util.hpp>
1314
#include <CL/sycl/detail/pi.hpp>
15+
#include <CL/sycl/detail/util.hpp>
1416
#include <detail/platform_impl.hpp>
1517

1618
#include <atomic>
@@ -51,7 +53,8 @@ class KernelProgramCache {
5153
using PiProgramT = std::remove_pointer<RT::PiProgram>::type;
5254
using PiProgramPtrT = std::atomic<PiProgramT *>;
5355
using ProgramWithBuildStateT = BuildResult<PiProgramT>;
54-
using ProgramCacheT = std::map<OSModuleHandle, ProgramWithBuildStateT>;
56+
using ProgramCacheKeyT = std::pair<SerializedObj, KernelSetId>;
57+
using ProgramCacheT = std::map<ProgramCacheKeyT, ProgramWithBuildStateT>;
5558
using ContextPtr = context_impl *;
5659

5760
using PiKernelT = std::remove_pointer<RT::PiKernel>::type;

sycl/source/detail/program_impl.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ void program_impl::build_with_kernel_name(string_class KernelName,
231231
if (is_cacheable_with_options(BuildOptions)) {
232232
MProgramAndKernelCachingAllowed = true;
233233
MProgram = ProgramManager::getInstance().getBuiltPIProgram(
234-
Module, get_context(), KernelName);
234+
Module, get_context(), KernelName, this);
235235
const detail::plugin &Plugin = getPlugin();
236236
Plugin.call<PiApiKind::piProgramRetain>(MProgram);
237237
} else {
@@ -332,7 +332,7 @@ void program_impl::compile(const string_class &Options) {
332332
check_device_feature_support<info::device::is_compiler_available>(MDevices);
333333
vector_class<RT::PiDevice> Devices(get_pi_devices());
334334
const detail::plugin &Plugin = getPlugin();
335-
ProgramManager::getInstance().flushSpecConstants(MProgram, *MContext);
335+
ProgramManager::getInstance().flushSpecConstants(*this);
336336
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piProgramCompile>(
337337
MProgram, Devices.size(), Devices.data(), Options.c_str(), 0, nullptr,
338338
nullptr, nullptr, nullptr);
@@ -351,7 +351,7 @@ void program_impl::build(const string_class &Options) {
351351
check_device_feature_support<info::device::is_compiler_available>(MDevices);
352352
vector_class<RT::PiDevice> Devices(get_pi_devices());
353353
const detail::plugin &Plugin = getPlugin();
354-
ProgramManager::getInstance().flushSpecConstants(MProgram, *MContext);
354+
ProgramManager::getInstance().flushSpecConstants(*this);
355355
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piProgramBuild>(
356356
MProgram, Devices.size(), Devices.data(), Options.c_str(), nullptr,
357357
nullptr);
@@ -398,7 +398,7 @@ RT::PiKernel program_impl::get_pi_kernel(const string_class &KernelName) const {
398398

399399
if (is_cacheable()) {
400400
Kernel = ProgramManager::getInstance().getOrCreateKernel(
401-
MProgramModuleHandle, get_context(), KernelName);
401+
MProgramModuleHandle, get_context(), KernelName, this);
402402
} else {
403403
const detail::plugin &Plugin = getPlugin();
404404
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piKernelCreate>(
@@ -470,11 +470,38 @@ vector_class<device> program_impl::get_info<info::program::devices>() const {
470470

471471
void program_impl::set_spec_constant_impl(const char *Name, const void *ValAddr,
472472
size_t ValSize) {
473-
spec_constant_impl &SC =
474-
ProgramManager::getInstance().resolveSpecConstant(this, Name);
473+
// Reuse cached programs lock as opposed to introducing a new lock.
474+
auto LockGuard = MContext->getKernelProgramCache().acquireCachedPrograms();
475+
spec_constant_impl &SC = SpecConstRegistry[Name];
475476
SC.set(ValSize, ValAddr);
476477
}
477478

479+
void program_impl::flush_spec_constants(const RTDeviceBinaryImage &Img,
480+
RT::PiProgram NativePrg) const {
481+
// iterate via all specialization constants the program's image depends on,
482+
// and set each to current runtime value (if any)
483+
const pi::DeviceBinaryImage::PropertyRange &SCRange = Img.getSpecConstants();
484+
ContextImplPtr Ctx = getSyclObjImpl(get_context());
485+
using SCItTy = pi::DeviceBinaryImage::PropertyRange::ConstIterator;
486+
487+
auto LockGuard = Ctx->getKernelProgramCache().acquireCachedPrograms();
488+
489+
for (SCItTy SCIt = SCRange.begin(); SCIt != SCRange.end(); SCIt++) {
490+
const char *SCName = (*SCIt)->Name;
491+
auto SCEntry = SpecConstRegistry.find(SCName);
492+
if (SCEntry == SpecConstRegistry.end())
493+
// spec constant has not been set in user code - SPIRV will use default
494+
continue;
495+
const spec_constant_impl &SC = SCEntry->second;
496+
assert(SC.isSet() && "uninitialized spec constant");
497+
pi_device_binary_property SCProp = *SCIt;
498+
pi_uint32 ID = pi::DeviceBinaryProperty(SCProp).asUint32();
499+
NativePrg = NativePrg ? NativePrg : pi::cast<pi::PiProgram>(get());
500+
Ctx->getPlugin().call<PiApiKind::piextProgramSetSpecializationConstant>(
501+
NativePrg, ID, SC.getSize(), SC.getValuePtr());
502+
}
503+
}
504+
478505
} // namespace detail
479506
} // namespace sycl
480507
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/source/detail/program_impl.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <CL/sycl/context.hpp>
1111
#include <CL/sycl/detail/common_info.hpp>
1212
#include <CL/sycl/detail/kernel_desc.hpp>
13+
#include <CL/sycl/detail/spec_constant_impl.hpp>
1314
#include <CL/sycl/device.hpp>
1415
#include <CL/sycl/program.hpp>
1516
#include <CL/sycl/stl.hpp>
@@ -292,13 +293,28 @@ class program_impl {
292293
void set_spec_constant_impl(const char *Name, const void *ValAddr,
293294
size_t ValSize);
294295

296+
/// Takes current values of specialization constants and "injects" them into
297+
/// the underlying native program program via specialization constant
298+
/// managemment PI APIs. The native program passed as non-null argument
299+
/// overrides the MProgram native program field.
300+
/// \param Img device binary image corresponding to this program, used to
301+
/// resolve spec constant name to SPIRV integer ID
302+
/// \param NativePrg if not null, used as the flush target, otherwise MProgram
303+
/// is used
304+
void flush_spec_constants(const RTDeviceBinaryImage &Img,
305+
RT::PiProgram NativePrg = nullptr) const;
306+
295307
/// Returns the OS module handle this program belongs to. A program belongs to
296308
/// an OS module if it was built from device image(s) belonging to that
297309
/// module.
298310
/// TODO Some programs can be linked from images belonging to different
299311
/// modules. May need a special fake handle for the resulting program.
300312
OSModuleHandle getOSModuleHandle() const { return MProgramModuleHandle; }
301313

314+
void stableSerializeSpecConstRegistry(SerializedObj &Dst) const {
315+
detail::stableSerializeSpecConstRegistry(SpecConstRegistry, Dst);
316+
}
317+
302318
private:
303319
// Deligating Constructor used in Implementation.
304320
program_impl(ContextImplPtr Context, pi_native_handle InteropProgram,
@@ -390,6 +406,12 @@ class program_impl {
390406
string_class MBuildOptions;
391407
OSModuleHandle MProgramModuleHandle = OSUtil::ExeModuleHandle;
392408

409+
// Keeps specialization constant map for this program. Spec constant name
410+
// resolution to actual SPIRV integer ID happens at build time, where the
411+
// device binary image is available. Access is guarded by this context's
412+
// program cache lock.
413+
SpecConstRegistryT SpecConstRegistry;
414+
393415
/// Only allow kernel caching for programs constructed with context only (or
394416
/// device list and context) and built with build_with_kernel_type with
395417
/// default build options

0 commit comments

Comments
 (0)