|
15 | 15 | #include <sycl/feature_test.hpp>
|
16 | 16 | #include <sycl/queue.hpp>
|
17 | 17 |
|
| 18 | +#include <deque> |
| 19 | + |
18 | 20 | // Developer switch to use emulation mode on all backends, even those that
|
19 | 21 | // report native support, this is useful for debugging.
|
20 | 22 | #define FORCE_EMULATION_MODE 0
|
@@ -71,6 +73,40 @@ bool checkForRequirement(sycl::detail::AccessorImplHost *Req,
|
71 | 73 | }
|
72 | 74 | return SuccessorAddedDep;
|
73 | 75 | }
|
| 76 | + |
| 77 | +/// Visits a node on the graph and it's successors recursively in a depth-first |
| 78 | +/// approach. |
| 79 | +/// @param[in] Node The current node being visited. |
| 80 | +/// @param[in,out] VisitedNodes A set of unique nodes which have already been |
| 81 | +/// visited. |
| 82 | +/// @param[in] NodeStack Stack of nodes which are currently being visited on the |
| 83 | +/// current path through the graph. |
| 84 | +/// @param[in] NodeFunc The function object to be run on each node. A return |
| 85 | +/// value of true indicates the search should be ended immediately and the |
| 86 | +/// function will return. |
| 87 | +/// @return True if the search should end immediately, false if not. |
| 88 | +bool visitNodeDepthFirst( |
| 89 | + std::shared_ptr<node_impl> Node, |
| 90 | + std::set<std::shared_ptr<node_impl>> &VisitedNodes, |
| 91 | + std::deque<std::shared_ptr<node_impl>> &NodeStack, |
| 92 | + std::function<bool(std::shared_ptr<node_impl> &, |
| 93 | + std::deque<std::shared_ptr<node_impl>> &)> |
| 94 | + NodeFunc) { |
| 95 | + auto EarlyReturn = NodeFunc(Node, NodeStack); |
| 96 | + if (EarlyReturn) { |
| 97 | + return true; |
| 98 | + } |
| 99 | + NodeStack.push_back(Node); |
| 100 | + Node->MVisited = true; |
| 101 | + VisitedNodes.emplace(Node); |
| 102 | + for (auto &Successor : Node->MSuccessors) { |
| 103 | + if (visitNodeDepthFirst(Successor, VisitedNodes, NodeStack, NodeFunc)) { |
| 104 | + return true; |
| 105 | + } |
| 106 | + } |
| 107 | + NodeStack.pop_back(); |
| 108 | + return false; |
| 109 | +} |
74 | 110 | } // anonymous namespace
|
75 | 111 |
|
76 | 112 | void exec_graph_impl::schedule() {
|
@@ -226,6 +262,105 @@ bool graph_impl::clearQueues() {
|
226 | 262 | return AnyQueuesCleared;
|
227 | 263 | }
|
228 | 264 |
|
| 265 | +void graph_impl::searchDepthFirst( |
| 266 | + std::function<bool(std::shared_ptr<node_impl> &, |
| 267 | + std::deque<std::shared_ptr<node_impl>> &)> |
| 268 | + NodeFunc) { |
| 269 | + // Track nodes visited during the search which can be used by NodeFunc in |
| 270 | + // depth first search queries. Currently unusued but is an |
| 271 | + // integral part of depth first searches. |
| 272 | + std::set<std::shared_ptr<node_impl>> VisitedNodes; |
| 273 | + |
| 274 | + for (auto &Root : MRoots) { |
| 275 | + std::deque<std::shared_ptr<node_impl>> NodeStack; |
| 276 | + if (visitNodeDepthFirst(Root, VisitedNodes, NodeStack, NodeFunc)) { |
| 277 | + break; |
| 278 | + } |
| 279 | + } |
| 280 | + |
| 281 | + // Reset the visited status of all nodes encountered in the search. |
| 282 | + for (auto &Node : VisitedNodes) { |
| 283 | + Node->MVisited = false; |
| 284 | + } |
| 285 | +} |
| 286 | + |
| 287 | +bool graph_impl::checkForCycles() { |
| 288 | + // Using a depth-first search and checking if we vist a node more than once in |
| 289 | + // the current path to identify if there are cycles. |
| 290 | + bool CycleFound = false; |
| 291 | + auto CheckFunc = [&](std::shared_ptr<node_impl> &Node, |
| 292 | + std::deque<std::shared_ptr<node_impl>> &NodeStack) { |
| 293 | + // If the current node has previously been found in the current path through |
| 294 | + // the graph then we have a cycle and we end the search early. |
| 295 | + if (std::find(NodeStack.begin(), NodeStack.end(), Node) != |
| 296 | + NodeStack.end()) { |
| 297 | + CycleFound = true; |
| 298 | + return true; |
| 299 | + } |
| 300 | + return false; |
| 301 | + }; |
| 302 | + searchDepthFirst(CheckFunc); |
| 303 | + return CycleFound; |
| 304 | +} |
| 305 | + |
| 306 | +void graph_impl::makeEdge(std::shared_ptr<node_impl> Src, |
| 307 | + std::shared_ptr<node_impl> Dest) { |
| 308 | + if (MRecordingQueues.size()) { |
| 309 | + throw sycl::exception(make_error_code(sycl::errc::invalid), |
| 310 | + "make_edge() cannot be called when a queue is " |
| 311 | + "currently recording commands to a graph."); |
| 312 | + } |
| 313 | + if (Src == Dest) { |
| 314 | + throw sycl::exception( |
| 315 | + make_error_code(sycl::errc::invalid), |
| 316 | + "make_edge() cannot be called when Src and Dest are the same."); |
| 317 | + } |
| 318 | + |
| 319 | + bool SrcFound = false; |
| 320 | + bool DestFound = false; |
| 321 | + auto CheckForNodes = [&](std::shared_ptr<node_impl> &Node, |
| 322 | + std::deque<std::shared_ptr<node_impl>> &) { |
| 323 | + if (Node == Src) { |
| 324 | + SrcFound = true; |
| 325 | + } |
| 326 | + if (Node == Dest) { |
| 327 | + DestFound = true; |
| 328 | + } |
| 329 | + return SrcFound && DestFound; |
| 330 | + }; |
| 331 | + |
| 332 | + searchDepthFirst(CheckForNodes); |
| 333 | + |
| 334 | + if (!SrcFound) { |
| 335 | + throw sycl::exception(make_error_code(sycl::errc::invalid), |
| 336 | + "Src must be a node inside the graph."); |
| 337 | + } |
| 338 | + if (!DestFound) { |
| 339 | + throw sycl::exception(make_error_code(sycl::errc::invalid), |
| 340 | + "Dest must be a node inside the graph."); |
| 341 | + } |
| 342 | + |
| 343 | + // We need to add the edges first before checking for cycles |
| 344 | + Src->registerSuccessor(Dest, Src); |
| 345 | + |
| 346 | + // We can skip cycle checks if either Dest has no successors (cycle not |
| 347 | + // possible) or cycle checks have been disabled with the no_cycle_check |
| 348 | + // property; |
| 349 | + if (Dest->MSuccessors.empty() || !MSkipCycleChecks) { |
| 350 | + bool CycleFound = checkForCycles(); |
| 351 | + |
| 352 | + if (CycleFound) { |
| 353 | + // Remove the added successor and predecessor |
| 354 | + Src->MSuccessors.pop_back(); |
| 355 | + Dest->MPredecessors.pop_back(); |
| 356 | + |
| 357 | + throw sycl::exception(make_error_code(sycl::errc::invalid), |
| 358 | + "Command graphs cannot contain cycles."); |
| 359 | + } |
| 360 | + } |
| 361 | + removeRoot(Dest); // remove receiver from root node list |
| 362 | +} |
| 363 | + |
229 | 364 | // Check if nodes are empty and if so loop back through predecessors until we
|
230 | 365 | // find the real dependency.
|
231 | 366 | void exec_graph_impl::findRealDeps(
|
@@ -463,8 +598,9 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
|
463 | 598 |
|
464 | 599 | modifiable_command_graph::modifiable_command_graph(
|
465 | 600 | const sycl::context &SyclContext, const sycl::device &SyclDevice,
|
466 |
| - const sycl::property_list &) |
467 |
| - : impl(std::make_shared<detail::graph_impl>(SyclContext, SyclDevice)) {} |
| 601 | + const sycl::property_list &PropList) |
| 602 | + : impl(std::make_shared<detail::graph_impl>(SyclContext, SyclDevice, |
| 603 | + PropList)) {} |
468 | 604 |
|
469 | 605 | node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
|
470 | 606 | std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
|
@@ -494,9 +630,7 @@ void modifiable_command_graph::make_edge(node &Src, node &Dest) {
|
494 | 630 | std::shared_ptr<detail::node_impl> ReceiverImpl =
|
495 | 631 | sycl::detail::getSyclObjImpl(Dest);
|
496 | 632 |
|
497 |
| - SenderImpl->registerSuccessor(ReceiverImpl, |
498 |
| - SenderImpl); // register successor |
499 |
| - impl->removeRoot(ReceiverImpl); // remove receiver from root node list |
| 633 | + impl->makeEdge(SenderImpl, ReceiverImpl); |
500 | 634 | }
|
501 | 635 |
|
502 | 636 | command_graph<graph_state::executable>
|
|
0 commit comments