Skip to content

[SYCL] Fix undefined behaviour during graph cleanup #2157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,15 @@ class Command {
bool MIsBlockable = false;
/// Counts the number of memory objects this command is a leaf for.
unsigned MLeafCounter = 0;
/// Used for marking the node as visited during graph traversal.
bool MVisited = false;

struct Marks {
/// Used for marking the node as visited during graph traversal.
bool MVisited = false;
/// Used for marking the node for deletion during cleanup.
bool MToBeDeleted = false;
};
/// Used for marking the node during graph traversal.
Marks MMarks;

enum class BlockReason : int { HostAccessor = 0, HostTask };

Expand Down
36 changes: 19 additions & 17 deletions sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,26 @@ Scheduler::GraphBuilder::GraphBuilder() {
}

static bool markNodeAsVisited(Command *Cmd, std::vector<Command *> &Visited) {
if (Cmd->MVisited)
if (Cmd->MMarks.MVisited)
return false;
Cmd->MVisited = true;
Cmd->MMarks.MVisited = true;
Visited.push_back(Cmd);
return true;
}

static void unmarkVisitedNodes(std::vector<Command *> &Visited) {
for (Command *Cmd : Visited)
Cmd->MVisited = false;
Cmd->MMarks.MVisited = false;
}

static void handleVisitedNodes(std::vector<Command *> &Visited) {
for (Command *Cmd : Visited) {
if (Cmd->MMarks.MToBeDeleted) {
Cmd->getEvent()->setCommand(nullptr);
delete Cmd;
} else
Cmd->MMarks.MVisited = false;
}
}

static void printDotRecursive(std::fstream &Stream,
Expand Down Expand Up @@ -825,7 +835,6 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {

std::queue<Command *> ToVisit;
std::vector<Command *> Visited;
std::vector<Command *> CmdsToDelete;
// First, mark all allocas for deletion and their direct users for traversal
// Dependencies of the users will be cleaned up during the traversal
for (Command *AllocaCmd : AllocaCommands) {
Expand All @@ -839,7 +848,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
else
markNodeAsVisited(UserCmd, Visited);

CmdsToDelete.push_back(AllocaCmd);
AllocaCmd->MMarks.MToBeDeleted = true;
// These commands will be deleted later, clear users now to avoid
// updating them during edge removal
AllocaCmd->MUsers.clear();
Expand All @@ -851,7 +860,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
AllocaCommandBase *LinkedCmd = AllocaCmd->MLinkedAllocaCmd;

if (LinkedCmd) {
assert(LinkedCmd->MVisited);
assert(LinkedCmd->MMarks.MVisited);

for (DepDesc &Dep : AllocaCmd->MDeps)
if (Dep.MDepCommand)
Expand Down Expand Up @@ -896,17 +905,12 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
// If all dependencies have been removed this way, mark the command for
// deletion
if (Cmd->MDeps.empty()) {
CmdsToDelete.push_back(Cmd);
Cmd->MMarks.MToBeDeleted = true;
Cmd->MUsers.clear();
}
}

unmarkVisitedNodes(Visited);

for (Command *Cmd : CmdsToDelete) {
Cmd->getEvent()->setCommand(nullptr);
delete Cmd;
}
handleVisitedNodes(Visited);
}

void Scheduler::GraphBuilder::cleanupFinishedCommands(Command *FinishedCmd) {
Expand Down Expand Up @@ -948,12 +952,10 @@ void Scheduler::GraphBuilder::cleanupFinishedCommands(Command *FinishedCmd) {
Command *DepCmd = Dep.MDepCommand;
DepCmd->MUsers.erase(Cmd);
}
Cmd->getEvent()->setCommand(nullptr);

Visited.pop_back();
delete Cmd;
Cmd->MMarks.MToBeDeleted = true;
}
unmarkVisitedNodes(Visited);
handleVisitedNodes(Visited);
}

void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {
Expand Down