8
8
9
9
#define __SYCL_GRAPH_IMPL_CPP
10
10
11
+ #include < stack>
11
12
#include < detail/graph_impl.hpp>
12
13
#include < detail/handler_impl.hpp>
13
14
#include < detail/kernel_arg_mask.hpp>
@@ -31,64 +32,47 @@ namespace experimental {
31
32
namespace detail {
32
33
33
34
namespace {
34
- // / Visits a node on the graph and it's successors recursively in a depth-first
35
- // / approach.
36
- // / @param[in] Node The current node being visited.
37
- // / @param[in,out] VisitedNodes A set of unique nodes which have already been
38
- // / visited.
39
- // / @param[in] NodeStack Stack of nodes which are currently being visited on the
40
- // / current path through the graph.
41
- // / @param[in] NodeFunc The function object to be run on each node. A return
42
- // / value of true indicates the search should be ended immediately and the
43
- // / function will return.
44
- // / @return True if the search should end immediately, false if not.
45
- bool visitNodeDepthFirst (
46
- std::shared_ptr<node_impl> Node,
47
- std::set<std::shared_ptr<node_impl>> &VisitedNodes,
48
- std::deque<std::shared_ptr<node_impl>> &NodeStack,
49
- std::function<bool (std::shared_ptr<node_impl> &,
50
- std::deque<std::shared_ptr<node_impl>> &)>
51
- NodeFunc) {
52
- auto EarlyReturn = NodeFunc (Node, NodeStack);
53
- if (EarlyReturn) {
54
- return true ;
55
- }
56
- NodeStack.push_back (Node);
57
- Node->MVisited = true ;
58
- VisitedNodes.emplace (Node);
59
- for (auto &Successor : Node->MSuccessors ) {
60
- if (visitNodeDepthFirst (Successor.lock (), VisitedNodes, NodeStack,
61
- NodeFunc)) {
62
- return true ;
63
- }
64
- }
65
- NodeStack.pop_back ();
66
- return false ;
67
- }
68
-
69
- // / Recursively add nodes to execution stack.
70
- // / @param NodeImpl Node to schedule.
71
- // / @param Schedule Execution ordering to add node to.
72
- // / @param PartitionBounded If set to true, the topological sort is stopped at
73
- // / partition borders. Hence, nodes belonging to a partition different from the
74
- // / NodeImpl partition are not processed.
75
- void sortTopological (std::shared_ptr<node_impl> NodeImpl,
76
- std::list<std::shared_ptr<node_impl>> &Schedule,
77
- bool PartitionBounded = false ) {
78
- for (auto &Succ : NodeImpl->MSuccessors ) {
79
- auto NextNode = Succ.lock ();
80
- if (PartitionBounded &&
81
- (NextNode->MPartitionNum != NodeImpl->MPartitionNum )) {
82
- continue ;
83
- }
84
- // Check if we've already scheduled this node
85
- if (std::find (Schedule.begin (), Schedule.end (), NextNode) ==
86
- Schedule.end ()) {
87
- sortTopological (NextNode, Schedule, PartitionBounded);
35
+ // / Topologically sorts the graph in order to schedule nodes for execution.
36
+ // / This implementation is based on Kahn's algorithm which uses a Breadth-first
37
+ // / search approach.
38
+ // / For performance reasons, this function uses the MTotalVisitedEdges
39
+ // / member variable of the node_impl class. It's the caller responsibility to
40
+ // / make sure that MTotalVisitedEdges is set to 0 for all nodes in the graph
41
+ // / before calling this function.
42
+ // / @param[in] Roots List of root nodes.
43
+ // / @param[out] SortedNodes The graph nodes sorted in topological order.
44
+ // / @param[in] PartitionBounded If set to true, the topological sort is stopped
45
+ // / at partition borders. Hence, nodes belonging to a partition different from
46
+ // / the NodeImpl partition are not processed.
47
+ void sortTopological (std::set<std::weak_ptr<node_impl>,
48
+ std::owner_less<std::weak_ptr<node_impl>>> &Roots,
49
+ std::list<std::shared_ptr<node_impl>> &SortedNodes,
50
+ bool PartitionBounded) {
51
+ std::stack<std::weak_ptr<node_impl>> Source;
52
+
53
+ for (auto &Node : Roots) {
54
+ Source.push (Node);
55
+ }
56
+
57
+ while (!Source.empty ()) {
58
+ auto Node = Source.top ().lock ();
59
+ Source.pop ();
60
+ SortedNodes.push_back (Node);
61
+
62
+ for (auto &SuccWP : Node->MSuccessors ) {
63
+ auto Succ = SuccWP.lock ();
64
+
65
+ if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum )) {
66
+ continue ;
67
+ }
68
+
69
+ auto &TotalVisitedEdges = Succ->MTotalVisitedEdges ;
70
+ ++TotalVisitedEdges;
71
+ if (TotalVisitedEdges == Succ->MPredecessors .size ()) {
72
+ Source.push (Succ);
73
+ }
88
74
}
89
75
}
90
-
91
- Schedule.push_front (NodeImpl);
92
76
}
93
77
94
78
// / Propagates the partition number `PartitionNum` to predecessors.
@@ -180,9 +164,9 @@ std::vector<node> createNodesFromImpls(
180
164
181
165
void partition::schedule () {
182
166
if (MSchedule.empty ()) {
183
- for ( auto &Node : MRoots) {
184
- sortTopological (Node. lock (), MSchedule, true );
185
- }
167
+ // There is no need to reset MTotalVisitedEdges before calling
168
+ // sortTopological because this function is only called once per partition.
169
+ sortTopological (MRoots, MSchedule, true );
186
170
}
187
171
}
188
172
@@ -311,6 +295,7 @@ static void checkGraphPropertiesAndThrow(const property_list &Properties) {
311
295
#define __SYCL_MANUALLY_DEFINED_PROP (NS_QUALIFIER, PROP_NAME )
312
296
switch (PropertyKind) {
313
297
#include < sycl/ext/oneapi/experimental/detail/properties/graph_properties.def>
298
+
314
299
default :
315
300
return false ;
316
301
}
@@ -627,44 +612,20 @@ bool graph_impl::clearQueues() {
627
612
return AnyQueuesCleared;
628
613
}
629
614
630
- void graph_impl::searchDepthFirst (
631
- std::function<bool (std::shared_ptr<node_impl> &,
632
- std::deque<std::shared_ptr<node_impl>> &)>
633
- NodeFunc) {
634
- // Track nodes visited during the search which can be used by NodeFunc in
635
- // depth first search queries. Currently unusued but is an
636
- // integral part of depth first searches.
637
- std::set<std::shared_ptr<node_impl>> VisitedNodes;
615
+ bool graph_impl::checkForCycles () {
616
+ std::list<std::shared_ptr<node_impl>> SortedNodes;
617
+ sortTopological (MRoots, SortedNodes, false );
638
618
639
- for (auto &Root : MRoots) {
640
- std::deque<std::shared_ptr<node_impl>> NodeStack;
641
- if (visitNodeDepthFirst (Root.lock (), VisitedNodes, NodeStack, NodeFunc)) {
642
- break ;
643
- }
644
- }
619
+ // If after a topological sort, not all the nodes in the graph are sorted,
620
+ // then there must be at least one cycle in the graph. This is guaranteed
621
+ // by Kahn's algorithm, which sortTopological() implements.
622
+ bool CycleFound = SortedNodes.size () != MNodeStorage.size ();
645
623
646
- // Reset the visited status of all nodes encountered in the search .
647
- for (auto &Node : VisitedNodes ) {
648
- Node->MVisited = false ;
624
+ // Reset the MTotalVisitedEdges variable to prepare for the next cycle check .
625
+ for (auto &Node : MNodeStorage ) {
626
+ Node->MTotalVisitedEdges = 0 ;
649
627
}
650
- }
651
628
652
- bool graph_impl::checkForCycles () {
653
- // Using a depth-first search and checking if we vist a node more than once in
654
- // the current path to identify if there are cycles.
655
- bool CycleFound = false ;
656
- auto CheckFunc = [&](std::shared_ptr<node_impl> &Node,
657
- std::deque<std::shared_ptr<node_impl>> &NodeStack) {
658
- // If the current node has previously been found in the current path through
659
- // the graph then we have a cycle and we end the search early.
660
- if (std::find (NodeStack.begin (), NodeStack.end (), Node) !=
661
- NodeStack.end ()) {
662
- CycleFound = true ;
663
- return true ;
664
- }
665
- return false ;
666
- };
667
- searchDepthFirst (CheckFunc);
668
629
return CycleFound;
669
630
}
670
631
@@ -698,19 +659,31 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
698
659
" Dest must be a node inside the graph." );
699
660
}
700
661
662
+ bool DestWasGraphRoot = Dest->MPredecessors .size () == 0 ;
663
+
701
664
// We need to add the edges first before checking for cycles
702
665
Src->registerSuccessor (Dest);
703
666
667
+ bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors .size () == 1 ;
668
+ if (DestLostRootStatus) {
669
+ // Dest is no longer a Root node, so we need to remove it from MRoots.
670
+ MRoots.erase (Dest);
671
+ }
672
+
704
673
// We can skip cycle checks if either Dest has no successors (cycle not
705
674
// possible) or cycle checks have been disabled with the no_cycle_check
706
675
// property;
707
676
if (Dest->MSuccessors .empty () || !MSkipCycleChecks) {
708
677
bool CycleFound = checkForCycles ();
709
678
710
679
if (CycleFound) {
711
- // Remove the added successor and predecessor
680
+ // Remove the added successor and predecessor.
712
681
Src->MSuccessors .pop_back ();
713
682
Dest->MPredecessors .pop_back ();
683
+ if (DestLostRootStatus) {
684
+ // Add Dest back into MRoots.
685
+ MRoots.insert (Dest);
686
+ }
714
687
715
688
throw sycl::exception (make_error_code (sycl::errc::invalid),
716
689
" Command graphs cannot contain cycles." );
0 commit comments