Skip to content

Commit 4688cb3

Browse files
authored
[SYCL] Fix thread safety issues in scheduler (#2281)
Signed-off-by: Alexander Flegontov <[email protected]>
1 parent 449308d commit 4688cb3

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,12 +1199,8 @@ AllocaCommandBase *ExecCGCommand::getAllocaForReq(Requirement *Req) {
11991199
throw runtime_error("Alloca for command not found", PI_INVALID_OPERATION);
12001200
}
12011201

1202-
void ExecCGCommand::flushStreams() {
1203-
assert(MCommandGroup->getType() == CG::KERNEL && "Expected kernel");
1204-
for (auto StreamImplPtr :
1205-
((CGExecKernel *)MCommandGroup.get())->getStreams()) {
1206-
StreamImplPtr->flush();
1207-
}
1202+
vector_class<StreamImplPtr> ExecCGCommand::getStreams() const {
1203+
return ((CGExecKernel *)MCommandGroup.get())->getStreams();
12081204
}
12091205

12101206
cl_int UpdateHostRequirementCommand::enqueueImp() {

sycl/source/detail/scheduler/commands.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class DispatchHostTask;
3232
using QueueImplPtr = std::shared_ptr<detail::queue_impl>;
3333
using EventImplPtr = std::shared_ptr<detail::event_impl>;
3434
using ContextImplPtr = std::shared_ptr<detail::context_impl>;
35+
using StreamImplPtr = std::shared_ptr<detail::stream_impl>;
3536

3637
class Command;
3738
class AllocaCommand;
@@ -480,7 +481,7 @@ class ExecCGCommand : public Command {
480481
public:
481482
ExecCGCommand(std::unique_ptr<detail::CG> CommandGroup, QueueImplPtr Queue);
482483

483-
void flushStreams();
484+
vector_class<StreamImplPtr> getStreams() const;
484485

485486
void printDot(std::ostream &Stream) const final override;
486487
void emitInstrumentationData() final override;

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <CL/sycl/device_selector.hpp>
1111
#include <detail/queue_impl.hpp>
1212
#include <detail/scheduler/scheduler.hpp>
13+
#include <detail/stream_impl.hpp>
1314

1415
#include <memory>
1516
#include <mutex>
@@ -63,12 +64,14 @@ void Scheduler::waitForRecordToFinish(MemObjRecord *Record) {
6364

6465
EventImplPtr Scheduler::addCG(std::unique_ptr<detail::CG> CommandGroup,
6566
QueueImplPtr Queue) {
66-
Command *NewCmd = nullptr;
67+
EventImplPtr NewEvent = nullptr;
6768
const bool IsKernel = CommandGroup->getType() == CG::KERNEL;
69+
vector_class<StreamImplPtr> Streams;
6870
{
6971
std::unique_lock<std::shared_timed_mutex> Lock(MGraphLock, std::defer_lock);
7072
lockSharedTimedMutex(Lock);
7173

74+
Command *NewCmd = nullptr;
7275
switch (CommandGroup->getType()) {
7376
case CG::UPDATE_HOST:
7477
NewCmd = MGraphBuilder.addCGUpdateHost(std::move(CommandGroup),
@@ -80,22 +83,30 @@ EventImplPtr Scheduler::addCG(std::unique_ptr<detail::CG> CommandGroup,
8083
default:
8184
NewCmd = MGraphBuilder.addCG(std::move(CommandGroup), std::move(Queue));
8285
}
86+
NewEvent = NewCmd->getEvent();
8387
}
8488

8589
{
8690
std::shared_lock<std::shared_timed_mutex> Lock(MGraphLock);
8791

88-
// TODO: Check if lazy mode.
89-
EnqueueResultT Res;
90-
bool Enqueued = GraphProcessor::enqueueCommand(NewCmd, Res);
91-
if (!Enqueued && EnqueueResultT::SyclEnqueueFailed == Res.MResult)
92-
throw runtime_error("Enqueue process failed.", PI_INVALID_OPERATION);
92+
Command *NewCmd = static_cast<Command *>(NewEvent->getCommand());
93+
if (NewCmd) {
94+
// TODO: Check if lazy mode.
95+
EnqueueResultT Res;
96+
bool Enqueued = GraphProcessor::enqueueCommand(NewCmd, Res);
97+
if (!Enqueued && EnqueueResultT::SyclEnqueueFailed == Res.MResult)
98+
throw runtime_error("Enqueue process failed.", PI_INVALID_OPERATION);
99+
100+
if (IsKernel)
101+
Streams = ((ExecCGCommand *)NewCmd)->getStreams();
102+
}
93103
}
94104

95-
if (IsKernel)
96-
((ExecCGCommand *)NewCmd)->flushStreams();
105+
for (auto StreamImplPtr : Streams) {
106+
StreamImplPtr->flush();
107+
}
97108

98-
return NewCmd->getEvent();
109+
return NewEvent;
99110
}
100111

101112
EventImplPtr Scheduler::addCopyBack(Requirement *Req) {

0 commit comments

Comments
 (0)