Skip to content

Commit c099e47

Browse files
[SYCL] Improve visited node tracking during graph traversal (#2067)
The main bottleneck of the graph cleanup mechanism was the tracking of already visited nodes with a std::set. This patch adds a Command member variable that allows various algorithms to mark the visited nodes during graph traversal with far less overhead. Signed-off-by: Sergey Semenov <[email protected]>
1 parent 4ba61d0 commit c099e47

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

sycl/source/detail/scheduler/commands.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ class Command {
219219
bool MIsBlockable = false;
220220
/// Counts the number of memory objects this command is a leaf for.
221221
unsigned MLeafCounter = 0;
222+
/// Used for marking the node as visited during graph traversal.
223+
bool MVisited = false;
222224

223225
enum class BlockReason : int { HostAccessor = 0, HostTask };
224226

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,22 @@ Scheduler::GraphBuilder::GraphBuilder() {
9292
}
9393
}
9494

95+
static bool markNodeAsVisited(Command *Cmd, std::vector<Command *> &Visited) {
96+
if (Cmd->MVisited)
97+
return false;
98+
Cmd->MVisited = true;
99+
Visited.push_back(Cmd);
100+
return true;
101+
}
102+
103+
static void unmarkVisitedNodes(std::vector<Command *> &Visited) {
104+
for (Command *Cmd : Visited)
105+
Cmd->MVisited = false;
106+
}
107+
95108
static void printDotRecursive(std::fstream &Stream,
96-
std::set<Command *> &Visited, Command *Cmd) {
97-
if (!Visited.insert(Cmd).second)
109+
std::vector<Command *> &Visited, Command *Cmd) {
110+
if (!markNodeAsVisited(Cmd, Visited))
98111
return;
99112
for (Command *User : Cmd->MUsers) {
100113
if (User)
@@ -114,13 +127,15 @@ void Scheduler::GraphBuilder::printGraphAsDot(const char *ModeName) {
114127
std::fstream Stream(FileName, std::ios::out);
115128
Stream << "strict digraph {" << std::endl;
116129

117-
std::set<Command *> Visited;
130+
std::vector<Command *> Visited;
118131

119132
for (SYCLMemObjI *MemObject : MMemObjs)
120133
for (Command *AllocaCmd : MemObject->MRecord->MAllocaCommands)
121134
printDotRecursive(Stream, Visited, AllocaCmd);
122135

123136
Stream << "}" << std::endl;
137+
138+
unmarkVisitedNodes(Visited);
124139
}
125140

126141
MemObjRecord *Scheduler::GraphBuilder::getMemObjRecord(SYCLMemObjI *MemObject) {
@@ -449,7 +464,7 @@ Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record,
449464
const Requirement *Req,
450465
const ContextImplPtr &Context) {
451466
std::set<Command *> RetDeps;
452-
std::set<Command *> Visited;
467+
std::vector<Command *> Visited;
453468
const bool ReadOnlyReq = Req->MAccessMode == access::mode::read;
454469

455470
std::vector<Command *> ToAnalyze{Record->MWriteLeaves.begin(),
@@ -490,11 +505,12 @@ Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record,
490505
break;
491506
}
492507

493-
if (Visited.insert(Dep.MDepCommand).second)
508+
if (markNodeAsVisited(Dep.MDepCommand, Visited))
494509
NewAnalyze.push_back(Dep.MDepCommand);
495510
}
496511
ToAnalyze.insert(ToAnalyze.end(), NewAnalyze.begin(), NewAnalyze.end());
497512
}
513+
unmarkVisitedNodes(Visited);
498514
return RetDeps;
499515
}
500516

@@ -808,20 +824,20 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
808824
return;
809825

810826
std::queue<Command *> ToVisit;
811-
std::set<Command *> Visited;
827+
std::vector<Command *> Visited;
812828
std::vector<Command *> CmdsToDelete;
813829
// First, mark all allocas for deletion and their direct users for traversal
814830
// Dependencies of the users will be cleaned up during the traversal
815831
for (Command *AllocaCmd : AllocaCommands) {
816-
Visited.insert(AllocaCmd);
832+
markNodeAsVisited(AllocaCmd, Visited);
817833

818834
for (Command *UserCmd : AllocaCmd->MUsers)
819835
// Linked alloca cmd may be in users of this alloca. We're not going to
820836
// visit it.
821837
if (UserCmd->getType() != Command::CommandType::ALLOCA)
822838
ToVisit.push(UserCmd);
823839
else
824-
Visited.insert(UserCmd);
840+
markNodeAsVisited(UserCmd, Visited);
825841

826842
CmdsToDelete.push_back(AllocaCmd);
827843
// These commands will be deleted later, clear users now to avoid
@@ -835,7 +851,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
835851
AllocaCommandBase *LinkedCmd = AllocaCmd->MLinkedAllocaCmd;
836852

837853
if (LinkedCmd) {
838-
assert(Visited.count(LinkedCmd));
854+
assert(LinkedCmd->MVisited);
839855

840856
for (DepDesc &Dep : AllocaCmd->MDeps)
841857
if (Dep.MDepCommand)
@@ -848,7 +864,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
848864
Command *Cmd = ToVisit.front();
849865
ToVisit.pop();
850866

851-
if (!Visited.insert(Cmd).second)
867+
if (!markNodeAsVisited(Cmd, Visited))
852868
continue;
853869

854870
for (Command *UserCmd : Cmd->MUsers)
@@ -885,6 +901,8 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
885901
}
886902
}
887903

904+
unmarkVisitedNodes(Visited);
905+
888906
for (Command *Cmd : CmdsToDelete) {
889907
Cmd->getEvent()->setCommand(nullptr);
890908
delete Cmd;
@@ -893,14 +911,14 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
893911

894912
void Scheduler::GraphBuilder::cleanupFinishedCommands(Command *FinishedCmd) {
895913
std::queue<Command *> CmdsToVisit({FinishedCmd});
896-
std::set<Command *> Visited;
914+
std::vector<Command *> Visited;
897915

898916
// Traverse the graph using BFS
899917
while (!CmdsToVisit.empty()) {
900918
Command *Cmd = CmdsToVisit.front();
901919
CmdsToVisit.pop();
902920

903-
if (!Visited.insert(Cmd).second)
921+
if (!markNodeAsVisited(Cmd, Visited))
904922
continue;
905923

906924
for (const DepDesc &Dep : Cmd->MDeps) {
@@ -932,8 +950,10 @@ void Scheduler::GraphBuilder::cleanupFinishedCommands(Command *FinishedCmd) {
932950
}
933951
Cmd->getEvent()->setCommand(nullptr);
934952

953+
Visited.pop_back();
935954
delete Cmd;
936955
}
956+
unmarkVisitedNodes(Visited);
937957
}
938958

939959
void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {

0 commit comments

Comments
 (0)