Skip to content

Commit 4b980d0

Browse files
mfrancepilloisBensuoEwanC
committed
[SYCL][Graph] Throw exception when explicit add called on a graph recording a queue (#283)
Factorizes the exception throwing method when the explicit API is used on a graph recording a queue Improves the test while_recording to test throwing an invalid exception for the two explicit graph:add entry points. Addresses issue #271 * Update sycl/source/detail/graph_impl.hpp Co-authored-by: Ben Tracy <[email protected]> * Update sycl/test-e2e/Graph/Explicit/while_recording.cpp Co-authored-by: Ben Tracy <[email protected]> * Update sycl/source/detail/graph_impl.hpp Co-authored-by: Ewan Crawford <[email protected]> --------- Co-authored-by: Ben Tracy <[email protected]> Co-authored-by: Ewan Crawford <[email protected]>
1 parent 99e6944 commit 4b980d0

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,7 @@ bool graph_impl::checkForCycles() {
305305

306306
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
307307
std::shared_ptr<node_impl> Dest) {
308-
if (MRecordingQueues.size()) {
309-
throw sycl::exception(make_error_code(sycl::errc::invalid),
310-
"make_edge() cannot be called when a queue is "
311-
"currently recording commands to a graph.");
312-
}
308+
throwIfGraphRecordingQueue("make_edge()");
313309
if (Src == Dest) {
314310
throw sycl::exception(
315311
make_error_code(sycl::errc::invalid),
@@ -610,6 +606,7 @@ modifiable_command_graph::modifiable_command_graph(
610606
PropList)) {}
611607

612608
node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
609+
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
613610
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
614611
for (auto &D : Deps) {
615612
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
@@ -621,6 +618,7 @@ node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
621618

622619
node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
623620
const std::vector<node> &Deps) {
621+
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
624622
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
625623
for (auto &D : Deps) {
626624
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));

sycl/source/detail/graph_impl.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,18 @@ class graph_impl {
343343
void makeEdge(std::shared_ptr<node_impl> Src,
344344
std::shared_ptr<node_impl> Dest);
345345

346+
/// Throws an invalid exception if this function is called
347+
/// while a queue is recording commands to the graph.
348+
/// @param ExceptionMsg Message to append to the exception message
349+
void throwIfGraphRecordingQueue(const std::string ExceptionMsg) const {
350+
if (MRecordingQueues.size()) {
351+
throw sycl::exception(make_error_code(sycl::errc::invalid),
352+
ExceptionMsg +
353+
" cannot be called when a queue "
354+
"is currently recording commands to a graph.");
355+
}
356+
}
357+
346358
private:
347359
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
348360
/// @param NodeFunc A function which receives as input a node in the graph to

sycl/test-e2e/Graph/Explicit/add_node_while_recording.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
//
77
// CHECK-NOT: LEAK
88

9-
// Expected Fail as exception not implemented yet
10-
// XFAIL: *
11-
129
// Tests attempting to add a node to a command_graph while it is being
1310
// recorded to by a queue is an error.
1411

@@ -30,8 +27,17 @@ int main() {
3027
Success = true;
3128
}
3229
}
30+
assert(Success);
3331

34-
Graph.end_recording();
32+
Success = false;
33+
try {
34+
Graph.add({});
35+
} catch (sycl::exception &E) {
36+
auto StdErrc = E.code().value();
37+
Success = (StdErrc == static_cast<int>(errc::invalid));
38+
}
3539
assert(Success);
40+
41+
Graph.end_recording();
3642
return 0;
3743
}

0 commit comments

Comments
 (0)