Skip to content

Commit 508ee90

Browse files
committed
Addressing PR comments
1 parent 5501907 commit 508ee90

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ void duplicateNode(const std::shared_ptr<node_impl> Node,
7474
}
7575
}
7676

77-
} // anonymous namespace
78-
7977
/// Recursively add nodes to execution stack.
8078
/// @param NodeImpl Node to schedule.
8179
/// @param Schedule Execution ordering to add node to.
@@ -90,6 +88,7 @@ void sortTopological(std::shared_ptr<node_impl> NodeImpl,
9088

9189
Schedule.push_front(NodeImpl);
9290
}
91+
} // anonymous namespace
9392

9493
void exec_graph_impl::schedule() {
9594
if (MSchedule.empty()) {
@@ -122,8 +121,7 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
122121

123122
// Find all exit nodes in the current graph and register the Inputs as
124123
// successors
125-
for (size_t i = 0; i < MNodeStorage.size(); i++) {
126-
auto NodeImpl = MNodeStorage[i];
124+
for (auto &NodeImpl : MNodeStorage) {
127125
if (NodeImpl->MSuccessors.size() == 0) {
128126
for (auto &Input : Inputs) {
129127
NodeImpl->registerSuccessor(Input, NodeImpl);
@@ -184,16 +182,8 @@ graph_impl::add(const std::vector<std::shared_ptr<node_impl>> &Dep) {
184182
Deps.insert(Deps.end(), MExtraDependencies.begin(), MExtraDependencies.end());
185183

186184
MNodeStorage.push_back(NodeImpl);
187-
// TODO: Encapsulate in separate function to avoid duplication
188-
if (!Deps.empty()) {
189-
for (auto &N : Deps) {
190-
N->registerSuccessor(NodeImpl, N); // register successor
191-
this->removeRoot(NodeImpl); // remove receiver from root node
192-
// list
193-
}
194-
} else {
195-
this->addRoot(NodeImpl);
196-
}
185+
186+
addDepsToNode(NodeImpl, Deps);
197187

198188
return NodeImpl;
199189
}
@@ -268,7 +258,7 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
268258
MemObj->markBeingUsedInGraph();
269259
}
270260
// Look through the graph for nodes which share this requirement
271-
for (auto Node : MNodeStorage) {
261+
for (auto &Node : MNodeStorage) {
272262
if (Node->hasRequirement(Req)) {
273263
bool ShouldAddDep = true;
274264
// If any of this node's successors have this requirement then we skip
@@ -304,15 +294,7 @@ graph_impl::add(sycl::detail::CG::CGTYPE CGType,
304294
std::make_shared<node_impl>(CGType, std::move(CommandGroup));
305295
MNodeStorage.push_back(NodeImpl);
306296

307-
if (!Deps.empty()) {
308-
for (auto &N : Deps) {
309-
N->registerSuccessor(NodeImpl, N); // register successor
310-
this->removeRoot(NodeImpl); // remove receiver from root node
311-
// list
312-
}
313-
} else {
314-
this->addRoot(NodeImpl);
315-
}
297+
addDepsToNode(NodeImpl, Deps);
316298

317299
// Set barrier nodes as prerequisites (new start points) for subsequent nodes
318300
if (CGType == sycl::detail::CG::Barrier) {
@@ -387,7 +369,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
387369

388370
bool SrcFound = false;
389371
bool DestFound = false;
390-
for (auto Node : MNodeStorage) {
372+
for (const auto &Node : MNodeStorage) {
391373

392374
SrcFound |= Node == Src;
393375
DestFound |= Node == Dest;

sycl/source/detail/graph_impl.hpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class node_impl {
178178
return nullptr;
179179
}
180180

181-
/// Tests is the caller is similar to Node, this is only used for testing.
181+
/// Tests if the caller is similar to Node, this is only used for testing.
182182
/// @param Node The node to check for similarity.
183183
/// @param CompareContentOnly Skip comparisons related to graph structure,
184184
/// compare only the type and command groups of the nodes
@@ -381,7 +381,10 @@ class graph_impl {
381381
/// List of root nodes.
382382
std::set<std::shared_ptr<node_impl>> MRoots;
383383

384-
/// Storage for all nodes contained within a graph
384+
/// Storage for all nodes contained within a graph. Nodes are connected to
385+
/// each other via weak_ptrs and so do not extend each other's lifetimes.
386+
/// This storage allows easy iteration over all nodes in the graph, rather
387+
/// than needing an expensive depth first search.
385388
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
386389

387390
/// Find the last node added to this graph from an in-order queue.
@@ -430,8 +433,8 @@ class graph_impl {
430433
/// @param NodeA pointer to the first node for comparison
431434
/// @param NodeB pointer to the second node for comparison
432435
/// @return true is same structure found, false otherwise
433-
bool checkNodeRecursive(std::shared_ptr<node_impl> NodeA,
434-
std::shared_ptr<node_impl> NodeB) const {
436+
static bool checkNodeRecursive(std::shared_ptr<node_impl> NodeA,
437+
std::shared_ptr<node_impl> NodeB) {
435438
size_t FoundCnt = 0;
436439
for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
437440
for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
@@ -565,6 +568,22 @@ class graph_impl {
565568
std::shared_ptr<node_impl>
566569
addNodesToExits(const std::list<std::shared_ptr<node_impl>> &NodeList);
567570

571+
/// Adds dependencies for a new node, if it has no deps it will be
572+
/// added as a root node.
573+
/// @param Node The node to add deps for
574+
/// @param Deps List of dependent nodes
575+
void addDepsToNode(std::shared_ptr<node_impl> Node,
576+
const std::vector<std::shared_ptr<node_impl>> &Deps) {
577+
if (!Deps.empty()) {
578+
for (auto &N : Deps) {
579+
N->registerSuccessor(Node, N);
580+
this->removeRoot(Node);
581+
}
582+
} else {
583+
this->addRoot(Node);
584+
}
585+
}
586+
568587
/// Context associated with this graph.
569588
sycl::context MContext;
570589
/// Device associated with this graph. All graph nodes will execute on this

0 commit comments

Comments
 (0)