Skip to content

Commit 0598209

Browse files
committed
[SYCL][Graph] Add error checking to make_edge (#264)
- make_edge now checks for cycles - no_cycle_check property can now be passed to skip them - Various other error checks in make_edge - Generic depth first search mechanism added to graph_impl - New e2e tests for cycle checks - Unit tests for other basic errors - Prevent adding duplicate edges - Adds testing for the graph structure after a cycle error is caught to ensure it is unchanged. - Skip cycle checks when dst has no successors
1 parent 1a5e532 commit 0598209

File tree

4 files changed

+377
-7
lines changed

4 files changed

+377
-7
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include <sycl/feature_test.hpp>
1616
#include <sycl/queue.hpp>
1717

18+
#include <deque>
19+
1820
// Developer switch to use emulation mode on all backends, even those that
1921
// report native support, this is useful for debugging.
2022
#define FORCE_EMULATION_MODE 0
@@ -71,6 +73,40 @@ bool checkForRequirement(sycl::detail::AccessorImplHost *Req,
7173
}
7274
return SuccessorAddedDep;
7375
}
76+
77+
/// Visits a node on the graph and it's successors recursively in a depth-first
78+
/// approach.
79+
/// @param[in] Node The current node being visited.
80+
/// @param[in,out] VisitedNodes A set of unique nodes which have already been
81+
/// visited.
82+
/// @param[in] NodeStack Stack of nodes which are currently being visited on the
83+
/// current path through the graph.
84+
/// @param[in] NodeFunc The function object to be run on each node. A return
85+
/// value of true indicates the search should be ended immediately and the
86+
/// function will return.
87+
/// @return True if the search should end immediately, false if not.
88+
bool visitNodeDepthFirst(
89+
std::shared_ptr<node_impl> Node,
90+
std::set<std::shared_ptr<node_impl>> &VisitedNodes,
91+
std::deque<std::shared_ptr<node_impl>> &NodeStack,
92+
std::function<bool(std::shared_ptr<node_impl> &,
93+
std::deque<std::shared_ptr<node_impl>> &)>
94+
NodeFunc) {
95+
auto EarlyReturn = NodeFunc(Node, NodeStack);
96+
if (EarlyReturn) {
97+
return true;
98+
}
99+
NodeStack.push_back(Node);
100+
Node->MVisited = true;
101+
VisitedNodes.emplace(Node);
102+
for (auto &Successor : Node->MSuccessors) {
103+
if (visitNodeDepthFirst(Successor, VisitedNodes, NodeStack, NodeFunc)) {
104+
return true;
105+
}
106+
}
107+
NodeStack.pop_back();
108+
return false;
109+
}
74110
} // anonymous namespace
75111

76112
void exec_graph_impl::schedule() {
@@ -226,6 +262,105 @@ bool graph_impl::clearQueues() {
226262
return AnyQueuesCleared;
227263
}
228264

265+
void graph_impl::searchDepthFirst(
266+
std::function<bool(std::shared_ptr<node_impl> &,
267+
std::deque<std::shared_ptr<node_impl>> &)>
268+
NodeFunc) {
269+
// Track nodes visited during the search which can be used by NodeFunc in
270+
// depth first search queries. Currently unusued but is an
271+
// integral part of depth first searches.
272+
std::set<std::shared_ptr<node_impl>> VisitedNodes;
273+
274+
for (auto &Root : MRoots) {
275+
std::deque<std::shared_ptr<node_impl>> NodeStack;
276+
if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) {
277+
break;
278+
}
279+
}
280+
281+
// Reset the visited status of all nodes encountered in the search.
282+
for (auto &Node : VisitedNodes) {
283+
Node->MVisited = false;
284+
}
285+
}
286+
287+
bool graph_impl::checkForCycles() {
288+
// Using a depth-first search and checking if we vist a node more than once in
289+
// the current path to identify if there are cycles.
290+
bool CycleFound = false;
291+
auto CheckFunc = [&](std::shared_ptr<node_impl> &Node,
292+
std::deque<std::shared_ptr<node_impl>> &NodeStack) {
293+
// If the current node has previously been found in the current path through
294+
// the graph then we have a cycle and we end the search early.
295+
if (std::find(NodeStack.begin(), NodeStack.end(), Node) !=
296+
NodeStack.end()) {
297+
CycleFound = true;
298+
return true;
299+
}
300+
return false;
301+
};
302+
searchDepthFirst(CheckFunc);
303+
return CycleFound;
304+
}
305+
306+
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
307+
std::shared_ptr<node_impl> Dest) {
308+
if (MRecordingQueues.size()) {
309+
throw sycl::exception(make_error_code(sycl::errc::invalid),
310+
"make_edge() cannot be called when a queue is "
311+
"currently recording commands to a graph.");
312+
}
313+
if (Src == Dest) {
314+
throw sycl::exception(
315+
make_error_code(sycl::errc::invalid),
316+
"make_edge() cannot be called when Src and Dest are the same.");
317+
}
318+
319+
bool SrcFound = false;
320+
bool DestFound = false;
321+
auto CheckForNodes = [&](std::shared_ptr<node_impl> &Node,
322+
std::deque<std::shared_ptr<node_impl>> &) {
323+
if (Node == Src) {
324+
SrcFound = true;
325+
}
326+
if (Node == Dest) {
327+
DestFound = true;
328+
}
329+
return SrcFound && DestFound;
330+
};
331+
332+
searchDepthFirst(CheckForNodes);
333+
334+
if (!SrcFound) {
335+
throw sycl::exception(make_error_code(sycl::errc::invalid),
336+
"Src must be a node inside the graph.");
337+
}
338+
if (!DestFound) {
339+
throw sycl::exception(make_error_code(sycl::errc::invalid),
340+
"Dest must be a node inside the graph.");
341+
}
342+
343+
// We need to add the edges first before checking for cycles
344+
Src->registerSuccessor(Dest, Src);
345+
346+
// We can skip cycle checks if either Dest has no successors (cycle not
347+
// possible) or cycle checks have been disabled with the no_cycle_check
348+
// property;
349+
if (Dest->MSuccessors.empty() || !MSkipCycleChecks) {
350+
bool CycleFound = checkForCycles();
351+
352+
if (CycleFound) {
353+
// Remove the added successor and predecessor
354+
Src->MSuccessors.pop_back();
355+
Dest->MPredecessors.pop_back();
356+
357+
throw sycl::exception(make_error_code(sycl::errc::invalid),
358+
"Command graphs cannot contain cycles.");
359+
}
360+
}
361+
removeRoot(Dest); // remove receiver from root node list
362+
}
363+
229364
// Check if nodes are empty and if so loop back through predecessors until we
230365
// find the real dependency.
231366
void exec_graph_impl::findRealDeps(
@@ -463,8 +598,9 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
463598

464599
modifiable_command_graph::modifiable_command_graph(
465600
const sycl::context &SyclContext, const sycl::device &SyclDevice,
466-
const sycl::property_list &)
467-
: impl(std::make_shared<detail::graph_impl>(SyclContext, SyclDevice)) {}
601+
const sycl::property_list &PropList)
602+
: impl(std::make_shared<detail::graph_impl>(SyclContext, SyclDevice,
603+
PropList)) {}
468604

469605
node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
470606
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
@@ -494,9 +630,7 @@ void modifiable_command_graph::make_edge(node &Src, node &Dest) {
494630
std::shared_ptr<detail::node_impl> ReceiverImpl =
495631
sycl::detail::getSyclObjImpl(Dest);
496632

497-
SenderImpl->registerSuccessor(ReceiverImpl,
498-
SenderImpl); // register successor
499-
impl->removeRoot(ReceiverImpl); // remove receiver from root node list
633+
impl->makeEdge(SenderImpl, ReceiverImpl);
500634
}
501635

502636
command_graph<graph_state::executable>

sycl/source/detail/graph_impl.hpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <detail/kernel_impl.hpp>
1818

1919
#include <cstring>
20+
#include <deque>
2021
#include <functional>
2122
#include <list>
2223
#include <set>
@@ -43,6 +44,9 @@ class node_impl {
4344
/// Command group object which stores all args etc needed to enqueue the node
4445
std::unique_ptr<sycl::detail::CG> MCommandGroup;
4546

47+
/// Used for tracking visited status during cycle checks.
48+
bool MVisited = false;
49+
4650
/// Add successor to the node.
4751
/// @param Node Node to add as a successor.
4852
/// @param Prev Predecessor to \p node being added as successor.
@@ -51,13 +55,23 @@ class node_impl {
5155
/// use a raw \p this pointer, so the extra \Prev parameter is passed.
5256
void registerSuccessor(const std::shared_ptr<node_impl> &Node,
5357
const std::shared_ptr<node_impl> &Prev) {
58+
if (std::find(MSuccessors.begin(), MSuccessors.end(), Node) !=
59+
MSuccessors.end()) {
60+
return;
61+
}
5462
MSuccessors.push_back(Node);
5563
Node->registerPredecessor(Prev);
5664
}
5765

5866
/// Add predecessor to the node.
5967
/// @param Node Node to add as a predecessor.
6068
void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
69+
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
70+
[&Node](const std::weak_ptr<node_impl> &Ptr) {
71+
return Ptr.lock() == Node;
72+
}) != MPredecessors.end()) {
73+
return;
74+
}
6175
MPredecessors.push_back(Node);
6276
}
6377

@@ -183,9 +197,15 @@ class graph_impl {
183197
/// Constructor.
184198
/// @param SyclContext Context to use for graph.
185199
/// @param SyclDevice Device to create nodes with.
186-
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice)
200+
/// @param PropList Optional list of properties.
201+
graph_impl(const sycl::context &SyclContext, const sycl::device &SyclDevice,
202+
const sycl::property_list &PropList = {})
187203
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
188-
MEventsMap(), MInorderQueueMap() {}
204+
MEventsMap(), MInorderQueueMap() {
205+
if (PropList.has_property<property::graph::no_cycle_check>()) {
206+
MSkipCycleChecks = true;
207+
}
208+
}
189209

190210
/// Insert node into list of root nodes.
191211
/// @param Root Node to add to list of root nodes.
@@ -315,7 +335,32 @@ class graph_impl {
315335
MInorderQueueMap[QueueWeakPtr] = Node;
316336
}
317337

338+
/// Make an edge between two nodes in the graph. Performs some mandatory
339+
/// error checks as well as an optional check for cycles introduced by making
340+
/// this edge.
341+
/// @param Src The source of the new edge.
342+
/// @param Dest The destination of the new edge.
343+
void makeEdge(std::shared_ptr<node_impl> Src,
344+
std::shared_ptr<node_impl> Dest);
345+
318346
private:
347+
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
348+
/// @param NodeFunc A function which receives as input a node in the graph to
349+
/// perform operations on as well as the stack of nodes encountered in the
350+
/// current path. The return value of this function determines whether an
351+
/// early exit is triggered, if true the depth-first search will end
352+
/// immediately and no further nodes will be visited.
353+
void
354+
searchDepthFirst(std::function<bool(std::shared_ptr<node_impl> &,
355+
std::deque<std::shared_ptr<node_impl>> &)>
356+
NodeFunc);
357+
358+
/// Check the graph for cycles by performing a depth-first search of the
359+
/// graph. If a node is visited more than once in a given path through the
360+
/// graph, a cycle is present and the search ends immediately.
361+
/// @return True if a cycle is detected, false if not.
362+
bool checkForCycles();
363+
319364
/// Context associated with this graph.
320365
sycl::context MContext;
321366
/// Device associated with this graph. All graph nodes will execute on this
@@ -333,6 +378,9 @@ class graph_impl {
333378
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
334379
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
335380
MInorderQueueMap;
381+
/// Controls whether we skip the cycle checks in makeEdge, set by the presence
382+
/// of the no_cycle_check property on construction.
383+
bool MSkipCycleChecks = false;
336384
};
337385

338386
/// Class representing the implementation of command_graph<executable>.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// REQUIRES: level_zero, gpu
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
// Tests that introducing a cycle to the graph will throw when
6+
// property::graph::no_cycle_check is not passed to the graph constructor and
7+
// will not throw when it is.
8+
9+
#include "../graph_common.hpp"
10+
11+
void CreateGraphWithCyclesTest(bool DisableCycleChecks) {
12+
13+
// If we are testing without cycle checks we need to do multiple iterations so
14+
// we can test multiple types of cycle, since introducing a cycle with no
15+
// checks may put the graph into an undefined state.
16+
const size_t Iterations = DisableCycleChecks ? 2 : 1;
17+
18+
queue Queue;
19+
20+
property_list Props;
21+
22+
if (DisableCycleChecks) {
23+
Props = {ext::oneapi::experimental::property::graph::no_cycle_check{}};
24+
}
25+
26+
for (size_t i = 0; i < Iterations; i++) {
27+
ext::oneapi::experimental::command_graph Graph{Queue.get_context(),
28+
Queue.get_device(), Props};
29+
30+
auto NodeA = Graph.add([&](sycl::handler &CGH) {
31+
CGH.single_task<class testKernelA>([=]() {});
32+
});
33+
auto NodeB = Graph.add([&](sycl::handler &CGH) {
34+
CGH.single_task<class testKernelB>([=]() {});
35+
});
36+
auto NodeC = Graph.add([&](sycl::handler &CGH) {
37+
CGH.single_task<class testKernelC>([=]() {});
38+
});
39+
40+
// Make normal edges
41+
std::error_code ErrorCode = sycl::make_error_code(sycl::errc::success);
42+
try {
43+
Graph.make_edge(NodeA, NodeB);
44+
Graph.make_edge(NodeB, NodeC);
45+
} catch (const sycl::exception &e) {
46+
ErrorCode = e.code();
47+
}
48+
49+
assert(ErrorCode == sycl::errc::success);
50+
51+
// Introduce cycles to the graph. If we are performing cycle checks we can
52+
// test both cycles, if they are disabled we need to test one per iteration.
53+
if (i == 0 || !DisableCycleChecks) {
54+
ErrorCode = sycl::make_error_code(sycl::errc::success);
55+
try {
56+
Graph.make_edge(NodeC, NodeA);
57+
} catch (const sycl::exception &e) {
58+
ErrorCode = e.code();
59+
}
60+
61+
assert(ErrorCode ==
62+
(DisableCycleChecks ? sycl::errc::success : sycl::errc::invalid));
63+
}
64+
65+
if (i == 1 || !DisableCycleChecks) {
66+
ErrorCode = sycl::make_error_code(sycl::errc::success);
67+
try {
68+
Graph.make_edge(NodeC, NodeB);
69+
} catch (const sycl::exception &e) {
70+
ErrorCode = e.code();
71+
}
72+
73+
assert(ErrorCode ==
74+
(DisableCycleChecks ? sycl::errc::success : sycl::errc::invalid));
75+
}
76+
}
77+
}
78+
79+
int main() {
80+
// Test with cycle checks
81+
CreateGraphWithCyclesTest(false);
82+
// Test without cycle checks
83+
CreateGraphWithCyclesTest(true);
84+
85+
return 0;
86+
}

0 commit comments

Comments
 (0)