Skip to content

Commit 67a81f6

Browse files
authored
[SYCL][Graph] Refactor node storage inside graphs (#11596)
- Store graph nodes inside a vector for more optimal searches - Replace several depth first search operations with iterations over node storage - Node successors are now weak_ptrs - Unit tests updated to reflect changes
1 parent 444afde commit 67a81f6

File tree

3 files changed

+248
-253
lines changed

3 files changed

+248
-253
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 64 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -29,50 +29,6 @@ namespace experimental {
2929
namespace detail {
3030

3131
namespace {
32-
33-
/// Recursively check if a given node is an exit node, and add the new nodes as
34-
/// successors if so.
35-
/// @param[in] CurrentNode Node to check as exit node.
36-
/// @param[in] NewInputs Noes to add as successors.
37-
void connectToExitNodes(
38-
std::shared_ptr<node_impl> CurrentNode,
39-
const std::vector<std::shared_ptr<node_impl>> &NewInputs) {
40-
if (CurrentNode->MSuccessors.size() > 0) {
41-
for (auto &Successor : CurrentNode->MSuccessors) {
42-
connectToExitNodes(Successor, NewInputs);
43-
}
44-
45-
} else {
46-
for (auto &Input : NewInputs) {
47-
CurrentNode->registerSuccessor(Input, CurrentNode);
48-
}
49-
}
50-
}
51-
52-
/// Recursive check if a graph node or its successors contains a given
53-
/// requirement.
54-
/// @param[in] Req The requirement to check for.
55-
/// @param[in] CurrentNode The current graph node being checked.
56-
/// @param[in,out] Deps The unique list of dependencies which have been
57-
/// identified for this requirement.
58-
/// @return True if a dependency was added in this node or any of its
59-
/// successors.
60-
bool checkForRequirement(sycl::detail::AccessorImplHost *Req,
61-
const std::shared_ptr<node_impl> &CurrentNode,
62-
std::set<std::shared_ptr<node_impl>> &Deps) {
63-
bool SuccessorAddedDep = false;
64-
for (auto &Successor : CurrentNode->MSuccessors) {
65-
SuccessorAddedDep |= checkForRequirement(Req, Successor, Deps);
66-
}
67-
68-
if (!CurrentNode->isEmpty() && Deps.find(CurrentNode) == Deps.end() &&
69-
CurrentNode->hasRequirement(Req) && !SuccessorAddedDep) {
70-
Deps.insert(CurrentNode);
71-
return true;
72-
}
73-
return SuccessorAddedDep;
74-
}
75-
7632
/// Visits a node on the graph and it's successors recursively in a depth-first
7733
/// approach.
7834
/// @param[in] Node The current node being visited.
@@ -99,7 +55,8 @@ bool visitNodeDepthFirst(
9955
Node->MVisited = true;
10056
VisitedNodes.emplace(Node);
10157
for (auto &Successor : Node->MSuccessors) {
102-
if (visitNodeDepthFirst(Successor, VisitedNodes, NodeStack, NodeFunc)) {
58+
if (visitNodeDepthFirst(Successor.lock(), VisitedNodes, NodeStack,
59+
NodeFunc)) {
10360
return true;
10461
}
10562
}
@@ -117,12 +74,28 @@ void duplicateNode(const std::shared_ptr<node_impl> Node,
11774
}
11875
}
11976

77+
/// Recursively add nodes to execution stack.
78+
/// @param NodeImpl Node to schedule.
79+
/// @param Schedule Execution ordering to add node to.
80+
void sortTopological(std::shared_ptr<node_impl> NodeImpl,
81+
std::list<std::shared_ptr<node_impl>> &Schedule) {
82+
for (auto &Succ : NodeImpl->MSuccessors) {
83+
// Check if we've already scheduled this node
84+
auto NextNode = Succ.lock();
85+
if (std::find(Schedule.begin(), Schedule.end(), NextNode) ==
86+
Schedule.end()) {
87+
sortTopological(NextNode, Schedule);
88+
}
89+
}
90+
91+
Schedule.push_front(NodeImpl);
92+
}
12093
} // anonymous namespace
12194

12295
void exec_graph_impl::schedule() {
12396
if (MSchedule.empty()) {
12497
for (auto &Node : MGraphImpl->MRoots) {
125-
Node->sortTopological(Node, MSchedule);
98+
sortTopological(Node.lock(), MSchedule);
12699
}
127100
}
128101
}
@@ -148,10 +121,19 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
148121
}
149122
}
150123

151-
// Recursively walk the graph to find exit nodes and connect up the inputs
152-
// TODO: Consider caching exit nodes so we don't have to do this
153-
for (auto &NodeImpl : MRoots) {
154-
connectToExitNodes(NodeImpl, Inputs);
124+
// Find all exit nodes in the current graph and register the Inputs as
125+
// successors
126+
for (auto &NodeImpl : MNodeStorage) {
127+
if (NodeImpl->MSuccessors.size() == 0) {
128+
for (auto &Input : Inputs) {
129+
NodeImpl->registerSuccessor(Input, NodeImpl);
130+
}
131+
}
132+
}
133+
134+
// Add all the new nodes to the node storage
135+
for (auto &Node : NodeList) {
136+
MNodeStorage.push_back(Node);
155137
}
156138

157139
return this->add(Outputs);
@@ -175,7 +157,7 @@ std::shared_ptr<node_impl> graph_impl::addSubgraphNodes(
175157
*NewNodesIt = NodeCopy;
176158
NodesMap.insert({Node, NodeCopy});
177159
for (auto &NextNode : Node->MSuccessors) {
178-
auto Successor = NodesMap.at(NextNode);
160+
auto Successor = NodesMap.at(NextNode.lock());
179161
NodeCopy->registerSuccessor(Successor, NodeCopy);
180162
}
181163
}
@@ -201,16 +183,9 @@ graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {
201183
// Add any deps from the vector of extra dependencies
202184
Deps.insert(Deps.end(), MExtraDependencies.begin(), MExtraDependencies.end());
203185

204-
// TODO: Encapsulate in separate function to avoid duplication
205-
if (!Deps.empty()) {
206-
for (auto &N : Deps) {
207-
N->registerSuccessor(NodeImpl, N); // register successor
208-
this->removeRoot(NodeImpl); // remove receiver from root node
209-
// list
210-
}
211-
} else {
212-
this->addRoot(NodeImpl);
213-
}
186+
MNodeStorage.push_back(NodeImpl);
187+
188+
addDepsToNode(NodeImpl, Deps);
214189

215190
return NodeImpl;
216191
}
@@ -285,17 +260,28 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
285260
MemObj->markBeingUsedInGraph();
286261
}
287262
// Look through the graph for nodes which share this requirement
288-
for (auto &NodePtr : MRoots) {
289-
checkForRequirement(Req, NodePtr, UniqueDeps);
263+
for (auto &Node : MNodeStorage) {
264+
if (Node->hasRequirement(Req)) {
265+
bool ShouldAddDep = true;
266+
// If any of this node's successors have this requirement then we skip
267+
// adding the current node as a dependency.
268+
for (auto &Succ : Node->MSuccessors) {
269+
if (Succ.lock()->hasRequirement(Req)) {
270+
ShouldAddDep = false;
271+
break;
272+
}
273+
}
274+
if (ShouldAddDep) {
275+
UniqueDeps.insert(Node);
276+
}
277+
}
290278
}
291279
}
292280

293281
// Add any nodes specified by event dependencies into the dependency list
294282
for (auto &Dep : CommandGroup->getEvents()) {
295283
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl != MEventsMap.end()) {
296-
if (UniqueDeps.find(NodeImpl->second) == UniqueDeps.end()) {
297-
UniqueDeps.insert(NodeImpl->second);
298-
}
284+
UniqueDeps.insert(NodeImpl->second);
299285
} else {
300286
throw sycl::exception(sycl::make_error_code(errc::invalid),
301287
"Event dependency from handler::depends_on does "
@@ -311,15 +297,9 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
311297

312298
const std::shared_ptr<node_impl> &NodeImpl =
313299
std::make_shared<node_impl>(CGType, std::move(CommandGroup));
314-
if (!Deps.empty()) {
315-
for (auto &N : Deps) {
316-
N->registerSuccessor(NodeImpl, N); // register successor
317-
this->removeRoot(NodeImpl); // remove receiver from root node
318-
// list
319-
}
320-
} else {
321-
this->addRoot(NodeImpl);
322-
}
300+
MNodeStorage.push_back(NodeImpl);
301+
302+
addDepsToNode(NodeImpl, Deps);
323303

324304
// Set barrier nodes as prerequisites (new start points) for subsequent nodes
325305
if (CGType == sycl::detail::CG::Barrier) {
@@ -353,7 +333,7 @@ void graph_impl::searchDepthFirst(
353333

354334
for (auto &Root : MRoots) {
355335
std::deque<std::shared_ptr<node_impl>> NodeStack;
356-
if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) {
336+
if (visitNodeDepthFirst(Root.lock(), VisitedNodes, NodeStack, NodeFunc)) {
357337
break;
358338
}
359339
}
@@ -394,18 +374,15 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
394374

395375
bool SrcFound = false;
396376
bool DestFound = false;
397-
auto CheckForNodes = [&](std::shared_ptr<node_impl> &Node,
398-
std::deque<std::shared_ptr<node_impl>> &) {
399-
if (Node == Src) {
400-
SrcFound = true;
401-
}
402-
if (Node == Dest) {
403-
DestFound = true;
404-
}
405-
return SrcFound && DestFound;
406-
};
377+
for (const auto &Node : MNodeStorage) {
378+
379+
SrcFound |= Node == Src;
380+
DestFound |= Node == Dest;
407381

408-
searchDepthFirst(CheckForNodes);
382+
if (SrcFound && DestFound) {
383+
break;
384+
}
385+
}
409386

410387
if (!SrcFound) {
411388
throw sycl::exception(make_error_code(sycl::errc::invalid),

0 commit comments

Comments
 (0)