Skip to content

Commit 42a5a35

Browse files
Switch the solution to adding another mark to command nodes
This approach shows better performance with floating point reduction which is cleanup-intensive. Signed-off-by: Sergey Semenov <[email protected]>
1 parent 08dceea commit 42a5a35

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

sycl/source/detail/scheduler/commands.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,16 @@ class Command {
219219
bool MIsBlockable = false;
220220
/// Counts the number of memory objects this command is a leaf for.
221221
unsigned MLeafCounter = 0;
222-
/// Used for marking the node as visited during graph traversal.
223-
bool MVisited = false;
222+
223+
struct Marks {
224+
/// Used for marking the node as visited during graph traversal.
225+
bool MVisited = false;
226+
/// Used for marking the node for deletion during cleanup.
227+
bool MToBeDeleted = false;
228+
};
229+
/// Used for marking the node during graph traversal.
230+
Marks MMarks;
231+
224232

225233
enum class BlockReason : int { HostAccessor = 0, HostTask };
226234

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,26 @@ Scheduler::GraphBuilder::GraphBuilder() {
9393
}
9494

9595
static bool markNodeAsVisited(Command *Cmd, std::vector<Command *> &Visited) {
96-
if (Cmd->MVisited)
96+
if (Cmd->MMarks.MVisited)
9797
return false;
98-
Cmd->MVisited = true;
98+
Cmd->MMarks.MVisited = true;
9999
Visited.push_back(Cmd);
100100
return true;
101101
}
102102

103103
static void unmarkVisitedNodes(std::vector<Command *> &Visited) {
104104
for (Command *Cmd : Visited)
105-
Cmd->MVisited = false;
105+
Cmd->MMarks.MVisited = false;
106+
}
107+
108+
static void handleVisitedNodes(std::vector<Command *> &Visited) {
109+
for (Command *Cmd : Visited) {
110+
if (Cmd->MMarks.MToBeDeleted) {
111+
Cmd->getEvent()->setCommand(nullptr);
112+
delete Cmd;
113+
}
114+
Cmd->MMarks.MVisited = false;
115+
}
106116
}
107117

108118
static void printDotRecursive(std::fstream &Stream,
@@ -825,7 +835,6 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
825835

826836
std::queue<Command *> ToVisit;
827837
std::vector<Command *> Visited;
828-
std::vector<Command *> CmdsToDelete;
829838
// First, mark all allocas for deletion and their direct users for traversal
830839
// Dependencies of the users will be cleaned up during the traversal
831840
for (Command *AllocaCmd : AllocaCommands) {
@@ -839,7 +848,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
839848
else
840849
markNodeAsVisited(UserCmd, Visited);
841850

842-
CmdsToDelete.push_back(AllocaCmd);
851+
AllocaCmd->MMarks.MToBeDeleted = true;
843852
// These commands will be deleted later, clear users now to avoid
844853
// updating them during edge removal
845854
AllocaCmd->MUsers.clear();
@@ -851,7 +860,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
851860
AllocaCommandBase *LinkedCmd = AllocaCmd->MLinkedAllocaCmd;
852861

853862
if (LinkedCmd) {
854-
assert(LinkedCmd->MVisited);
863+
assert(LinkedCmd->MMarks.MVisited);
855864

856865
for (DepDesc &Dep : AllocaCmd->MDeps)
857866
if (Dep.MDepCommand)
@@ -896,22 +905,16 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
896905
// If all dependencies have been removed this way, mark the command for
897906
// deletion
898907
if (Cmd->MDeps.empty()) {
899-
CmdsToDelete.push_back(Cmd);
908+
Cmd->MMarks.MToBeDeleted = true;
900909
Cmd->MUsers.clear();
901910
}
902911
}
903912

904-
unmarkVisitedNodes(Visited);
905-
906-
for (Command *Cmd : CmdsToDelete) {
907-
Cmd->getEvent()->setCommand(nullptr);
908-
delete Cmd;
909-
}
913+
handleVisitedNodes(Visited);
910914
}
911915

912916
void Scheduler::GraphBuilder::cleanupFinishedCommands(Command *FinishedCmd) {
913917
std::queue<Command *> CmdsToVisit({FinishedCmd});
914-
std::vector<Command *> CmdsToDelete;
915918
std::vector<Command *> Visited;
916919

917920
// Traverse the graph using BFS
@@ -950,15 +953,9 @@ void Scheduler::GraphBuilder::cleanupFinishedCommands(Command *FinishedCmd) {
950953
DepCmd->MUsers.erase(Cmd);
951954
}
952955

953-
CmdsToDelete.push_back(Cmd);
954-
Visited.pop_back();
955-
}
956-
unmarkVisitedNodes(Visited);
957-
958-
for (Command *Cmd : CmdsToDelete) {
959-
Cmd->getEvent()->setCommand(nullptr);
960-
delete Cmd;
956+
Cmd->MMarks.MToBeDeleted = true;
961957
}
958+
handleVisitedNodes(Visited);
962959
}
963960

964961
void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {

0 commit comments

Comments
 (0)