Skip to content

[SYCL] Specialization constants: fix scope, allow redefinition and AOT. #1633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions sycl/include/CL/sycl/detail/device_binary_image.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//==----- device_binary_image.hpp --- SYCL device binary image abstraction -==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once

#include <CL/sycl/detail/os_util.hpp>
#include <CL/sycl/detail/pi.hpp>

#include <memory>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace detail {

// SYCL RT wrapper over PI binary image.
class RTDeviceBinaryImage : public pi::DeviceBinaryImage {
public:
RTDeviceBinaryImage(OSModuleHandle ModuleHandle)
: pi::DeviceBinaryImage(), ModuleHandle(ModuleHandle) {}
RTDeviceBinaryImage(pi_device_binary Bin, OSModuleHandle ModuleHandle)
: pi::DeviceBinaryImage(Bin), ModuleHandle(ModuleHandle) {}
OSModuleHandle getOSModuleHandle() const { return ModuleHandle; }

~RTDeviceBinaryImage() override {}

bool supportsSpecConstants() const {
return getFormat() == PI_DEVICE_BINARY_TYPE_SPIRV;
}

const pi_device_binary_struct &getRawData() const { return *get(); }

void print() const override {
pi::DeviceBinaryImage::print();
std::cerr << " OSModuleHandle=" << ModuleHandle << "\n";
}

protected:
OSModuleHandle ModuleHandle;
};

// Dynamically allocated device binary image, which de-allocates its binary
// data in destructor.
class DynRTDeviceBinaryImage : public RTDeviceBinaryImage {
public:
DynRTDeviceBinaryImage(std::unique_ptr<char[]> &&DataPtr, size_t DataSize,
OSModuleHandle M);
~DynRTDeviceBinaryImage() override;

void print() const override {
RTDeviceBinaryImage::print();
std::cerr << " DYNAMICALLY CREATED\n";
}

protected:
std::unique_ptr<char[]> Data;
};

} // namespace detail
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
22 changes: 14 additions & 8 deletions sycl/include/CL/sycl/detail/spec_constant_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,44 @@
#pragma once

#include <CL/sycl/detail/defines.hpp>
#include <CL/sycl/detail/util.hpp>
#include <CL/sycl/stl.hpp>

#include <iostream>
#include <map>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace detail {

// Represents a specialization constant in SYCL runtime.
// Represents a specialization constant value in SYCL runtime.
class spec_constant_impl {
public:
spec_constant_impl(unsigned int ID) : ID(ID), Size(0), Bytes{0} {}
spec_constant_impl() : Size(0), Bytes{0} {};

spec_constant_impl(unsigned int ID, size_t Size, const void *Val) : ID(ID) {
set(Size, Val);
}
spec_constant_impl(size_t Size, const void *Val) { set(Size, Val); }

void set(size_t Size, const void *Val);

unsigned int getID() const { return ID; }
size_t getSize() const { return Size; }
const unsigned char *getValuePtr() const { return Bytes; }
bool isSet() const { return Size != 0; }

private:
unsigned int ID; // specialization constant's ID (equals to SPIRV ID)
size_t Size; // size of its value
size_t Size; // the size of the spec constant value
// TODO invent more flexible approach to support values of arbitrary type:
unsigned char Bytes[8]; // memory to hold the value bytes
};

std::ostream &operator<<(std::ostream &Out, const spec_constant_impl &V);

// Used to define specialization constant registry. Must be ordered map, since
// the order of entries matters in stableSerializeSpecConstRegistry.
using SpecConstRegistryT = std::map<string_class, spec_constant_impl>;

void stableSerializeSpecConstRegistry(const SpecConstRegistryT &Reg,
SerializedObj &Dst);

} // namespace detail
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
8 changes: 2 additions & 6 deletions sycl/include/CL/sycl/detail/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#ifndef __SYCL_DEVICE_ONLY

#include <CL/sycl/detail/defines.hpp>
#include <CL/sycl/stl.hpp>

#include <cstring>
#include <mutex>
Expand Down Expand Up @@ -52,12 +53,7 @@ struct CmpCStr {
}
};

// Interface to iterate via C strings.
class CStringIterator {
public:
// Get the next string. Returns next string's pointer or nullptr.
virtual const char *next() = 0;
};
using SerializedObj = sycl::vector_class<unsigned char>;

} // namespace detail
} // namespace sycl
Expand Down
4 changes: 3 additions & 1 deletion sycl/include/CL/sycl/experimental/spec_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ __SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace experimental {

class spec_const_error : public compile_program_error {};
class spec_const_error : public compile_program_error {
using compile_program_error::compile_program_error;
};

template <typename T, typename ID = T> class spec_constant {
private:
Expand Down
1 change: 1 addition & 0 deletions sycl/source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ set(SYCL_SOURCES
"detail/common.cpp"
"detail/config.cpp"
"detail/context_impl.cpp"
"detail/device_binary_image.cpp"
"detail/device_impl.cpp"
"detail/error_handling/enqueue_kernel.cpp"
"detail/event_impl.cpp"
Expand Down
40 changes: 40 additions & 0 deletions sycl/source/detail/device_binary_image.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//==----- device_binary_image.cpp --- SYCL device binary image abstraction -==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <CL/sycl/detail/pi.hpp>

#include <memory>

#include <CL/sycl/detail/device_binary_image.hpp>

using namespace sycl::detail;

DynRTDeviceBinaryImage::DynRTDeviceBinaryImage(
std::unique_ptr<char[]> &&DataPtr, size_t DataSize, OSModuleHandle M)
: RTDeviceBinaryImage(M) {
Data = std::move(DataPtr);
Bin = new pi_device_binary_struct();
Bin->Version = PI_DEVICE_BINARY_VERSION;
Bin->Kind = PI_DEVICE_BINARY_OFFLOAD_KIND_SYCL;
Bin->DeviceTargetSpec = PI_DEVICE_BINARY_TARGET_UNKNOWN;
Bin->CompileOptions = "";
Bin->LinkOptions = "";
Bin->ManifestStart = nullptr;
Bin->ManifestEnd = nullptr;
Bin->BinaryStart = reinterpret_cast<unsigned char *>(Data.get());
Bin->BinaryEnd = Bin->BinaryStart + DataSize;
Bin->EntriesBegin = nullptr;
Bin->EntriesEnd = nullptr;
Bin->Format = pi::getBinaryImageFormat(Bin->BinaryStart, DataSize);
init(Bin);
}

DynRTDeviceBinaryImage::~DynRTDeviceBinaryImage() {
delete Bin;
Bin = nullptr;
}
5 changes: 4 additions & 1 deletion sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

#pragma once

#include <CL/sycl/detail/common.hpp>
#include <CL/sycl/detail/locked.hpp>
#include <CL/sycl/detail/os_util.hpp>
#include <CL/sycl/detail/pi.hpp>
#include <CL/sycl/detail/util.hpp>
#include <detail/platform_impl.hpp>

#include <atomic>
Expand Down Expand Up @@ -51,7 +53,8 @@ class KernelProgramCache {
using PiProgramT = std::remove_pointer<RT::PiProgram>::type;
using PiProgramPtrT = std::atomic<PiProgramT *>;
using ProgramWithBuildStateT = BuildResult<PiProgramT>;
using ProgramCacheT = std::map<OSModuleHandle, ProgramWithBuildStateT>;
using ProgramCacheKeyT = std::pair<SerializedObj, KernelSetId>;
using ProgramCacheT = std::map<ProgramCacheKeyT, ProgramWithBuildStateT>;
using ContextPtr = context_impl *;

using PiKernelT = std::remove_pointer<RT::PiKernel>::type;
Expand Down
38 changes: 32 additions & 6 deletions sycl/source/detail/program_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void program_impl::build_with_kernel_name(string_class KernelName,
if (is_cacheable_with_options(BuildOptions)) {
MProgramAndKernelCachingAllowed = true;
MProgram = ProgramManager::getInstance().getBuiltPIProgram(
Module, get_context(), KernelName);
Module, get_context(), KernelName, this);
const detail::plugin &Plugin = getPlugin();
Plugin.call<PiApiKind::piProgramRetain>(MProgram);
} else {
Expand Down Expand Up @@ -332,7 +332,6 @@ void program_impl::compile(const string_class &Options) {
check_device_feature_support<info::device::is_compiler_available>(MDevices);
vector_class<RT::PiDevice> Devices(get_pi_devices());
const detail::plugin &Plugin = getPlugin();
ProgramManager::getInstance().flushSpecConstants(MProgram, *MContext);
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piProgramCompile>(
MProgram, Devices.size(), Devices.data(), Options.c_str(), 0, nullptr,
nullptr, nullptr, nullptr);
Expand All @@ -351,7 +350,7 @@ void program_impl::build(const string_class &Options) {
check_device_feature_support<info::device::is_compiler_available>(MDevices);
vector_class<RT::PiDevice> Devices(get_pi_devices());
const detail::plugin &Plugin = getPlugin();
ProgramManager::getInstance().flushSpecConstants(MProgram, *MContext);
ProgramManager::getInstance().flushSpecConstants(*this);
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piProgramBuild>(
MProgram, Devices.size(), Devices.data(), Options.c_str(), nullptr,
nullptr);
Expand Down Expand Up @@ -398,7 +397,7 @@ RT::PiKernel program_impl::get_pi_kernel(const string_class &KernelName) const {

if (is_cacheable()) {
Kernel = ProgramManager::getInstance().getOrCreateKernel(
MProgramModuleHandle, get_context(), KernelName);
MProgramModuleHandle, get_context(), KernelName, this);
} else {
const detail::plugin &Plugin = getPlugin();
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piKernelCreate>(
Expand Down Expand Up @@ -470,11 +469,38 @@ vector_class<device> program_impl::get_info<info::program::devices>() const {

void program_impl::set_spec_constant_impl(const char *Name, const void *ValAddr,
size_t ValSize) {
spec_constant_impl &SC =
ProgramManager::getInstance().resolveSpecConstant(this, Name);
// Reuse cached programs lock as opposed to introducing a new lock.
auto LockGuard = MContext->getKernelProgramCache().acquireCachedPrograms();
spec_constant_impl &SC = SpecConstRegistry[Name];
SC.set(ValSize, ValAddr);
}

void program_impl::flush_spec_constants(const RTDeviceBinaryImage &Img,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest moving functionality of this function to program_manager.cpp:991 to avoid dealing with RTDeviceBeinaryImage outside of program_manager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the problem with dealing with RTDeviceBeinaryImage outside of program_manager?
Or do you mean incurred dependence on program_manager.h everywhere RTDeviceBeinaryImage is used? That can be fixed by extracting BinaryImage infra into a separate header.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice if program_manager abstracts away low level image details from other parts of the SYCL.
Is there any reason for this function to be a method of program class?
This comment doesn't block PR approval.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spec constants are now bound to a program_impl, this is the main reason. But I agree that it is better keep as few inter-dependencies as possible. Anyway, upcoming modules implementation will require significant refactoring in this area - let's revisit this design question as a part of modules design, not to refactor twice. Do you agree?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

RT::PiProgram NativePrg) const {
// iterate via all specialization constants the program's image depends on,
// and set each to current runtime value (if any)
const pi::DeviceBinaryImage::PropertyRange &SCRange = Img.getSpecConstants();
ContextImplPtr Ctx = getSyclObjImpl(get_context());
using SCItTy = pi::DeviceBinaryImage::PropertyRange::ConstIterator;

auto LockGuard = Ctx->getKernelProgramCache().acquireCachedPrograms();

for (SCItTy SCIt : SCRange) {
const char *SCName = (*SCIt)->Name;
auto SCEntry = SpecConstRegistry.find(SCName);
if (SCEntry == SpecConstRegistry.end())
// spec constant has not been set in user code - SPIRV will use default
continue;
const spec_constant_impl &SC = SCEntry->second;
assert(SC.isSet() && "uninitialized spec constant");
pi_device_binary_property SCProp = *SCIt;
pi_uint32 ID = pi::DeviceBinaryProperty(SCProp).asUint32();
NativePrg = NativePrg ? NativePrg : getHandleRef();
Ctx->getPlugin().call<PiApiKind::piextProgramSetSpecializationConstant>(
NativePrg, ID, SC.getSize(), SC.getValuePtr());
}
}

} // namespace detail
} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
25 changes: 25 additions & 0 deletions sycl/source/detail/program_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <CL/sycl/context.hpp>
#include <CL/sycl/detail/common_info.hpp>
#include <CL/sycl/detail/kernel_desc.hpp>
#include <CL/sycl/detail/spec_constant_impl.hpp>
#include <CL/sycl/device.hpp>
#include <CL/sycl/program.hpp>
#include <CL/sycl/stl.hpp>
Expand Down Expand Up @@ -292,13 +293,31 @@ class program_impl {
void set_spec_constant_impl(const char *Name, const void *ValAddr,
size_t ValSize);

/// Takes current values of specialization constants and "injects" them into
/// the underlying native program program via specialization constant
/// managemment PI APIs. The native program passed as non-null argument
/// overrides the MProgram native program field.
/// \param Img device binary image corresponding to this program, used to
/// resolve spec constant name to SPIRV integer ID
/// \param NativePrg if not null, used as the flush target, otherwise MProgram
/// is used
void flush_spec_constants(const RTDeviceBinaryImage &Img,
RT::PiProgram NativePrg = nullptr) const;

/// Returns the OS module handle this program belongs to. A program belongs to
/// an OS module if it was built from device image(s) belonging to that
/// module.
/// TODO Some programs can be linked from images belonging to different
/// modules. May need a special fake handle for the resulting program.
OSModuleHandle getOSModuleHandle() const { return MProgramModuleHandle; }

void stableSerializeSpecConstRegistry(SerializedObj &Dst) const {
detail::stableSerializeSpecConstRegistry(SpecConstRegistry, Dst);
}

/// Tells whether a specialization constant has been set for this program.
bool hasSetSpecConstants() const { return !SpecConstRegistry.empty(); }

private:
// Deligating Constructor used in Implementation.
program_impl(ContextImplPtr Context, pi_native_handle InteropProgram,
Expand Down Expand Up @@ -390,6 +409,12 @@ class program_impl {
string_class MBuildOptions;
OSModuleHandle MProgramModuleHandle = OSUtil::ExeModuleHandle;

// Keeps specialization constant map for this program. Spec constant name
// resolution to actual SPIRV integer ID happens at build time, where the
// device binary image is available. Access is guarded by this context's
// program cache lock.
SpecConstRegistryT SpecConstRegistry;

/// Only allow kernel caching for programs constructed with context only (or
/// device list and context) and built with build_with_kernel_type with
/// default build options
Expand Down
Loading