Skip to content

[SYCL] Fix deadlock in ProgramManager class #2131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,25 @@ class KernelProgramCache {
/// The pointer is not null if and only if the entity is usable.
/// State of the entity is provided by the user of cache instance.
/// Currently there is only a single user - ProgramManager class.
template<typename T> struct BuildResult {
template <typename T> struct BuildResult {
std::atomic<T *> Ptr;
std::atomic<int> State;
BuildError Error;

/// Condition variable to signal that build result is ready.
/// A per-object (i.e. kernel or program) condition variable is employed
/// instead of global one in order to eliminate the following deadlock.
/// A thread T1 awaiting for build result BR1 to be ready may be awakened by
/// another thread (due to use of global condition variable), which made
/// build result BR2 ready. Meanwhile, a thread which made build result BR1
/// ready notifies everyone via a global condition variable and T1 will skip
/// this notification as it's not in condition_variable::wait()'s wait cycle
/// now. Now T1 goes to sleep again and will wait until either a spurious
/// wake-up or another thread will wake it up.
std::condition_variable MBuildCV;
/// A mutex to be employed along with MBuildCV.
std::mutex MBuildResultMutex;

BuildResult(T* P, int S) : Ptr{P}, State{S}, Error{"", 0} {}
};

Expand All @@ -59,14 +73,8 @@ class KernelProgramCache {

using PiKernelT = std::remove_pointer<RT::PiKernel>::type;

struct BuildResultKernel : public BuildResult<PiKernelT> {
std::mutex MKernelMutex;

BuildResultKernel(PiKernelT *P, int S) : BuildResult(P, S) {}
};

using PiKernelPtrT = std::atomic<PiKernelT *>;
using KernelWithBuildStateT = BuildResultKernel;
using KernelWithBuildStateT = BuildResult<PiKernelT>;
using KernelByNameT = std::map<string_class, KernelWithBuildStateT>;
using KernelCacheT = std::map<RT::PiProgram, KernelByNameT>;

Expand All @@ -82,21 +90,21 @@ class KernelProgramCache {
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
}

template <class Predicate> void waitUntilBuilt(Predicate Pred) const {
std::unique_lock<std::mutex> Lock(MBuildCVMutex);
template <typename T, class Predicate>
void waitUntilBuilt(BuildResult<T> &BR, Predicate Pred) const {
std::unique_lock<std::mutex> Lock(BR.MBuildResultMutex);

MBuildCV.wait(Lock, Pred);
BR.MBuildCV.wait(Lock, Pred);
}

void notifyAllBuild() const { MBuildCV.notify_all(); }
template <typename T> void notifyAllBuild(BuildResult<T> &BR) const {
BR.MBuildCV.notify_all();
}

private:
std::mutex MProgramCacheMutex;
std::mutex MKernelsPerProgramCacheMutex;

mutable std::condition_variable MBuildCV;
mutable std::mutex MBuildCVMutex;

ProgramCacheT MCachedPrograms;
KernelCacheT MKernelsPerProgramCache;
ContextPtr MParentContext;
Expand Down
16 changes: 8 additions & 8 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ RetT *waitUntilBuilt(KernelProgramCache &Cache,
KernelProgramCache::BuildResult<RetT> *BuildResult) {
// any thread which will find nullptr in cache will wait until the pointer
// is not null anymore
Cache.waitUntilBuilt([BuildResult]() {
Cache.waitUntilBuilt(*BuildResult, [BuildResult]() {
int State = BuildResult->State.load();

return State == BS_Done || State == BS_Failed;
Expand Down Expand Up @@ -212,7 +212,7 @@ getOrBuild(KernelProgramCache &KPCache, KeyT &&CacheKey, AcquireFT &&Acquire,

BuildResult->State.store(BS_Done);

KPCache.notifyAllBuild();
KPCache.notifyAllBuild(*BuildResult);

return BuildResult;
} catch (const exception &Ex) {
Expand All @@ -221,13 +221,13 @@ getOrBuild(KernelProgramCache &KPCache, KeyT &&CacheKey, AcquireFT &&Acquire,

BuildResult->State.store(BS_Failed);

KPCache.notifyAllBuild();
KPCache.notifyAllBuild(*BuildResult);

std::rethrow_exception(std::current_exception());
} catch (...) {
BuildResult->State.store(BS_Failed);

KPCache.notifyAllBuild();
KPCache.notifyAllBuild(*BuildResult);

std::rethrow_exception(std::current_exception());
}
Expand Down Expand Up @@ -445,10 +445,10 @@ ProgramManager::getOrCreateKernel(OSModuleHandle M, const context &Context,
return Result;
};

auto BuildResult = static_cast<KernelProgramCache::BuildResultKernel *>(
getOrBuild<PiKernelT, invalid_object_error>(Cache, KernelName, AcquireF,
GetF, BuildF));
return std::make_pair(BuildResult->Ptr.load(), &(BuildResult->MKernelMutex));
auto BuildResult = getOrBuild<PiKernelT, invalid_object_error>(
Cache, KernelName, AcquireF, GetF, BuildF);
return std::make_pair(BuildResult->Ptr.load(),
&(BuildResult->MBuildResultMutex));
}

RT::PiProgram
Expand Down