Skip to content

Commit 0bd1bf9

Browse files
committed
[SYCL] Add per-kernel mutex to fix the race when setting kernel parameters in parallel
Signed-off-by: Alexander Flegontov <[email protected]>
1 parent 2e4b1d3 commit 0bd1bf9

File tree

6 files changed

+43
-21
lines changed

6 files changed

+43
-21
lines changed

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,15 @@ class KernelProgramCache {
5858
using ContextPtr = context_impl *;
5959

6060
using PiKernelT = std::remove_pointer<RT::PiKernel>::type;
61+
62+
struct BuildResultKernel : public BuildResult<PiKernelT> {
63+
std::mutex MKernelMutex;
64+
65+
BuildResultKernel(PiKernelT *P, int S) : BuildResult(P, S) {}
66+
};
67+
6168
using PiKernelPtrT = std::atomic<PiKernelT *>;
62-
using KernelWithBuildStateT = BuildResult<PiKernelT>;
69+
using KernelWithBuildStateT = BuildResultKernel;
6370
using KernelByNameT = std::map<string_class, KernelWithBuildStateT>;
6471
using KernelCacheT = std::map<RT::PiProgram, KernelByNameT>;
6572

sycl/source/detail/program_impl.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,9 @@ RT::PiKernel program_impl::get_pi_kernel(const string_class &KernelName) const {
396396
RT::PiKernel Kernel;
397397

398398
if (is_cacheable()) {
399-
Kernel = ProgramManager::getInstance().getOrCreateKernel(
400-
MProgramModuleHandle, get_context(), KernelName, this);
399+
std::tie(Kernel, std::ignore) =
400+
ProgramManager::getInstance().getOrCreateKernel(
401+
MProgramModuleHandle, get_context(), KernelName, this);
401402
getPlugin().call<PiApiKind::piKernelRetain>(Kernel);
402403
} else {
403404
const detail::plugin &Plugin = getPlugin();

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ RetT *waitUntilBuilt(KernelProgramCache &Cache,
167167
/// cache. Accepts nothing. Return pointer to built entity.
168168
template <typename RetT, typename ExceptionT, typename KeyT, typename AcquireFT,
169169
typename GetCacheFT, typename BuildFT>
170-
RetT *getOrBuild(KernelProgramCache &KPCache, KeyT &&CacheKey,
171-
AcquireFT &&Acquire, GetCacheFT &&GetCache, BuildFT &&Build) {
170+
KernelProgramCache::BuildResult<RetT> *
171+
getOrBuild(KernelProgramCache &KPCache, KeyT &&CacheKey, AcquireFT &&Acquire,
172+
GetCacheFT &&GetCache, BuildFT &&Build) {
172173
bool InsertionTookPlace;
173174
KernelProgramCache::BuildResult<RetT> *BuildResult;
174175

@@ -190,7 +191,7 @@ RetT *getOrBuild(KernelProgramCache &KPCache, KeyT &&CacheKey,
190191
RetT *Result = waitUntilBuilt<ExceptionT>(KPCache, BuildResult);
191192

192193
if (Result)
193-
return Result;
194+
return BuildResult;
194195

195196
// Previous build is failed. There was no SYCL exception though.
196197
// We might try to build once more.
@@ -220,7 +221,7 @@ RetT *getOrBuild(KernelProgramCache &KPCache, KeyT &&CacheKey,
220221

221222
KPCache.notifyAllBuild();
222223

223-
return Desired;
224+
return BuildResult;
224225
} catch (const exception &Ex) {
225226
BuildResult->Error.Msg = Ex.what();
226227
BuildResult->Error.Code = Ex.get_cl_code();
@@ -400,14 +401,15 @@ RT::PiProgram ProgramManager::getBuiltPIProgram(OSModuleHandle M,
400401
if (Prg)
401402
Prg->stableSerializeSpecConstRegistry(SpecConsts);
402403

403-
return getOrBuild<PiProgramT, compile_program_error>(
404+
auto BuildResult = getOrBuild<PiProgramT, compile_program_error>(
404405
Cache, KeyT(std::move(SpecConsts), KSId), AcquireF, GetF, BuildF);
406+
return BuildResult->Ptr.load();
405407
}
406408

407-
RT::PiKernel ProgramManager::getOrCreateKernel(OSModuleHandle M,
408-
const context &Context,
409-
const string_class &KernelName,
410-
const program_impl *Prg) {
409+
std::pair<RT::PiKernel, std::mutex *>
410+
ProgramManager::getOrCreateKernel(OSModuleHandle M, const context &Context,
411+
const string_class &KernelName,
412+
const program_impl *Prg) {
411413
if (DbgProgMgr > 0) {
412414
std::cerr << ">>> ProgramManager::getOrCreateKernel(" << M << ", "
413415
<< getRawSyclObjImpl(Context) << ", " << KernelName << ")\n";
@@ -441,8 +443,10 @@ RT::PiKernel ProgramManager::getOrCreateKernel(OSModuleHandle M,
441443
return Result;
442444
};
443445

444-
return getOrBuild<PiKernelT, invalid_object_error>(Cache, KernelName,
445-
AcquireF, GetF, BuildF);
446+
auto BuildResult = static_cast<KernelProgramCache::BuildResultKernel *>(
447+
getOrBuild<PiKernelT, invalid_object_error>(Cache, KernelName, AcquireF,
448+
GetF, BuildF));
449+
return std::make_pair(BuildResult->Ptr.load(), &(BuildResult->MKernelMutex));
446450
}
447451

448452
RT::PiProgram

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ class ProgramManager {
7777
RT::PiProgram getBuiltPIProgram(OSModuleHandle M, const context &Context,
7878
const string_class &KernelName,
7979
const program_impl *Prg = nullptr);
80-
RT::PiKernel getOrCreateKernel(OSModuleHandle M, const context &Context,
81-
const string_class &KernelName,
82-
const program_impl *Prg);
80+
std::pair<RT::PiKernel, std::mutex *>
81+
getOrCreateKernel(OSModuleHandle M, const context &Context,
82+
const string_class &KernelName, const program_impl *Prg);
8383
RT::PiProgram getPiProgramFromPiKernel(RT::PiKernel Kernel,
8484
const ContextImplPtr Context);
8585

sycl/source/detail/scheduler/commands.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,15 +1799,19 @@ cl_int ExecCGCommand::enqueueImp() {
17991799
sycl::context Context = MQueue->get_context();
18001800
const detail::plugin &Plugin = MQueue->getPlugin();
18011801
RT::PiKernel Kernel = nullptr;
1802+
std::mutex *KernelMutex = nullptr;
18021803

18031804
if (nullptr != ExecKernel->MSyclKernel) {
18041805
assert(ExecKernel->MSyclKernel->get_info<info::kernel::context>() ==
18051806
Context);
18061807
Kernel = ExecKernel->MSyclKernel->getHandleRef();
1807-
} else
1808-
Kernel = detail::ProgramManager::getInstance().getOrCreateKernel(
1809-
ExecKernel->MOSModuleHandle, Context, ExecKernel->MKernelName,
1810-
nullptr);
1808+
} else {
1809+
std::tie(Kernel, KernelMutex) =
1810+
detail::ProgramManager::getInstance().getOrCreateKernel(
1811+
ExecKernel->MOSModuleHandle, Context, ExecKernel->MKernelName,
1812+
nullptr);
1813+
KernelMutex->lock();
1814+
}
18111815

18121816
for (ArgDesc &Arg : ExecKernel->MArgs) {
18131817
switch (Arg.MType) {
@@ -1863,6 +1867,9 @@ cl_int ExecCGCommand::enqueueImp() {
18631867
&NDRDesc.GlobalSize[0], HasLocalSize ? &NDRDesc.LocalSize[0] : nullptr,
18641868
RawEvents.size(), RawEvents.empty() ? nullptr : &RawEvents[0], &Event);
18651869

1870+
if (KernelMutex != nullptr)
1871+
KernelMutex->unlock();
1872+
18661873
if (PI_SUCCESS != Error) {
18671874
// If we have got non-success error code, let's analyze it to emit nice
18681875
// exception explaining what was wrong

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ EventImplPtr Scheduler::addCG(std::unique_ptr<detail::CG> CommandGroup,
8080
default:
8181
NewCmd = MGraphBuilder.addCG(std::move(CommandGroup), std::move(Queue));
8282
}
83+
}
8384

85+
{
86+
std::shared_lock<std::shared_timed_mutex> Lock(MGraphLock);
8487
// TODO: Check if lazy mode.
8588
EnqueueResultT Res;
8689
bool Enqueued = GraphProcessor::enqueueCommand(NewCmd, Res);

0 commit comments

Comments
 (0)