Skip to content

[SYCL][Fusion] Cache JIT compiled fused kernels #8051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions sycl-fusion/jit-compiler/include/Hashing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//==---- Hashing.h - helper for hashes for JIT internal representations ----==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SYCL_FUSION_JIT_COMPILER_HASHING_H
#define SYCL_FUSION_JIT_COMPILER_HASHING_H

#include "Parameter.h"

#include "llvm/ADT/Hashing.h"

#include <tuple>
#include <vector>

namespace jit_compiler {
inline llvm::hash_code hash_value(const ParameterInternalization &P) {
return llvm::hash_combine(P.LocalSize, P.Intern, P.Param);
}

inline llvm::hash_code hash_value(const Parameter &P) {
return llvm::hash_combine(P.ParamIdx, P.KernelIdx);
}

inline llvm::hash_code hash_value(const JITConstant &C) {
return llvm::hash_combine(C.Param, C.Value);
}

inline llvm::hash_code hash_value(const ParameterIdentity &IP) {
return llvm::hash_combine(IP.LHS, IP.RHS);
}
} // namespace jit_compiler

namespace std {
template <typename T> inline llvm::hash_code hash_value(const vector<T> &V) {
return llvm::hash_combine_range(V.begin(), V.end());
}

template <typename... T> struct hash<tuple<T...>> {
size_t operator()(const tuple<T...> &Tuple) const noexcept {
return llvm::hash_value(Tuple);
}
};
} // namespace std

#endif // SYCL_FUSION_JIT_COMPILER_HASHING_H
16 changes: 16 additions & 0 deletions sycl-fusion/jit-compiler/include/JITContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>

#include "Hashing.h"
#include "Kernel.h"
#include "Parameter.h"

Expand All @@ -23,6 +27,10 @@ class LLVMContext;

namespace jit_compiler {

using CacheKeyT =
std::tuple<std::vector<std::string>, ParamIdentList, int,
std::vector<ParameterInternalization>, std::vector<JITConstant>>;

///
/// Wrapper around a SPIR-V binary.
class SPIRVBinary {
Expand Down Expand Up @@ -51,6 +59,10 @@ class JITContext {

SPIRVBinary &emplaceSPIRVBinary(std::string Binary);

std::optional<SYCLKernelInfo> getCacheEntry(CacheKeyT &Identifier) const;

void addCacheEntry(CacheKeyT &Identifier, SYCLKernelInfo &Kernel);

private:
// FIXME: Change this to std::shared_mutex after switching to C++17.
using MutexT = std::shared_timed_mutex;
Expand All @@ -64,6 +76,10 @@ class JITContext {
MutexT BinariesMutex;

std::vector<SPIRVBinary> Binaries;

mutable MutexT CacheMutex;

std::unordered_map<CacheKeyT, SYCLKernelInfo> Cache;
};
} // namespace jit_compiler

Expand Down
2 changes: 2 additions & 0 deletions sycl-fusion/jit-compiler/include/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ namespace option {

struct JITEnableVerbose : public OptionBase<OptionID::VerboseOutput, bool> {};

struct JITEnableCaching : public OptionBase<OptionID::EnableCaching, bool> {};

} // namespace option
} // namespace jit_compiler

Expand Down
15 changes: 15 additions & 0 deletions sycl-fusion/jit-compiler/lib/JITContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,18 @@ SPIRVBinary &JITContext::emplaceSPIRVBinary(std::string Binary) {
Binaries.emplace_back(std::move(Binary));
return Binaries.back();
}

std::optional<SYCLKernelInfo>
JITContext::getCacheEntry(CacheKeyT &Identifier) const {
ReadLockT ReadLock{CacheMutex};
auto Entry = Cache.find(Identifier);
if (Entry != Cache.end()) {
return Entry->second;
}
return {};
}

void JITContext::addCacheEntry(CacheKeyT &Identifier, SYCLKernelInfo &Kernel) {
WriteLockT WriteLock{CacheMutex};
Cache.emplace(Identifier, Kernel);
}
17 changes: 17 additions & 0 deletions sycl-fusion/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ FusionResult KernelFusion::fuseKernels(
// available (on a per-thread basis).
ConfigHelper::setConfig(std::move(JITConfig));

bool CachingEnabled = ConfigHelper::get<option::JITEnableCaching>();
CacheKeyT CacheKey{KernelsToFuse, Identities, BarriersFlags, Internalization,
Constants};
if (CachingEnabled) {
std::optional<SYCLKernelInfo> CachedKernel = JITCtx.getCacheEntry(CacheKey);
if (CachedKernel) {
helper::printDebugMessage("Re-using cached JIT kernel");
return FusionResult{*CachedKernel, /*Cached*/ true};
}
helper::printDebugMessage(
"Compiling new kernel, no suitable cached kernel found");
}

SYCLModuleInfo ModuleInfo;
// Copy the kernel information for the input kernels to the module
// information. We could remove the copy, if we removed the const from the
Expand Down Expand Up @@ -115,5 +128,9 @@ FusionResult KernelFusion::fuseKernels(
FusedBinaryInfo.BinaryStart = SPIRVBin->address();
FusedBinaryInfo.BinarySize = SPIRVBin->size();

if (CachingEnabled) {
JITCtx.addCacheEntry(CacheKey, FusedKernelInfo);
}

return FusionResult{FusedKernelInfo};
}
1 change: 1 addition & 0 deletions sycl/doc/EnvironmentVariables.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ compiler and runtime.
| `SYCL_USM_HOSTPTR_IMPORT` | Integer | Enable by specifying non-zero value. Buffers created with a host pointer will result in host data promotion to USM, improving data transfer performance. To use this feature, also set SYCL_HOST_UNIFIED_MEMORY=1. |
| `SYCL_EAGER_INIT` | Integer | Enable by specifying non-zero value. Tells the SYCL runtime to do as much as possible initialization at objects construction as opposed to doing lazy initialization on the fly. This may mean doing some redundant work at warmup but ensures fastest possible execution on the following hot and reportable paths. It also instructs PI plugins to do the same. Default is "0". |
| `SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE` | See [below](#sycl_reduction_preferred_workgroup_size) | Controls the preferred work-group size of reductions. |
| `SYCL_ENABLE_FUSION_CACHING` | '1' or '0' | Enable ('1') or disable ('0') caching of JIT compilations for kernel fusion. Caching avoids repeatedly running the JIT compilation pipeline if the same sequence of kernels is fused multiple times. Default value is '1'. |

`(*) Note: Any means this environment variable is effective when set to any non-null value.`

Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/config.def
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ CONFIG(SYCL_QUEUE_THREAD_POOL_SIZE, 4, __SYCL_QUEUE_THREAD_POOL_SIZE)
CONFIG(SYCL_RT_WARNING_LEVEL, 4, __SYCL_RT_WARNING_LEVEL)
CONFIG(SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE, 16, __SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE)
CONFIG(ONEAPI_DEVICE_SELECTOR, 1024, __ONEAPI_DEVICE_SELECTOR)
CONFIG(SYCL_ENABLE_FUSION_CACHING, 1, __SYCL_ENABLE_FUSION_CACHING)
28 changes: 28 additions & 0 deletions sycl/source/detail/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,34 @@ template <> class SYCLConfig<SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE> {
}
};

template <> class SYCLConfig<SYCL_ENABLE_FUSION_CACHING> {
using BaseT = SYCLConfigBase<SYCL_ENABLE_FUSION_CACHING>;

public:
static bool get() {
constexpr bool DefaultValue = true;

const char *ValStr = getCachedValue();

if (!ValStr)
return DefaultValue;

return ValStr[0] == '1';
}

static void reset() { (void)getCachedValue(/*ResetCache=*/true); }

static const char *getName() { return BaseT::MConfigName; }

private:
static const char *getCachedValue(bool ResetCache = false) {
static const char *ValStr = BaseT::getRawValue();
if (ResetCache)
ValStr = BaseT::getRawValue();
return ValStr;
}
};

#undef INVALID_CONFIG_EXCEPTION

} // namespace detail
Expand Down
3 changes: 2 additions & 1 deletion sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,8 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
bool DebugEnabled =
detail::SYCLConfig<detail::SYCL_RT_WARNING_LEVEL>::get() > 0;
JITConfig.set<::jit_compiler::option::JITEnableVerbose>(DebugEnabled);
// TODO: Enable caching in a separate PR.
JITConfig.set<::jit_compiler::option::JITEnableCaching>(
detail::SYCLConfig<detail::SYCL_ENABLE_FUSION_CACHING>::get());

auto FusionResult = ::jit_compiler::KernelFusion::fuseKernels(
*MJITContext, std::move(JITConfig), InputKernelInfo, InputKernelNames,
Expand Down