Skip to content

Commit c22e34b

Browse files
authored
[SYCL] Specialization constants: fix scope, allow redefinition and AOT. (#1633)
* [SYCL] Specialization constants: fix scope, allow redifinition and AOT. - Specialization constants are now per-program, as the implemented spec requires. - Program's specialization constant set is added to the program compilation cache key, it gets recompiled on new specialization constant combination. Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 40287bf commit c22e34b

16 files changed

+343
-312
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//==----- device_binary_image.hpp --- SYCL device binary image abstraction -==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#pragma once
9+
10+
#include <CL/sycl/detail/os_util.hpp>
11+
#include <CL/sycl/detail/pi.hpp>
12+
13+
#include <memory>
14+
15+
__SYCL_INLINE_NAMESPACE(cl) {
16+
namespace sycl {
17+
namespace detail {
18+
19+
// SYCL RT wrapper over PI binary image.
20+
class RTDeviceBinaryImage : public pi::DeviceBinaryImage {
21+
public:
22+
RTDeviceBinaryImage(OSModuleHandle ModuleHandle)
23+
: pi::DeviceBinaryImage(), ModuleHandle(ModuleHandle) {}
24+
RTDeviceBinaryImage(pi_device_binary Bin, OSModuleHandle ModuleHandle)
25+
: pi::DeviceBinaryImage(Bin), ModuleHandle(ModuleHandle) {}
26+
OSModuleHandle getOSModuleHandle() const { return ModuleHandle; }
27+
28+
~RTDeviceBinaryImage() override {}
29+
30+
bool supportsSpecConstants() const {
31+
return getFormat() == PI_DEVICE_BINARY_TYPE_SPIRV;
32+
}
33+
34+
const pi_device_binary_struct &getRawData() const { return *get(); }
35+
36+
void print() const override {
37+
pi::DeviceBinaryImage::print();
38+
std::cerr << " OSModuleHandle=" << ModuleHandle << "\n";
39+
}
40+
41+
protected:
42+
OSModuleHandle ModuleHandle;
43+
};
44+
45+
// Dynamically allocated device binary image, which de-allocates its binary
46+
// data in destructor.
47+
class DynRTDeviceBinaryImage : public RTDeviceBinaryImage {
48+
public:
49+
DynRTDeviceBinaryImage(std::unique_ptr<char[]> &&DataPtr, size_t DataSize,
50+
OSModuleHandle M);
51+
~DynRTDeviceBinaryImage() override;
52+
53+
void print() const override {
54+
RTDeviceBinaryImage::print();
55+
std::cerr << " DYNAMICALLY CREATED\n";
56+
}
57+
58+
protected:
59+
std::unique_ptr<char[]> Data;
60+
};
61+
62+
} // namespace detail
63+
} // namespace sycl
64+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/detail/spec_constant_impl.hpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,44 @@
99
#pragma once
1010

1111
#include <CL/sycl/detail/defines.hpp>
12+
#include <CL/sycl/detail/util.hpp>
13+
#include <CL/sycl/stl.hpp>
1214

1315
#include <iostream>
16+
#include <map>
1417

1518
__SYCL_INLINE_NAMESPACE(cl) {
1619
namespace sycl {
1720
namespace detail {
1821

19-
// Represents a specialization constant in SYCL runtime.
22+
// Represents a specialization constant value in SYCL runtime.
2023
class spec_constant_impl {
2124
public:
22-
spec_constant_impl(unsigned int ID) : ID(ID), Size(0), Bytes{0} {}
25+
spec_constant_impl() : Size(0), Bytes{0} {};
2326

24-
spec_constant_impl(unsigned int ID, size_t Size, const void *Val) : ID(ID) {
25-
set(Size, Val);
26-
}
27+
spec_constant_impl(size_t Size, const void *Val) { set(Size, Val); }
2728

2829
void set(size_t Size, const void *Val);
2930

30-
unsigned int getID() const { return ID; }
3131
size_t getSize() const { return Size; }
3232
const unsigned char *getValuePtr() const { return Bytes; }
3333
bool isSet() const { return Size != 0; }
3434

3535
private:
36-
unsigned int ID; // specialization constant's ID (equals to SPIRV ID)
37-
size_t Size; // size of its value
36+
size_t Size; // the size of the spec constant value
3837
// TODO invent more flexible approach to support values of arbitrary type:
3938
unsigned char Bytes[8]; // memory to hold the value bytes
4039
};
4140

4241
std::ostream &operator<<(std::ostream &Out, const spec_constant_impl &V);
4342

43+
// Used to define specialization constant registry. Must be ordered map, since
44+
// the order of entries matters in stableSerializeSpecConstRegistry.
45+
using SpecConstRegistryT = std::map<string_class, spec_constant_impl>;
46+
47+
void stableSerializeSpecConstRegistry(const SpecConstRegistryT &Reg,
48+
SerializedObj &Dst);
49+
4450
} // namespace detail
4551
} // namespace sycl
4652
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/detail/util.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#ifndef __SYCL_DEVICE_ONLY
1212

1313
#include <CL/sycl/detail/defines.hpp>
14+
#include <CL/sycl/stl.hpp>
1415

1516
#include <cstring>
1617
#include <mutex>
@@ -52,12 +53,7 @@ struct CmpCStr {
5253
}
5354
};
5455

55-
// Interface to iterate via C strings.
56-
class CStringIterator {
57-
public:
58-
// Get the next string. Returns next string's pointer or nullptr.
59-
virtual const char *next() = 0;
60-
};
56+
using SerializedObj = sycl::vector_class<unsigned char>;
6157

6258
} // namespace detail
6359
} // namespace sycl

sycl/include/CL/sycl/experimental/spec_constant.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ __SYCL_INLINE_NAMESPACE(cl) {
2424
namespace sycl {
2525
namespace experimental {
2626

27-
class spec_const_error : public compile_program_error {};
27+
class spec_const_error : public compile_program_error {
28+
using compile_program_error::compile_program_error;
29+
};
2830

2931
template <typename T, typename ID = T> class spec_constant {
3032
private:

sycl/source/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ set(SYCL_SOURCES
106106
"detail/common.cpp"
107107
"detail/config.cpp"
108108
"detail/context_impl.cpp"
109+
"detail/device_binary_image.cpp"
109110
"detail/device_impl.cpp"
110111
"detail/error_handling/enqueue_kernel.cpp"
111112
"detail/event_impl.cpp"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//==----- device_binary_image.cpp --- SYCL device binary image abstraction -==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <CL/sycl/detail/pi.hpp>
10+
11+
#include <memory>
12+
13+
#include <CL/sycl/detail/device_binary_image.hpp>
14+
15+
using namespace sycl::detail;
16+
17+
DynRTDeviceBinaryImage::DynRTDeviceBinaryImage(
18+
std::unique_ptr<char[]> &&DataPtr, size_t DataSize, OSModuleHandle M)
19+
: RTDeviceBinaryImage(M) {
20+
Data = std::move(DataPtr);
21+
Bin = new pi_device_binary_struct();
22+
Bin->Version = PI_DEVICE_BINARY_VERSION;
23+
Bin->Kind = PI_DEVICE_BINARY_OFFLOAD_KIND_SYCL;
24+
Bin->DeviceTargetSpec = PI_DEVICE_BINARY_TARGET_UNKNOWN;
25+
Bin->CompileOptions = "";
26+
Bin->LinkOptions = "";
27+
Bin->ManifestStart = nullptr;
28+
Bin->ManifestEnd = nullptr;
29+
Bin->BinaryStart = reinterpret_cast<unsigned char *>(Data.get());
30+
Bin->BinaryEnd = Bin->BinaryStart + DataSize;
31+
Bin->EntriesBegin = nullptr;
32+
Bin->EntriesEnd = nullptr;
33+
Bin->Format = pi::getBinaryImageFormat(Bin->BinaryStart, DataSize);
34+
init(Bin);
35+
}
36+
37+
DynRTDeviceBinaryImage::~DynRTDeviceBinaryImage() {
38+
delete Bin;
39+
Bin = nullptr;
40+
}

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
#pragma once
1010

11+
#include <CL/sycl/detail/common.hpp>
1112
#include <CL/sycl/detail/locked.hpp>
1213
#include <CL/sycl/detail/os_util.hpp>
1314
#include <CL/sycl/detail/pi.hpp>
15+
#include <CL/sycl/detail/util.hpp>
1416
#include <detail/platform_impl.hpp>
1517

1618
#include <atomic>
@@ -51,7 +53,8 @@ class KernelProgramCache {
5153
using PiProgramT = std::remove_pointer<RT::PiProgram>::type;
5254
using PiProgramPtrT = std::atomic<PiProgramT *>;
5355
using ProgramWithBuildStateT = BuildResult<PiProgramT>;
54-
using ProgramCacheT = std::map<OSModuleHandle, ProgramWithBuildStateT>;
56+
using ProgramCacheKeyT = std::pair<SerializedObj, KernelSetId>;
57+
using ProgramCacheT = std::map<ProgramCacheKeyT, ProgramWithBuildStateT>;
5558
using ContextPtr = context_impl *;
5659

5760
using PiKernelT = std::remove_pointer<RT::PiKernel>::type;

sycl/source/detail/program_impl.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ void program_impl::build_with_kernel_name(string_class KernelName,
231231
if (is_cacheable_with_options(BuildOptions)) {
232232
MProgramAndKernelCachingAllowed = true;
233233
MProgram = ProgramManager::getInstance().getBuiltPIProgram(
234-
Module, get_context(), KernelName);
234+
Module, get_context(), KernelName, this);
235235
const detail::plugin &Plugin = getPlugin();
236236
Plugin.call<PiApiKind::piProgramRetain>(MProgram);
237237
} else {
@@ -332,7 +332,6 @@ void program_impl::compile(const string_class &Options) {
332332
check_device_feature_support<info::device::is_compiler_available>(MDevices);
333333
vector_class<RT::PiDevice> Devices(get_pi_devices());
334334
const detail::plugin &Plugin = getPlugin();
335-
ProgramManager::getInstance().flushSpecConstants(MProgram, *MContext);
336335
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piProgramCompile>(
337336
MProgram, Devices.size(), Devices.data(), Options.c_str(), 0, nullptr,
338337
nullptr, nullptr, nullptr);
@@ -351,7 +350,7 @@ void program_impl::build(const string_class &Options) {
351350
check_device_feature_support<info::device::is_compiler_available>(MDevices);
352351
vector_class<RT::PiDevice> Devices(get_pi_devices());
353352
const detail::plugin &Plugin = getPlugin();
354-
ProgramManager::getInstance().flushSpecConstants(MProgram, *MContext);
353+
ProgramManager::getInstance().flushSpecConstants(*this);
355354
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piProgramBuild>(
356355
MProgram, Devices.size(), Devices.data(), Options.c_str(), nullptr,
357356
nullptr);
@@ -398,7 +397,7 @@ RT::PiKernel program_impl::get_pi_kernel(const string_class &KernelName) const {
398397

399398
if (is_cacheable()) {
400399
Kernel = ProgramManager::getInstance().getOrCreateKernel(
401-
MProgramModuleHandle, get_context(), KernelName);
400+
MProgramModuleHandle, get_context(), KernelName, this);
402401
} else {
403402
const detail::plugin &Plugin = getPlugin();
404403
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piKernelCreate>(
@@ -470,11 +469,38 @@ vector_class<device> program_impl::get_info<info::program::devices>() const {
470469

471470
void program_impl::set_spec_constant_impl(const char *Name, const void *ValAddr,
472471
size_t ValSize) {
473-
spec_constant_impl &SC =
474-
ProgramManager::getInstance().resolveSpecConstant(this, Name);
472+
// Reuse cached programs lock as opposed to introducing a new lock.
473+
auto LockGuard = MContext->getKernelProgramCache().acquireCachedPrograms();
474+
spec_constant_impl &SC = SpecConstRegistry[Name];
475475
SC.set(ValSize, ValAddr);
476476
}
477477

478+
void program_impl::flush_spec_constants(const RTDeviceBinaryImage &Img,
479+
RT::PiProgram NativePrg) const {
480+
// iterate via all specialization constants the program's image depends on,
481+
// and set each to current runtime value (if any)
482+
const pi::DeviceBinaryImage::PropertyRange &SCRange = Img.getSpecConstants();
483+
ContextImplPtr Ctx = getSyclObjImpl(get_context());
484+
using SCItTy = pi::DeviceBinaryImage::PropertyRange::ConstIterator;
485+
486+
auto LockGuard = Ctx->getKernelProgramCache().acquireCachedPrograms();
487+
488+
for (SCItTy SCIt : SCRange) {
489+
const char *SCName = (*SCIt)->Name;
490+
auto SCEntry = SpecConstRegistry.find(SCName);
491+
if (SCEntry == SpecConstRegistry.end())
492+
// spec constant has not been set in user code - SPIRV will use default
493+
continue;
494+
const spec_constant_impl &SC = SCEntry->second;
495+
assert(SC.isSet() && "uninitialized spec constant");
496+
pi_device_binary_property SCProp = *SCIt;
497+
pi_uint32 ID = pi::DeviceBinaryProperty(SCProp).asUint32();
498+
NativePrg = NativePrg ? NativePrg : getHandleRef();
499+
Ctx->getPlugin().call<PiApiKind::piextProgramSetSpecializationConstant>(
500+
NativePrg, ID, SC.getSize(), SC.getValuePtr());
501+
}
502+
}
503+
478504
} // namespace detail
479505
} // namespace sycl
480506
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/source/detail/program_impl.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <CL/sycl/context.hpp>
1111
#include <CL/sycl/detail/common_info.hpp>
1212
#include <CL/sycl/detail/kernel_desc.hpp>
13+
#include <CL/sycl/detail/spec_constant_impl.hpp>
1314
#include <CL/sycl/device.hpp>
1415
#include <CL/sycl/program.hpp>
1516
#include <CL/sycl/stl.hpp>
@@ -292,13 +293,31 @@ class program_impl {
292293
void set_spec_constant_impl(const char *Name, const void *ValAddr,
293294
size_t ValSize);
294295

296+
/// Takes current values of specialization constants and "injects" them into
297+
/// the underlying native program program via specialization constant
298+
/// managemment PI APIs. The native program passed as non-null argument
299+
/// overrides the MProgram native program field.
300+
/// \param Img device binary image corresponding to this program, used to
301+
/// resolve spec constant name to SPIRV integer ID
302+
/// \param NativePrg if not null, used as the flush target, otherwise MProgram
303+
/// is used
304+
void flush_spec_constants(const RTDeviceBinaryImage &Img,
305+
RT::PiProgram NativePrg = nullptr) const;
306+
295307
/// Returns the OS module handle this program belongs to. A program belongs to
296308
/// an OS module if it was built from device image(s) belonging to that
297309
/// module.
298310
/// TODO Some programs can be linked from images belonging to different
299311
/// modules. May need a special fake handle for the resulting program.
300312
OSModuleHandle getOSModuleHandle() const { return MProgramModuleHandle; }
301313

314+
void stableSerializeSpecConstRegistry(SerializedObj &Dst) const {
315+
detail::stableSerializeSpecConstRegistry(SpecConstRegistry, Dst);
316+
}
317+
318+
/// Tells whether a specialization constant has been set for this program.
319+
bool hasSetSpecConstants() const { return !SpecConstRegistry.empty(); }
320+
302321
private:
303322
// Deligating Constructor used in Implementation.
304323
program_impl(ContextImplPtr Context, pi_native_handle InteropProgram,
@@ -390,6 +409,12 @@ class program_impl {
390409
string_class MBuildOptions;
391410
OSModuleHandle MProgramModuleHandle = OSUtil::ExeModuleHandle;
392411

412+
// Keeps specialization constant map for this program. Spec constant name
413+
// resolution to actual SPIRV integer ID happens at build time, where the
414+
// device binary image is available. Access is guarded by this context's
415+
// program cache lock.
416+
SpecConstRegistryT SpecConstRegistry;
417+
393418
/// Only allow kernel caching for programs constructed with context only (or
394419
/// device list and context) and built with build_with_kernel_type with
395420
/// default build options

0 commit comments

Comments
 (0)