Skip to content

Commit c16705a

Browse files
authored
[SYCL] Reduce the time to get a kernel from cache (#4186)
Signed-off-by: Alexander Flegontov <[email protected]>
1 parent a50f45b commit c16705a

File tree

6 files changed

+212
-21
lines changed

6 files changed

+212
-21
lines changed

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include <mutex>
2222
#include <type_traits>
2323

24+
// For testing purposes
25+
class MockKernelProgramCache;
26+
2427
__SYCL_INLINE_NAMESPACE(cl) {
2528
namespace sycl {
2629
namespace detail {
@@ -79,6 +82,13 @@ class KernelProgramCache {
7982
using KernelByNameT = std::map<std::string, KernelWithBuildStateT>;
8083
using KernelCacheT = std::map<RT::PiProgram, KernelByNameT>;
8184

85+
using KernelFastCacheKeyT =
86+
std::tuple<SerializedObj, OSModuleHandle, RT::PiDevice, std::string,
87+
std::string>;
88+
using KernelFastCacheValT =
89+
std::tuple<RT::PiKernel, std::mutex *, RT::PiProgram>;
90+
using KernelFastCacheT = std::map<KernelFastCacheKeyT, KernelFastCacheValT>;
91+
8292
~KernelProgramCache();
8393

8494
void setContextPtr(const ContextPtr &AContext) { MParentContext = AContext; }
@@ -102,13 +112,35 @@ class KernelProgramCache {
102112
BR.MBuildCV.notify_all();
103113
}
104114

115+
template <typename KeyT>
116+
KernelFastCacheValT tryToGetKernelFast(KeyT &&CacheKey) {
117+
std::unique_lock<std::mutex> Lock(MKernelFastCacheMutex);
118+
auto It = MKernelFastCache.find(CacheKey);
119+
if (It != MKernelFastCache.end()) {
120+
return It->second;
121+
}
122+
return std::make_tuple(nullptr, nullptr, nullptr);
123+
}
124+
125+
template <typename KeyT, typename ValT>
126+
void saveKernel(KeyT &&CacheKey, ValT &&CacheVal) {
127+
std::unique_lock<std::mutex> Lock(MKernelFastCacheMutex);
128+
// if no insertion took place, thus some other thread has already inserted
129+
// smth in the cache
130+
MKernelFastCache.emplace(CacheKey, CacheVal);
131+
}
132+
105133
private:
106134
std::mutex MProgramCacheMutex;
107135
std::mutex MKernelsPerProgramCacheMutex;
108136

109137
ProgramCacheT MCachedPrograms;
110138
KernelCacheT MKernelsPerProgramCache;
111139
ContextPtr MParentContext;
140+
141+
std::mutex MKernelFastCacheMutex;
142+
KernelFastCacheT MKernelFastCache;
143+
friend class ::MockKernelProgramCache;
112144
};
113145
} // namespace detail
114146
} // namespace sycl

sycl/source/detail/program_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ RT::PiKernel program_impl::get_pi_kernel(const std::string &KernelName) const {
439439
RT::PiKernel Kernel = nullptr;
440440

441441
if (is_cacheable()) {
442-
std::tie(Kernel, std::ignore) =
442+
std::tie(Kernel, std::ignore, std::ignore) =
443443
ProgramManager::getInstance().getOrCreateKernel(
444444
MProgramModuleHandle, detail::getSyclObjImpl(get_context()),
445445
detail::getSyclObjImpl(get_devices()[0]), KernelName, this);

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -519,24 +519,41 @@ RT::PiProgram ProgramManager::getBuiltPIProgram(
519519
return BuildResult->Ptr.load();
520520
}
521521

522-
std::pair<RT::PiKernel, std::mutex *> ProgramManager::getOrCreateKernel(
523-
OSModuleHandle M, const ContextImplPtr &Context,
524-
const DeviceImplPtr &Device, const std::string &KernelName,
525-
const program_impl *Prg) {
522+
std::tuple<RT::PiKernel, std::mutex *, RT::PiProgram>
523+
ProgramManager::getOrCreateKernel(OSModuleHandle M,
524+
const ContextImplPtr &ContextImpl,
525+
const DeviceImplPtr &DeviceImpl,
526+
const std::string &KernelName,
527+
const program_impl *Prg) {
526528
if (DbgProgMgr > 0) {
527529
std::cerr << ">>> ProgramManager::getOrCreateKernel(" << M << ", "
528-
<< Context.get() << ", " << Device.get() << ", " << KernelName
529-
<< ")\n";
530+
<< ContextImpl.get() << ", " << DeviceImpl.get() << ", "
531+
<< KernelName << ")\n";
530532
}
531533

532-
RT::PiProgram Program =
533-
getBuiltPIProgram(M, Context, Device, KernelName, Prg);
534-
535534
using PiKernelT = KernelProgramCache::PiKernelT;
536535
using KernelCacheT = KernelProgramCache::KernelCacheT;
537536
using KernelByNameT = KernelProgramCache::KernelByNameT;
538537

539-
KernelProgramCache &Cache = Context->getKernelProgramCache();
538+
KernelProgramCache &Cache = ContextImpl->getKernelProgramCache();
539+
540+
std::string CompileOpts, LinkOpts;
541+
SerializedObj SpecConsts;
542+
if (Prg) {
543+
CompileOpts = Prg->get_build_options();
544+
Prg->stableSerializeSpecConstRegistry(SpecConsts);
545+
}
546+
applyOptionsFromEnvironment(CompileOpts, LinkOpts);
547+
const RT::PiDevice PiDevice = DeviceImpl->getHandleRef();
548+
549+
auto key = std::make_tuple(std::move(SpecConsts), M, PiDevice,
550+
CompileOpts + LinkOpts, KernelName);
551+
auto ret_tuple = Cache.tryToGetKernelFast(key);
552+
if (std::get<0>(ret_tuple))
553+
return ret_tuple;
554+
555+
RT::PiProgram Program =
556+
getBuiltPIProgram(M, ContextImpl, DeviceImpl, KernelName, Prg);
540557

541558
auto AcquireF = [](KernelProgramCache &Cache) {
542559
return Cache.acquireKernelsPerProgramCache();
@@ -545,12 +562,12 @@ std::pair<RT::PiKernel, std::mutex *> ProgramManager::getOrCreateKernel(
545562
[&Program](const Locked<KernelCacheT> &LockedCache) -> KernelByNameT & {
546563
return LockedCache.get()[Program];
547564
};
548-
auto BuildF = [&Program, &KernelName, &Context] {
565+
auto BuildF = [&Program, &KernelName, &ContextImpl] {
549566
PiKernelT *Result = nullptr;
550567

551568
// TODO need some user-friendly error/exception
552569
// instead of currently obscure one
553-
const detail::plugin &Plugin = Context->getPlugin();
570+
const detail::plugin &Plugin = ContextImpl->getPlugin();
554571
Plugin.call<PiApiKind::piKernelCreate>(Program, KernelName.c_str(),
555572
&Result);
556573

@@ -564,8 +581,10 @@ std::pair<RT::PiKernel, std::mutex *> ProgramManager::getOrCreateKernel(
564581

565582
auto BuildResult = getOrBuild<PiKernelT, invalid_object_error>(
566583
Cache, KernelName, AcquireF, GetF, BuildF);
567-
return std::make_pair(BuildResult->Ptr.load(),
568-
&(BuildResult->MBuildResultMutex));
584+
auto ret_val = std::make_tuple(BuildResult->Ptr.load(),
585+
&(BuildResult->MBuildResultMutex), Program);
586+
Cache.saveKernel(key, ret_val);
587+
return ret_val;
569588
}
570589

571590
RT::PiProgram

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class ProgramManager {
106106
const property_list &PropList,
107107
bool JITCompilationIsRequired = false);
108108

109-
std::pair<RT::PiKernel, std::mutex *>
109+
std::tuple<RT::PiKernel, std::mutex *, RT::PiProgram>
110110
getOrCreateKernel(OSModuleHandle M, const ContextImplPtr &ContextImpl,
111111
const DeviceImplPtr &DeviceImpl,
112112
const std::string &KernelName, const program_impl *Prg);

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,21 +2063,18 @@ cl_int ExecCGCommand::enqueueImp() {
20632063
Program = SyclProg->getHandleRef();
20642064
if (SyclProg->is_cacheable()) {
20652065
RT::PiKernel FoundKernel = nullptr;
2066-
std::tie(FoundKernel, KernelMutex) =
2066+
std::tie(FoundKernel, KernelMutex, std::ignore) =
20672067
detail::ProgramManager::getInstance().getOrCreateKernel(
20682068
ExecKernel->MOSModuleHandle, ContextImpl, DeviceImpl,
20692069
ExecKernel->MKernelName, SyclProg.get());
20702070
assert(FoundKernel == Kernel);
20712071
} else
20722072
KnownProgram = false;
20732073
} else {
2074-
std::tie(Kernel, KernelMutex) =
2074+
std::tie(Kernel, KernelMutex, Program) =
20752075
detail::ProgramManager::getInstance().getOrCreateKernel(
20762076
ExecKernel->MOSModuleHandle, ContextImpl, DeviceImpl,
20772077
ExecKernel->MKernelName, nullptr);
2078-
MQueue->getPlugin().call<PiApiKind::piKernelGetInfo>(
2079-
Kernel, PI_KERNEL_INFO_PROGRAM, sizeof(RT::PiProgram), &Program,
2080-
nullptr);
20812078
}
20822079

20832080
pi_result Error = PI_SUCCESS;

sycl/unittests/kernel-and-program/Cache.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,146 @@ TEST_F(KernelAndProgramCacheTest, KernelNegativeSource) {
434434
CtxImpl->getKernelProgramCache().acquireKernelsPerProgramCache().get();
435435
EXPECT_EQ(Cache.size(), 0U) << "Expect empty cache for kernels";
436436
}
437+
438+
typedef KernelAndProgramCacheTest KernelAndProgramFastCacheTest;
439+
440+
class MockKernelProgramCache : public detail::KernelProgramCache {
441+
public:
442+
static detail::KernelProgramCache::KernelFastCacheT &
443+
getFastCache(detail::KernelProgramCache &cache) {
444+
return (reinterpret_cast<MockKernelProgramCache &>(cache)).get();
445+
}
446+
447+
detail::KernelProgramCache::KernelFastCacheT &get() {
448+
return this->MKernelFastCache;
449+
}
450+
};
451+
452+
// Check that kernels built without options are cached.
453+
TEST_F(KernelAndProgramFastCacheTest, KernelPositive) {
454+
if (Plt.is_host() || Plt.get_backend() != backend::opencl) {
455+
return;
456+
}
457+
458+
context Ctx{Plt};
459+
auto CtxImpl = detail::getSyclObjImpl(Ctx);
460+
461+
globalCtx.reset(new TestCtx{CtxImpl->getHandleRef()});
462+
463+
program Prg{Ctx};
464+
465+
Prg.build_with_kernel_type<TestKernel>();
466+
kernel Ker = Prg.get_kernel<TestKernel>();
467+
detail::KernelProgramCache::KernelFastCacheT &Cache =
468+
MockKernelProgramCache::getFastCache(CtxImpl->getKernelProgramCache());
469+
EXPECT_EQ(Cache.size(), 1U) << "Expect non-empty cache for kernels";
470+
}
471+
472+
// Check that kernels built with options are cached.
473+
TEST_F(KernelAndProgramFastCacheTest, KernelPositiveBuildOpts) {
474+
if (Plt.is_host() || Plt.get_backend() != backend::opencl) {
475+
return;
476+
}
477+
478+
context Ctx{Plt};
479+
auto CtxImpl = detail::getSyclObjImpl(Ctx);
480+
481+
globalCtx.reset(new TestCtx{CtxImpl->getHandleRef()});
482+
483+
program Prg{Ctx};
484+
485+
Prg.build_with_kernel_type<TestKernel>("-g");
486+
487+
kernel Ker = Prg.get_kernel<TestKernel>();
488+
detail::KernelProgramCache::KernelFastCacheT &Cache =
489+
MockKernelProgramCache::getFastCache(CtxImpl->getKernelProgramCache());
490+
EXPECT_EQ(Cache.size(), 1U) << "Expect non-empty cache for kernels";
491+
}
492+
493+
// Check that kernels built with compile options are not cached.
494+
TEST_F(KernelAndProgramFastCacheTest, KernelNegativeCompileOpts) {
495+
if (Plt.is_host() || Plt.get_backend() != backend::opencl) {
496+
return;
497+
}
498+
499+
context Ctx{Plt};
500+
auto CtxImpl = detail::getSyclObjImpl(Ctx);
501+
502+
globalCtx.reset(new TestCtx{CtxImpl->getHandleRef()});
503+
504+
program Prg{Ctx};
505+
506+
Prg.compile_with_kernel_type<TestKernel>("-g");
507+
Prg.link();
508+
kernel Ker = Prg.get_kernel<TestKernel>();
509+
detail::KernelProgramCache::KernelFastCacheT &Cache =
510+
MockKernelProgramCache::getFastCache(CtxImpl->getKernelProgramCache());
511+
EXPECT_EQ(Cache.size(), 0U) << "Expect empty cache for kernels";
512+
}
513+
514+
// Check that kernels built with link options are not cached.
515+
TEST_F(KernelAndProgramFastCacheTest, KernelNegativeLinkOpts) {
516+
if (Plt.is_host() || Plt.get_backend() != backend::opencl) {
517+
return;
518+
}
519+
520+
context Ctx{Plt};
521+
auto CtxImpl = detail::getSyclObjImpl(Ctx);
522+
523+
globalCtx.reset(new TestCtx{CtxImpl->getHandleRef()});
524+
525+
program Prg{Ctx};
526+
527+
Prg.compile_with_kernel_type<TestKernel>();
528+
Prg.link("-g");
529+
kernel Ker = Prg.get_kernel<TestKernel>();
530+
detail::KernelProgramCache::KernelFastCacheT &Cache =
531+
MockKernelProgramCache::getFastCache(CtxImpl->getKernelProgramCache());
532+
EXPECT_EQ(Cache.size(), 0U) << "Expect empty cache for kernels";
533+
}
534+
535+
// Check that kernels are not cached if program is created from multiple
536+
// programs.
537+
TEST_F(KernelAndProgramFastCacheTest, KernelNegativeLinkedProgs) {
538+
if (Plt.is_host() || Plt.get_backend() != backend::opencl) {
539+
return;
540+
}
541+
542+
context Ctx{Plt};
543+
auto CtxImpl = detail::getSyclObjImpl(Ctx);
544+
545+
globalCtx.reset(new TestCtx{CtxImpl->getHandleRef()});
546+
547+
program Prg1{Ctx};
548+
program Prg2{Ctx};
549+
550+
Prg1.compile_with_kernel_type<TestKernel>();
551+
Prg2.compile_with_kernel_type<TestKernel2>();
552+
program Prg({Prg1, Prg2});
553+
kernel Ker = Prg.get_kernel<TestKernel>();
554+
555+
detail::KernelProgramCache::KernelFastCacheT &Cache =
556+
MockKernelProgramCache::getFastCache(CtxImpl->getKernelProgramCache());
557+
EXPECT_EQ(Cache.size(), 0U) << "Expect empty cache for kernels";
558+
}
559+
560+
// Check that kernels created from source are not cached.
561+
TEST_F(KernelAndProgramFastCacheTest, KernelNegativeSource) {
562+
if (Plt.is_host() || Plt.get_backend() != backend::opencl) {
563+
return;
564+
}
565+
566+
context Ctx{Plt};
567+
auto CtxImpl = detail::getSyclObjImpl(Ctx);
568+
569+
globalCtx.reset(new TestCtx{CtxImpl->getHandleRef()});
570+
571+
program Prg{Ctx};
572+
573+
Prg.build_with_source("");
574+
kernel Ker = Prg.get_kernel("test");
575+
576+
detail::KernelProgramCache::KernelFastCacheT &Cache =
577+
MockKernelProgramCache::getFastCache(CtxImpl->getKernelProgramCache());
578+
EXPECT_EQ(Cache.size(), 0U) << "Expect empty cache for kernels";
579+
}

0 commit comments

Comments
 (0)