Skip to content

Commit d2578a3

Browse files
authored
[Sycl][Graph] Reimplement topological sort algorithm (#17495)
The current implementation of the topological sort algorithm relies on a recursive function. This causes stack overflow issues when the graphs are very large. There is also a performance issue in the current implementation related to the way std::find() is being used. This commit reimplements the topological sort algorithm to fix these issues: - Uses an iterative approach instead of recursive. - Removes the use of std::find() and instead relies on keeping the state with a helper variable. - Reimplements the cycle checking algorithm to use the same topological sort function as the one used during scheduling. - Fixes a bug with make_edge() not removing the dest node from the list of root nodes.
1 parent 4463200 commit d2578a3

File tree

5 files changed

+204
-112
lines changed

5 files changed

+204
-112
lines changed

sycl/doc/design/CommandGraph.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,15 @@ Edges are stored in each node as lists of predecessor and successor nodes.
100100

101101
## Execution Order
102102

103-
The current way graph nodes are linearized into execution order is using a
104-
reversed depth-first sorting algorithm. Alternative algorithms, such as
105-
breadth-first, are possible and may give better performance on certain
106-
workloads/hardware. In the future there might be options for allowing the
107-
user to control this implementation detail.
103+
Graph nodes are currently linearized into execution order using a topological
104+
sort algorithm. This algorithm uses a breadth-first search approach
105+
and is an implementation of Kahn's algorithm. Alternative algorithms, such as
106+
depth-first search, are possible and may give better performance on certain
107+
workloads/hardware. However, depth first searches are usually implemented
108+
using recursion, which can lead to stack overflow issues on large graphs.
109+
In the future, there might be options for allowing the user to control this
110+
implementation detail. It might also be possible to automatically change
111+
which algorithm is used based on characteristics such as the graph's size.
108112

109113
## Scheduler Integration
110114

sycl/source/detail/graph_impl.cpp

Lines changed: 67 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#define __SYCL_GRAPH_IMPL_CPP
1010

11+
#include <stack>
1112
#include <detail/graph_impl.hpp>
1213
#include <detail/handler_impl.hpp>
1314
#include <detail/kernel_arg_mask.hpp>
@@ -31,64 +32,47 @@ namespace experimental {
3132
namespace detail {
3233

3334
namespace {
34-
/// Visits a node on the graph and it's successors recursively in a depth-first
35-
/// approach.
36-
/// @param[in] Node The current node being visited.
37-
/// @param[in,out] VisitedNodes A set of unique nodes which have already been
38-
/// visited.
39-
/// @param[in] NodeStack Stack of nodes which are currently being visited on the
40-
/// current path through the graph.
41-
/// @param[in] NodeFunc The function object to be run on each node. A return
42-
/// value of true indicates the search should be ended immediately and the
43-
/// function will return.
44-
/// @return True if the search should end immediately, false if not.
45-
bool visitNodeDepthFirst(
46-
std::shared_ptr<node_impl> Node,
47-
std::set<std::shared_ptr<node_impl>> &VisitedNodes,
48-
std::deque<std::shared_ptr<node_impl>> &NodeStack,
49-
std::function<bool(std::shared_ptr<node_impl> &,
50-
std::deque<std::shared_ptr<node_impl>> &)>
51-
NodeFunc) {
52-
auto EarlyReturn = NodeFunc(Node, NodeStack);
53-
if (EarlyReturn) {
54-
return true;
55-
}
56-
NodeStack.push_back(Node);
57-
Node->MVisited = true;
58-
VisitedNodes.emplace(Node);
59-
for (auto &Successor : Node->MSuccessors) {
60-
if (visitNodeDepthFirst(Successor.lock(), VisitedNodes, NodeStack,
61-
NodeFunc)) {
62-
return true;
63-
}
64-
}
65-
NodeStack.pop_back();
66-
return false;
67-
}
68-
69-
/// Recursively add nodes to execution stack.
70-
/// @param NodeImpl Node to schedule.
71-
/// @param Schedule Execution ordering to add node to.
72-
/// @param PartitionBounded If set to true, the topological sort is stopped at
73-
/// partition borders. Hence, nodes belonging to a partition different from the
74-
/// NodeImpl partition are not processed.
75-
void sortTopological(std::shared_ptr<node_impl> NodeImpl,
76-
std::list<std::shared_ptr<node_impl>> &Schedule,
77-
bool PartitionBounded = false) {
78-
for (auto &Succ : NodeImpl->MSuccessors) {
79-
auto NextNode = Succ.lock();
80-
if (PartitionBounded &&
81-
(NextNode->MPartitionNum != NodeImpl->MPartitionNum)) {
82-
continue;
83-
}
84-
// Check if we've already scheduled this node
85-
if (std::find(Schedule.begin(), Schedule.end(), NextNode) ==
86-
Schedule.end()) {
87-
sortTopological(NextNode, Schedule, PartitionBounded);
35+
/// Topologically sorts the graph in order to schedule nodes for execution.
36+
/// This implementation is based on Kahn's algorithm which uses a Breadth-first
37+
/// search approach.
38+
/// For performance reasons, this function uses the MTotalVisitedEdges
39+
/// member variable of the node_impl class. It's the caller responsibility to
40+
/// make sure that MTotalVisitedEdges is set to 0 for all nodes in the graph
41+
/// before calling this function.
42+
/// @param[in] Roots List of root nodes.
43+
/// @param[out] SortedNodes The graph nodes sorted in topological order.
44+
/// @param[in] PartitionBounded If set to true, the topological sort is stopped
45+
/// at partition borders. Hence, nodes belonging to a partition different from
46+
/// the NodeImpl partition are not processed.
47+
void sortTopological(std::set<std::weak_ptr<node_impl>,
48+
std::owner_less<std::weak_ptr<node_impl>>> &Roots,
49+
std::list<std::shared_ptr<node_impl>> &SortedNodes,
50+
bool PartitionBounded) {
51+
std::stack<std::weak_ptr<node_impl>> Source;
52+
53+
for (auto &Node : Roots) {
54+
Source.push(Node);
55+
}
56+
57+
while (!Source.empty()) {
58+
auto Node = Source.top().lock();
59+
Source.pop();
60+
SortedNodes.push_back(Node);
61+
62+
for (auto &SuccWP : Node->MSuccessors) {
63+
auto Succ = SuccWP.lock();
64+
65+
if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) {
66+
continue;
67+
}
68+
69+
auto &TotalVisitedEdges = Succ->MTotalVisitedEdges;
70+
++TotalVisitedEdges;
71+
if (TotalVisitedEdges == Succ->MPredecessors.size()) {
72+
Source.push(Succ);
73+
}
8874
}
8975
}
90-
91-
Schedule.push_front(NodeImpl);
9276
}
9377

9478
/// Propagates the partition number `PartitionNum` to predecessors.
@@ -180,9 +164,9 @@ std::vector<node> createNodesFromImpls(
180164

181165
void partition::schedule() {
182166
if (MSchedule.empty()) {
183-
for (auto &Node : MRoots) {
184-
sortTopological(Node.lock(), MSchedule, true);
185-
}
167+
// There is no need to reset MTotalVisitedEdges before calling
168+
// sortTopological because this function is only called once per partition.
169+
sortTopological(MRoots, MSchedule, true);
186170
}
187171
}
188172

@@ -311,6 +295,7 @@ static void checkGraphPropertiesAndThrow(const property_list &Properties) {
311295
#define __SYCL_MANUALLY_DEFINED_PROP(NS_QUALIFIER, PROP_NAME)
312296
switch (PropertyKind) {
313297
#include <sycl/ext/oneapi/experimental/detail/properties/graph_properties.def>
298+
314299
default:
315300
return false;
316301
}
@@ -627,44 +612,20 @@ bool graph_impl::clearQueues() {
627612
return AnyQueuesCleared;
628613
}
629614

630-
void graph_impl::searchDepthFirst(
631-
std::function<bool(std::shared_ptr<node_impl> &,
632-
std::deque<std::shared_ptr<node_impl>> &)>
633-
NodeFunc) {
634-
// Track nodes visited during the search which can be used by NodeFunc in
635-
// depth first search queries. Currently unusued but is an
636-
// integral part of depth first searches.
637-
std::set<std::shared_ptr<node_impl>> VisitedNodes;
615+
bool graph_impl::checkForCycles() {
616+
std::list<std::shared_ptr<node_impl>> SortedNodes;
617+
sortTopological(MRoots, SortedNodes, false);
638618

639-
for (auto &Root : MRoots) {
640-
std::deque<std::shared_ptr<node_impl>> NodeStack;
641-
if (visitNodeDepthFirst(Root.lock(), VisitedNodes, NodeStack, NodeFunc)) {
642-
break;
643-
}
644-
}
619+
// If after a topological sort, not all the nodes in the graph are sorted,
620+
// then there must be at least one cycle in the graph. This is guaranteed
621+
// by Kahn's algorithm, which sortTopological() implements.
622+
bool CycleFound = SortedNodes.size() != MNodeStorage.size();
645623

646-
// Reset the visited status of all nodes encountered in the search.
647-
for (auto &Node : VisitedNodes) {
648-
Node->MVisited = false;
624+
// Reset the MTotalVisitedEdges variable to prepare for the next cycle check.
625+
for (auto &Node : MNodeStorage) {
626+
Node->MTotalVisitedEdges = 0;
649627
}
650-
}
651628

652-
bool graph_impl::checkForCycles() {
653-
// Using a depth-first search and checking if we vist a node more than once in
654-
// the current path to identify if there are cycles.
655-
bool CycleFound = false;
656-
auto CheckFunc = [&](std::shared_ptr<node_impl> &Node,
657-
std::deque<std::shared_ptr<node_impl>> &NodeStack) {
658-
// If the current node has previously been found in the current path through
659-
// the graph then we have a cycle and we end the search early.
660-
if (std::find(NodeStack.begin(), NodeStack.end(), Node) !=
661-
NodeStack.end()) {
662-
CycleFound = true;
663-
return true;
664-
}
665-
return false;
666-
};
667-
searchDepthFirst(CheckFunc);
668629
return CycleFound;
669630
}
670631

@@ -698,19 +659,31 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
698659
"Dest must be a node inside the graph.");
699660
}
700661

662+
bool DestWasGraphRoot = Dest->MPredecessors.size() == 0;
663+
701664
// We need to add the edges first before checking for cycles
702665
Src->registerSuccessor(Dest);
703666

667+
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
668+
if (DestLostRootStatus) {
669+
// Dest is no longer a Root node, so we need to remove it from MRoots.
670+
MRoots.erase(Dest);
671+
}
672+
704673
// We can skip cycle checks if either Dest has no successors (cycle not
705674
// possible) or cycle checks have been disabled with the no_cycle_check
706675
// property;
707676
if (Dest->MSuccessors.empty() || !MSkipCycleChecks) {
708677
bool CycleFound = checkForCycles();
709678

710679
if (CycleFound) {
711-
// Remove the added successor and predecessor
680+
// Remove the added successor and predecessor.
712681
Src->MSuccessors.pop_back();
713682
Dest->MPredecessors.pop_back();
683+
if (DestLostRootStatus) {
684+
// Add Dest back into MRoots.
685+
MRoots.insert(Dest);
686+
}
714687

715688
throw sycl::exception(make_error_code(sycl::errc::invalid),
716689
"Command graphs cannot contain cycles.");

sycl/source/detail/graph_impl.hpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
102102
/// subgraph node.
103103
std::shared_ptr<exec_graph_impl> MSubGraphImpl;
104104

105-
/// Used for tracking visited status during cycle checks.
106-
bool MVisited = false;
105+
/// Used for tracking visited status during cycle checks and node scheduling.
106+
size_t MTotalVisitedEdges = 0;
107107

108108
/// Partition number needed to assign a Node to a a partition.
109109
/// Note : This number is only used during the partitionning process and
@@ -1130,17 +1130,6 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11301130
unsigned long long getID() const { return MID; }
11311131

11321132
private:
1133-
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
1134-
/// @param NodeFunc A function which receives as input a node in the graph to
1135-
/// perform operations on as well as the stack of nodes encountered in the
1136-
/// current path. The return value of this function determines whether an
1137-
/// early exit is triggered, if true the depth-first search will end
1138-
/// immediately and no further nodes will be visited.
1139-
void
1140-
searchDepthFirst(std::function<bool(std::shared_ptr<node_impl> &,
1141-
std::deque<std::shared_ptr<node_impl>> &)>
1142-
NodeFunc);
1143-
11441133
/// Check the graph for cycles by performing a depth-first search of the
11451134
/// graph. If a node is visited more than once in a given path through the
11461135
/// graph, a cycle is present and the search ends immediately.

sycl/unittests/Extensions/CommandGraph/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_sycl_unittest(CommandGraphExtensionTests OBJECT
1010
Queries.cpp
1111
Regressions.cpp
1212
Subgraph.cpp
13+
TopologicalSort.cpp
1314
Update.cpp
1415
Properties.cpp
1516
)

0 commit comments

Comments
 (0)