@@ -92,9 +92,22 @@ Scheduler::GraphBuilder::GraphBuilder() {
92
92
}
93
93
}
94
94
95
+ static bool markNodeAsVisited (Command *Cmd, std::vector<Command *> &Visited) {
96
+ if (Cmd->MVisited )
97
+ return false ;
98
+ Cmd->MVisited = true ;
99
+ Visited.push_back (Cmd);
100
+ return true ;
101
+ }
102
+
103
+ static void unmarkVisitedNodes (std::vector<Command *> &Visited) {
104
+ for (Command *Cmd : Visited)
105
+ Cmd->MVisited = false ;
106
+ }
107
+
95
108
static void printDotRecursive (std::fstream &Stream,
96
- std::set <Command *> &Visited, Command *Cmd) {
97
- if (!Visited. insert (Cmd). second )
109
+ std::vector <Command *> &Visited, Command *Cmd) {
110
+ if (!markNodeAsVisited (Cmd, Visited) )
98
111
return ;
99
112
for (Command *User : Cmd->MUsers ) {
100
113
if (User)
@@ -114,13 +127,15 @@ void Scheduler::GraphBuilder::printGraphAsDot(const char *ModeName) {
114
127
std::fstream Stream (FileName, std::ios::out);
115
128
Stream << " strict digraph {" << std::endl;
116
129
117
- std::set <Command *> Visited;
130
+ std::vector <Command *> Visited;
118
131
119
132
for (SYCLMemObjI *MemObject : MMemObjs)
120
133
for (Command *AllocaCmd : MemObject->MRecord ->MAllocaCommands )
121
134
printDotRecursive (Stream, Visited, AllocaCmd);
122
135
123
136
Stream << " }" << std::endl;
137
+
138
+ unmarkVisitedNodes (Visited);
124
139
}
125
140
126
141
MemObjRecord *Scheduler::GraphBuilder::getMemObjRecord (SYCLMemObjI *MemObject) {
@@ -449,7 +464,7 @@ Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record,
449
464
const Requirement *Req,
450
465
const ContextImplPtr &Context) {
451
466
std::set<Command *> RetDeps;
452
- std::set <Command *> Visited;
467
+ std::vector <Command *> Visited;
453
468
const bool ReadOnlyReq = Req->MAccessMode == access::mode::read;
454
469
455
470
std::vector<Command *> ToAnalyze{Record->MWriteLeaves .begin (),
@@ -490,11 +505,12 @@ Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record,
490
505
break ;
491
506
}
492
507
493
- if (Visited. insert (Dep.MDepCommand ). second )
508
+ if (markNodeAsVisited (Dep.MDepCommand , Visited) )
494
509
NewAnalyze.push_back (Dep.MDepCommand );
495
510
}
496
511
ToAnalyze.insert (ToAnalyze.end (), NewAnalyze.begin (), NewAnalyze.end ());
497
512
}
513
+ unmarkVisitedNodes (Visited);
498
514
return RetDeps;
499
515
}
500
516
@@ -808,20 +824,20 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
808
824
return ;
809
825
810
826
std::queue<Command *> ToVisit;
811
- std::set <Command *> Visited;
827
+ std::vector <Command *> Visited;
812
828
std::vector<Command *> CmdsToDelete;
813
829
// First, mark all allocas for deletion and their direct users for traversal
814
830
// Dependencies of the users will be cleaned up during the traversal
815
831
for (Command *AllocaCmd : AllocaCommands) {
816
- Visited. insert (AllocaCmd);
832
+ markNodeAsVisited (AllocaCmd, Visited );
817
833
818
834
for (Command *UserCmd : AllocaCmd->MUsers )
819
835
// Linked alloca cmd may be in users of this alloca. We're not going to
820
836
// visit it.
821
837
if (UserCmd->getType () != Command::CommandType::ALLOCA)
822
838
ToVisit.push (UserCmd);
823
839
else
824
- Visited. insert (UserCmd);
840
+ markNodeAsVisited (UserCmd, Visited );
825
841
826
842
CmdsToDelete.push_back (AllocaCmd);
827
843
// These commands will be deleted later, clear users now to avoid
@@ -835,7 +851,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
835
851
AllocaCommandBase *LinkedCmd = AllocaCmd->MLinkedAllocaCmd ;
836
852
837
853
if (LinkedCmd) {
838
- assert (Visited. count ( LinkedCmd) );
854
+ assert (LinkedCmd-> MVisited );
839
855
840
856
for (DepDesc &Dep : AllocaCmd->MDeps )
841
857
if (Dep.MDepCommand )
@@ -848,7 +864,7 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
848
864
Command *Cmd = ToVisit.front ();
849
865
ToVisit.pop ();
850
866
851
- if (!Visited. insert (Cmd). second )
867
+ if (!markNodeAsVisited (Cmd, Visited) )
852
868
continue ;
853
869
854
870
for (Command *UserCmd : Cmd->MUsers )
@@ -885,6 +901,8 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
885
901
}
886
902
}
887
903
904
+ unmarkVisitedNodes (Visited);
905
+
888
906
for (Command *Cmd : CmdsToDelete) {
889
907
Cmd->getEvent ()->setCommand (nullptr );
890
908
delete Cmd;
@@ -893,14 +911,14 @@ void Scheduler::GraphBuilder::cleanupCommandsForRecord(MemObjRecord *Record) {
893
911
894
912
void Scheduler::GraphBuilder::cleanupFinishedCommands (Command *FinishedCmd) {
895
913
std::queue<Command *> CmdsToVisit ({FinishedCmd});
896
- std::set <Command *> Visited;
914
+ std::vector <Command *> Visited;
897
915
898
916
// Traverse the graph using BFS
899
917
while (!CmdsToVisit.empty ()) {
900
918
Command *Cmd = CmdsToVisit.front ();
901
919
CmdsToVisit.pop ();
902
920
903
- if (!Visited. insert (Cmd). second )
921
+ if (!markNodeAsVisited (Cmd, Visited) )
904
922
continue ;
905
923
906
924
for (const DepDesc &Dep : Cmd->MDeps ) {
@@ -932,8 +950,10 @@ void Scheduler::GraphBuilder::cleanupFinishedCommands(Command *FinishedCmd) {
932
950
}
933
951
Cmd->getEvent ()->setCommand (nullptr );
934
952
953
+ Visited.pop_back ();
935
954
delete Cmd;
936
955
}
956
+ unmarkVisitedNodes (Visited);
937
957
}
938
958
939
959
void Scheduler::GraphBuilder::removeRecordForMemObj (SYCLMemObjI *MemObject) {
0 commit comments