Skip to content

Commit bc20c32

Browse files
committed
Address issues with previous solution
- Instead of trying to update host tasks through scheduler, simply do host-task updates immediately before scheduling the rest of the commands.
1 parent 373d3d2 commit bc20c32

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,15 +1081,6 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10811081
} else if ((CurrentPartition->MSchedule.size() > 0) &&
10821082
(CurrentPartition->MSchedule.front()->MCGType ==
10831083
sycl::detail::CGType::CodeplayHostTask)) {
1084-
// If we have pending updates then we need to make sure that they are
1085-
// completed before the host-task is enqueued, to ensure it has received
1086-
// those updates prior to calling node->getCGCopy()
1087-
if (MUpdateEvents.size() > 0) {
1088-
for (auto &Event : MUpdateEvents) {
1089-
Event->wait_and_throw(Event);
1090-
}
1091-
MUpdateEvents.clear();
1092-
}
10931084

10941085
auto NodeImpl = CurrentPartition->MSchedule.front();
10951086
// Schedule host task
@@ -1389,7 +1380,7 @@ void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
13891380
}
13901381

13911382
void exec_graph_impl::update(
1392-
const std::vector<std::shared_ptr<node_impl>> Nodes) {
1383+
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
13931384

13941385
if (!MIsUpdatable) {
13951386
throw sycl::exception(sycl::make_error_code(errc::invalid),
@@ -1445,21 +1436,45 @@ void exec_graph_impl::update(
14451436
NeedScheduledUpdate |= MExecutionEvents.size() > 0;
14461437

14471438
if (NeedScheduledUpdate) {
1439+
// Copy the list of nodes as we may need to modify it
1440+
auto NodesCopy = Nodes;
1441+
1442+
// If the graph contains host tasks we need special handling here because
1443+
// their state lives in the graph object itself, so we must do the update
1444+
// immediately here. Whereas all other command state lives in the backend so
1445+
// it can be scheduled along with other commands.
1446+
if (MContainsHostTask) {
1447+
std::vector<std::shared_ptr<node_impl>> HostTasks;
1448+
// Remove any nodes that are host tasks and put them in HostTasks
1449+
auto RemovedIter = std::remove_if(
1450+
NodesCopy.begin(), NodesCopy.end(),
1451+
[&HostTasks](const std::shared_ptr<node_impl> &Node) -> bool {
1452+
if (Node->MNodeType == node_type::host_task) {
1453+
HostTasks.push_back(Node);
1454+
return true;
1455+
}
1456+
return false;
1457+
});
1458+
// Clean up extra elements in NodesCopy after the remove
1459+
NodesCopy.erase(RemovedIter, NodesCopy.end());
1460+
1461+
// Update host-tasks synchronously
1462+
for (auto &HostTaskNode : HostTasks) {
1463+
updateImpl(HostTaskNode);
1464+
}
1465+
}
1466+
14481467
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
14491468
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
14501469
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
14511470
sycl::async_handler{}, sycl::property_list{});
14521471

14531472
auto UpdateEvent =
14541473
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
1455-
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
1474+
this, std::move(NodesCopy), AllocaQueue, UpdateRequirements,
1475+
MExecutionEvents);
14561476

1457-
// If the graph contains host-task(s) we need to track update events so we
1458-
// can explicitly wait on them before enqueue further host-tasks to ensure
1459-
// updates have taken effect.
1460-
if (MContainsHostTask) {
1461-
MUpdateEvents.push_back(UpdateEvent);
1462-
}
1477+
MExecutionEvents.push_back(UpdateEvent);
14631478
} else {
14641479
for (auto &Node : Nodes) {
14651480
updateImpl(Node);

sycl/source/detail/graph_impl.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,7 +1333,7 @@ class exec_graph_impl {
13331333

13341334
void update(std::shared_ptr<graph_impl> GraphImpl);
13351335
void update(std::shared_ptr<node_impl> Node);
1336-
void update(const std::vector<std::shared_ptr<node_impl>> Nodes);
1336+
void update(const std::vector<std::shared_ptr<node_impl>> &Nodes);
13371337

13381338
void updateImpl(std::shared_ptr<node_impl> NodeImpl);
13391339

@@ -1454,12 +1454,10 @@ class exec_graph_impl {
14541454
unsigned long long MID;
14551455
// Used for std::hash in order to create a unique hash for the instance.
14561456
inline static std::atomic<unsigned long long> NextAvailableID = 0;
1457-
// True if this graph contains any host-tasks, controls whether we store
1458-
// events in MUpdateEvents.
1457+
1458+
// True if this graph contains any host-tasks, indicates we need special
1459+
// handling for them during update().
14591460
bool MContainsHostTask = false;
1460-
// Contains events for updates submitted through the scheduler as we need to
1461-
// wait on them when enqueuing host-tasks.
1462-
std::vector<sycl::detail::EventImplPtr> MUpdateEvents;
14631461
};
14641462

14651463
class dynamic_parameter_impl {

0 commit comments

Comments
 (0)