Skip to content

Commit 373d3d2

Browse files
committed
Fix host-tasks being enqueued before they should be updated.
1 parent 9189aad commit 373d3d2

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ void exec_graph_impl::makePartitions() {
196196
}
197197
}
198198

199+
MContainsHostTask = HostTaskList.size() > 0;
199200
// Annotate nodes
200201
// The first step in graph partitioning is to annotate all nodes of the graph
201202
// with a temporary partition or group number. This step allows us to group
@@ -1080,6 +1081,16 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10801081
} else if ((CurrentPartition->MSchedule.size() > 0) &&
10811082
(CurrentPartition->MSchedule.front()->MCGType ==
10821083
sycl::detail::CGType::CodeplayHostTask)) {
1084+
// If we have pending updates then we need to make sure that they are
1085+
// completed before the host-task is enqueued, to ensure it has received
1086+
// those updates prior to calling node->getCGCopy()
1087+
if (MUpdateEvents.size() > 0) {
1088+
for (auto &Event : MUpdateEvents) {
1089+
Event->wait_and_throw(Event);
1090+
}
1091+
MUpdateEvents.clear();
1092+
}
1093+
10831094
auto NodeImpl = CurrentPartition->MSchedule.front();
10841095
// Schedule host task
10851096
NodeImpl->MCommandGroup->getEvents().insert(
@@ -1438,9 +1449,17 @@ void exec_graph_impl::update(
14381449
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
14391450
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
14401451
sycl::async_handler{}, sycl::property_list{});
1441-
// Don't need to care about the return event here because it is synchronous
1442-
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
1443-
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
1452+
1453+
auto UpdateEvent =
1454+
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
1455+
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
1456+
1457+
// If the graph contains host-task(s) we need to track update events so we
1458+
// can explicitly wait on them before enqueue further host-tasks to ensure
1459+
// updates have taken effect.
1460+
if (MContainsHostTask) {
1461+
MUpdateEvents.push_back(UpdateEvent);
1462+
}
14441463
} else {
14451464
for (auto &Node : Nodes) {
14461465
updateImpl(Node);

sycl/source/detail/graph_impl.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
475475
HostTaskCG->getAccStorage() = OtherHostTaskCG->getAccStorage();
476476
HostTaskCG->getRequirements() = OtherHostTaskCG->getRequirements();
477477
HostTaskCG->MHostTask = OtherHostTaskCG->MHostTask;
478+
HostTaskCG->getEvents() = OtherHostTaskCG->getEvents();
478479
break;
479480
}
480481
default:
@@ -1453,6 +1454,12 @@ class exec_graph_impl {
14531454
unsigned long long MID;
14541455
// Used for std::hash in order to create a unique hash for the instance.
14551456
inline static std::atomic<unsigned long long> NextAvailableID = 0;
1457+
// True if this graph contains any host-tasks, controls whether we store
1458+
// events in MUpdateEvents.
1459+
bool MContainsHostTask = false;
1460+
// Contains events for updates submitted through the scheduler as we need to
1461+
// wait on them when enqueuing host-tasks.
1462+
std::vector<sycl::detail::EventImplPtr> MUpdateEvents;
14561463
};
14571464

14581465
class dynamic_parameter_impl {

sycl/test-e2e/Graph/Inputs/whole_update_host_task.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,23 +100,20 @@ int main() {
100100
// Fill graphB with nodes, with a different set of pointers
101101
add_nodes_to_graph(GraphB, Queue, PtrA2, PtrB2, PtrC2, ModValue);
102102

103-
// Execute several Iterations of the graph for 1st set of buffers
103+
// Execute several Iterations of the graph, updating in between each
104+
// execution.
104105
event Event;
105106
for (unsigned n = 0; n < Iterations; n++) {
106107
Event = Queue.submit([&](handler &CGH) {
107108
CGH.depends_on(Event);
108109
CGH.ext_oneapi_graph(GraphExec);
109110
});
110-
}
111-
112-
GraphExec.update(GraphB);
113-
114-
// Execute several Iterations of the graph for 2nd set of buffers
115-
for (unsigned n = 0; n < Iterations; n++) {
111+
GraphExec.update(GraphB);
116112
Event = Queue.submit([&](handler &CGH) {
117113
CGH.depends_on(Event);
118114
CGH.ext_oneapi_graph(GraphExec);
119115
});
116+
GraphExec.update(GraphA);
120117
}
121118

122119
Queue.wait_and_throw();

0 commit comments

Comments
 (0)