@@ -29,50 +29,6 @@ namespace experimental {
29
29
namespace detail {
30
30
31
31
namespace {
32
-
33
- // / Recursively check if a given node is an exit node, and add the new nodes as
34
- // / successors if so.
35
- // / @param[in] CurrentNode Node to check as exit node.
36
- // / @param[in] NewInputs Noes to add as successors.
37
- void connectToExitNodes (
38
- std::shared_ptr<node_impl> CurrentNode,
39
- const std::vector<std::shared_ptr<node_impl>> &NewInputs) {
40
- if (CurrentNode->MSuccessors .size () > 0 ) {
41
- for (auto &Successor : CurrentNode->MSuccessors ) {
42
- connectToExitNodes (Successor, NewInputs);
43
- }
44
-
45
- } else {
46
- for (auto &Input : NewInputs) {
47
- CurrentNode->registerSuccessor (Input, CurrentNode);
48
- }
49
- }
50
- }
51
-
52
- // / Recursive check if a graph node or its successors contains a given
53
- // / requirement.
54
- // / @param[in] Req The requirement to check for.
55
- // / @param[in] CurrentNode The current graph node being checked.
56
- // / @param[in,out] Deps The unique list of dependencies which have been
57
- // / identified for this requirement.
58
- // / @return True if a dependency was added in this node or any of its
59
- // / successors.
60
- bool checkForRequirement (sycl::detail::AccessorImplHost *Req,
61
- const std::shared_ptr<node_impl> &CurrentNode,
62
- std::set<std::shared_ptr<node_impl>> &Deps) {
63
- bool SuccessorAddedDep = false ;
64
- for (auto &Successor : CurrentNode->MSuccessors ) {
65
- SuccessorAddedDep |= checkForRequirement (Req, Successor, Deps);
66
- }
67
-
68
- if (!CurrentNode->isEmpty () && Deps.find (CurrentNode) == Deps.end () &&
69
- CurrentNode->hasRequirement (Req) && !SuccessorAddedDep) {
70
- Deps.insert (CurrentNode);
71
- return true ;
72
- }
73
- return SuccessorAddedDep;
74
- }
75
-
76
32
// / Visits a node on the graph and it's successors recursively in a depth-first
77
33
// / approach.
78
34
// / @param[in] Node The current node being visited.
@@ -99,7 +55,8 @@ bool visitNodeDepthFirst(
99
55
Node->MVisited = true ;
100
56
VisitedNodes.emplace (Node);
101
57
for (auto &Successor : Node->MSuccessors ) {
102
- if (visitNodeDepthFirst (Successor, VisitedNodes, NodeStack, NodeFunc)) {
58
+ if (visitNodeDepthFirst (Successor.lock (), VisitedNodes, NodeStack,
59
+ NodeFunc)) {
103
60
return true ;
104
61
}
105
62
}
@@ -117,12 +74,28 @@ void duplicateNode(const std::shared_ptr<node_impl> Node,
117
74
}
118
75
}
119
76
77
+ // / Recursively add nodes to execution stack.
78
+ // / @param NodeImpl Node to schedule.
79
+ // / @param Schedule Execution ordering to add node to.
80
+ void sortTopological (std::shared_ptr<node_impl> NodeImpl,
81
+ std::list<std::shared_ptr<node_impl>> &Schedule) {
82
+ for (auto &Succ : NodeImpl->MSuccessors ) {
83
+ // Check if we've already scheduled this node
84
+ auto NextNode = Succ.lock ();
85
+ if (std::find (Schedule.begin (), Schedule.end (), NextNode) ==
86
+ Schedule.end ()) {
87
+ sortTopological (NextNode, Schedule);
88
+ }
89
+ }
90
+
91
+ Schedule.push_front (NodeImpl);
92
+ }
120
93
} // anonymous namespace
121
94
122
95
void exec_graph_impl::schedule () {
123
96
if (MSchedule.empty ()) {
124
97
for (auto &Node : MGraphImpl->MRoots ) {
125
- Node-> sortTopological (Node, MSchedule);
98
+ sortTopological (Node. lock () , MSchedule);
126
99
}
127
100
}
128
101
}
@@ -148,10 +121,19 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
148
121
}
149
122
}
150
123
151
- // Recursively walk the graph to find exit nodes and connect up the inputs
152
- // TODO: Consider caching exit nodes so we don't have to do this
153
- for (auto &NodeImpl : MRoots) {
154
- connectToExitNodes (NodeImpl, Inputs);
124
+ // Find all exit nodes in the current graph and register the Inputs as
125
+ // successors
126
+ for (auto &NodeImpl : MNodeStorage) {
127
+ if (NodeImpl->MSuccessors .size () == 0 ) {
128
+ for (auto &Input : Inputs) {
129
+ NodeImpl->registerSuccessor (Input, NodeImpl);
130
+ }
131
+ }
132
+ }
133
+
134
+ // Add all the new nodes to the node storage
135
+ for (auto &Node : NodeList) {
136
+ MNodeStorage.push_back (Node);
155
137
}
156
138
157
139
return this ->add (Outputs);
@@ -175,7 +157,7 @@ std::shared_ptr<node_impl> graph_impl::addSubgraphNodes(
175
157
*NewNodesIt = NodeCopy;
176
158
NodesMap.insert ({Node, NodeCopy});
177
159
for (auto &NextNode : Node->MSuccessors ) {
178
- auto Successor = NodesMap.at (NextNode);
160
+ auto Successor = NodesMap.at (NextNode. lock () );
179
161
NodeCopy->registerSuccessor (Successor, NodeCopy);
180
162
}
181
163
}
@@ -201,16 +183,9 @@ graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {
201
183
// Add any deps from the vector of extra dependencies
202
184
Deps.insert (Deps.end (), MExtraDependencies.begin (), MExtraDependencies.end ());
203
185
204
- // TODO: Encapsulate in separate function to avoid duplication
205
- if (!Deps.empty ()) {
206
- for (auto &N : Deps) {
207
- N->registerSuccessor (NodeImpl, N); // register successor
208
- this ->removeRoot (NodeImpl); // remove receiver from root node
209
- // list
210
- }
211
- } else {
212
- this ->addRoot (NodeImpl);
213
- }
186
+ MNodeStorage.push_back (NodeImpl);
187
+
188
+ addDepsToNode (NodeImpl, Deps);
214
189
215
190
return NodeImpl;
216
191
}
@@ -285,17 +260,28 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
285
260
MemObj->markBeingUsedInGraph ();
286
261
}
287
262
// Look through the graph for nodes which share this requirement
288
- for (auto &NodePtr : MRoots) {
289
- checkForRequirement (Req, NodePtr, UniqueDeps);
263
+ for (auto &Node : MNodeStorage) {
264
+ if (Node->hasRequirement (Req)) {
265
+ bool ShouldAddDep = true ;
266
+ // If any of this node's successors have this requirement then we skip
267
+ // adding the current node as a dependency.
268
+ for (auto &Succ : Node->MSuccessors ) {
269
+ if (Succ.lock ()->hasRequirement (Req)) {
270
+ ShouldAddDep = false ;
271
+ break ;
272
+ }
273
+ }
274
+ if (ShouldAddDep) {
275
+ UniqueDeps.insert (Node);
276
+ }
277
+ }
290
278
}
291
279
}
292
280
293
281
// Add any nodes specified by event dependencies into the dependency list
294
282
for (auto &Dep : CommandGroup->getEvents ()) {
295
283
if (auto NodeImpl = MEventsMap.find (Dep); NodeImpl != MEventsMap.end ()) {
296
- if (UniqueDeps.find (NodeImpl->second ) == UniqueDeps.end ()) {
297
- UniqueDeps.insert (NodeImpl->second );
298
- }
284
+ UniqueDeps.insert (NodeImpl->second );
299
285
} else {
300
286
throw sycl::exception (sycl::make_error_code (errc::invalid),
301
287
" Event dependency from handler::depends_on does "
@@ -311,15 +297,9 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
311
297
312
298
const std::shared_ptr<node_impl> &NodeImpl =
313
299
std::make_shared<node_impl>(CGType, std::move (CommandGroup));
314
- if (!Deps.empty ()) {
315
- for (auto &N : Deps) {
316
- N->registerSuccessor (NodeImpl, N); // register successor
317
- this ->removeRoot (NodeImpl); // remove receiver from root node
318
- // list
319
- }
320
- } else {
321
- this ->addRoot (NodeImpl);
322
- }
300
+ MNodeStorage.push_back (NodeImpl);
301
+
302
+ addDepsToNode (NodeImpl, Deps);
323
303
324
304
// Set barrier nodes as prerequisites (new start points) for subsequent nodes
325
305
if (CGType == sycl::detail::CG::Barrier) {
@@ -353,7 +333,7 @@ void graph_impl::searchDepthFirst(
353
333
354
334
for (auto &Root : MRoots) {
355
335
std::deque<std::shared_ptr<node_impl>> NodeStack;
356
- if (visitNodeDepthFirst (Root, VisitedNodes, NodeStack, NodeFunc)) {
336
+ if (visitNodeDepthFirst (Root. lock () , VisitedNodes, NodeStack, NodeFunc)) {
357
337
break ;
358
338
}
359
339
}
@@ -394,18 +374,15 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
394
374
395
375
bool SrcFound = false ;
396
376
bool DestFound = false ;
397
- auto CheckForNodes = [&](std::shared_ptr<node_impl> &Node,
398
- std::deque<std::shared_ptr<node_impl>> &) {
399
- if (Node == Src) {
400
- SrcFound = true ;
401
- }
402
- if (Node == Dest) {
403
- DestFound = true ;
404
- }
405
- return SrcFound && DestFound;
406
- };
377
+ for (const auto &Node : MNodeStorage) {
378
+
379
+ SrcFound |= Node == Src;
380
+ DestFound |= Node == Dest;
407
381
408
- searchDepthFirst (CheckForNodes);
382
+ if (SrcFound && DestFound) {
383
+ break ;
384
+ }
385
+ }
409
386
410
387
if (!SrcFound) {
411
388
throw sycl::exception (make_error_code (sycl::errc::invalid),
0 commit comments