Skip to content

Commit 21a8981

Browse files
committed
[SYCL][Fusion] Cache JIT compiled fused kernels
Signed-off-by: Lukas Sommer <[email protected]>
1 parent 4afeea5 commit 21a8981

File tree

9 files changed

+131
-1
lines changed

9 files changed

+131
-1
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//==---- Hashing.h - helper for hashes for JIT internal representations ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SYCL_FUSION_JIT_COMPILER_HASHING_H
10+
#define SYCL_FUSION_JIT_COMPILER_HASHING_H
11+
12+
#include "Parameter.h"
13+
14+
#include "llvm/ADT/Hashing.h"
15+
16+
#include <tuple>
17+
#include <vector>
18+
19+
namespace jit_compiler {
20+
inline llvm::hash_code hash_value(const ParameterInternalization &P) {
21+
return llvm::hash_combine(P.LocalSize, P.Intern, P.Param);
22+
}
23+
24+
inline llvm::hash_code hash_value(const Parameter &P) {
25+
return llvm::hash_combine(P.ParamIdx, P.KernelIdx);
26+
}
27+
28+
inline llvm::hash_code hash_value(const JITConstant &C) {
29+
return llvm::hash_combine(C.Param, C.Value);
30+
}
31+
32+
inline llvm::hash_code hash_value(const ParameterIdentity &IP) {
33+
return llvm::hash_combine(IP.LHS, IP.RHS);
34+
}
35+
} // namespace jit_compiler
36+
37+
namespace std {
38+
template <typename T> inline llvm::hash_code hash_value(const vector<T> &V) {
39+
return llvm::hash_combine_range(V.begin(), V.end());
40+
}
41+
42+
template <typename... T> struct hash<tuple<T...>> {
43+
size_t operator()(const tuple<T...> &Tuple) const noexcept {
44+
return llvm::hash_value(Tuple);
45+
}
46+
};
47+
} // namespace std
48+
49+
#endif // SYCL_FUSION_JIT_COMPILER_HASHING_H

sycl-fusion/jit-compiler/include/JITContext.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@
1212
#include <memory>
1313
#include <mutex>
1414
#include <shared_mutex>
15+
#include <string>
16+
#include <tuple>
1517
#include <unordered_map>
18+
#include <vector>
1619

20+
#include "Hashing.h"
1721
#include "Kernel.h"
1822
#include "Parameter.h"
1923

@@ -23,6 +27,10 @@ class LLVMContext;
2327

2428
namespace jit_compiler {
2529

30+
using CacheKeyT =
31+
std::tuple<std::vector<std::string>, ParamIdentList, int,
32+
std::vector<ParameterInternalization>, std::vector<JITConstant>>;
33+
2634
///
2735
/// Wrapper around a SPIR-V binary.
2836
class SPIRVBinary {
@@ -51,6 +59,10 @@ class JITContext {
5159

5260
SPIRVBinary &emplaceSPIRVBinary(std::string Binary);
5361

62+
std::optional<SYCLKernelInfo> getCacheEntry(CacheKeyT &Identifier) const;
63+
64+
void addCacheEntry(CacheKeyT &Identifier, SYCLKernelInfo &Kernel);
65+
5466
private:
5567
// FIXME: Change this to std::shared_mutex after switching to C++17.
5668
using MutexT = std::shared_timed_mutex;
@@ -64,6 +76,10 @@ class JITContext {
6476
MutexT BinariesMutex;
6577

6678
std::vector<SPIRVBinary> Binaries;
79+
80+
mutable MutexT CacheMutex;
81+
82+
std::unordered_map<CacheKeyT, SYCLKernelInfo> Cache;
6783
};
6884
} // namespace jit_compiler
6985

sycl-fusion/jit-compiler/include/Options.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ namespace option {
7676

7777
struct JITEnableVerbose : public OptionBase<OptionID::VerboseOutput, bool> {};
7878

79+
struct JITEnableCaching : public OptionBase<OptionID::EnableCaching, bool> {};
80+
7981
} // namespace option
8082
} // namespace jit_compiler
8183

sycl-fusion/jit-compiler/lib/JITContext.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,18 @@ SPIRVBinary &JITContext::emplaceSPIRVBinary(std::string Binary) {
3333
Binaries.emplace_back(std::move(Binary));
3434
return Binaries.back();
3535
}
36+
37+
std::optional<SYCLKernelInfo>
38+
JITContext::getCacheEntry(CacheKeyT &Identifier) const {
39+
ReadLockT ReadLock{CacheMutex};
40+
auto Entry = Cache.find(Identifier);
41+
if (Entry != Cache.end()) {
42+
return Entry->second;
43+
}
44+
return {};
45+
}
46+
47+
void JITContext::addCacheEntry(CacheKeyT &Identifier, SYCLKernelInfo &Kernel) {
48+
WriteLockT WriteLock{CacheMutex};
49+
Cache.emplace(Identifier, Kernel);
50+
}

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ FusionResult KernelFusion::fuseKernels(
4949
// available (on a per-thread basis).
5050
ConfigHelper::setConfig(std::move(JITConfig));
5151

52+
bool CachingEnabled = ConfigHelper::get<option::JITEnableCaching>();
53+
CacheKeyT CacheKey{KernelsToFuse, Identities, BarriersFlags, Internalization,
54+
Constants};
55+
if (CachingEnabled) {
56+
std::optional<SYCLKernelInfo> CachedKernel = JITCtx.getCacheEntry(CacheKey);
57+
if (CachedKernel) {
58+
helper::printDebugMessage("Re-using cached JIT kernel");
59+
return FusionResult{*CachedKernel, /*Cached*/ true};
60+
}
61+
helper::printDebugMessage(
62+
"Compiling new kernel, no suitable cached kernel found");
63+
}
64+
5265
SYCLModuleInfo ModuleInfo;
5366
// Copy the kernel information for the input kernels to the module
5467
// information. We could remove the copy, if we removed the const from the
@@ -115,5 +128,9 @@ FusionResult KernelFusion::fuseKernels(
115128
FusedBinaryInfo.BinaryStart = SPIRVBin->address();
116129
FusedBinaryInfo.BinarySize = SPIRVBin->size();
117130

131+
if (CachingEnabled) {
132+
JITCtx.addCacheEntry(CacheKey, FusedKernelInfo);
133+
}
134+
118135
return FusionResult{FusedKernelInfo};
119136
}

sycl/doc/EnvironmentVariables.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ compiler and runtime.
2424
| `SYCL_USM_HOSTPTR_IMPORT` | Integer | Enable by specifying non-zero value. Buffers created with a host pointer will result in host data promotion to USM, improving data transfer performance. To use this feature, also set SYCL_HOST_UNIFIED_MEMORY=1. |
2525
| `SYCL_EAGER_INIT` | Integer | Enable by specifying non-zero value. Tells the SYCL runtime to do as much as possible initialization at objects construction as opposed to doing lazy initialization on the fly. This may mean doing some redundant work at warmup but ensures fastest possible execution on the following hot and reportable paths. It also instructs PI plugins to do the same. Default is "0". |
2626
| `SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE` | See [below](#sycl_reduction_preferred_workgroup_size) | Controls the preferred work-group size of reductions. |
27+
| `SYCL_ENABLE_FUSION_CACHING` | '1' or '0' | Enable ('1') or disable ('0') caching of JIT compilations for kernel fusion. Caching avoids repeatedly running the JIT compilation pipeline if the same sequence of kernels is fused multiple times. Default value is '1'. |
2728

2829
`(*) Note: Any means this environment variable is effective when set to any non-null value.`
2930

sycl/source/detail/config.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ CONFIG(SYCL_QUEUE_THREAD_POOL_SIZE, 4, __SYCL_QUEUE_THREAD_POOL_SIZE)
3939
CONFIG(SYCL_RT_WARNING_LEVEL, 4, __SYCL_RT_WARNING_LEVEL)
4040
CONFIG(SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE, 16, __SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE)
4141
CONFIG(ONEAPI_DEVICE_SELECTOR, 1024, __ONEAPI_DEVICE_SELECTOR)
42+
CONFIG(SYCL_ENABLE_FUSION_CACHING, 1, __SYCL_ENABLE_FUSION_CACHING)

sycl/source/detail/config.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,34 @@ template <> class SYCLConfig<SYCL_REDUCTION_PREFERRED_WORKGROUP_SIZE> {
579579
}
580580
};
581581

582+
template <> class SYCLConfig<SYCL_ENABLE_FUSION_CACHING> {
583+
using BaseT = SYCLConfigBase<SYCL_ENABLE_FUSION_CACHING>;
584+
585+
public:
586+
static bool get() {
587+
constexpr bool DefaultValue = true;
588+
589+
const char *ValStr = getCachedValue();
590+
591+
if (!ValStr)
592+
return DefaultValue;
593+
594+
return ValStr[0] == '1';
595+
}
596+
597+
static void reset() { (void)getCachedValue(/*ResetCache=*/true); }
598+
599+
static const char *getName() { return BaseT::MConfigName; }
600+
601+
private:
602+
static const char *getCachedValue(bool ResetCache = false) {
603+
static const char *ValStr = BaseT::getRawValue();
604+
if (ResetCache)
605+
ValStr = BaseT::getRawValue();
606+
return ValStr;
607+
}
608+
};
609+
582610
#undef INVALID_CONFIG_EXCEPTION
583611

584612
} // namespace detail

sycl/source/detail/jit_compiler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,8 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
751751
bool DebugEnabled =
752752
detail::SYCLConfig<detail::SYCL_RT_WARNING_LEVEL>::get() > 0;
753753
JITConfig.set<::jit_compiler::option::JITEnableVerbose>(DebugEnabled);
754-
// TODO: Enable caching in a separate PR.
754+
JITConfig.set<::jit_compiler::option::JITEnableCaching>(
755+
detail::SYCLConfig<detail::SYCL_ENABLE_FUSION_CACHING>::get());
755756

756757
auto FusionResult = ::jit_compiler::KernelFusion::fuseKernels(
757758
*MJITContext, std::move(JITConfig), InputKernelInfo, InputKernelNames,

0 commit comments

Comments
 (0)