@@ -1081,15 +1081,6 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
1081
1081
} else if ((CurrentPartition->MSchedule .size () > 0 ) &&
1082
1082
(CurrentPartition->MSchedule .front ()->MCGType ==
1083
1083
sycl::detail::CGType::CodeplayHostTask)) {
1084
- // If we have pending updates then we need to make sure that they are
1085
- // completed before the host-task is enqueued, to ensure it has received
1086
- // those updates prior to calling node->getCGCopy()
1087
- if (MUpdateEvents.size () > 0 ) {
1088
- for (auto &Event : MUpdateEvents) {
1089
- Event->wait_and_throw (Event);
1090
- }
1091
- MUpdateEvents.clear ();
1092
- }
1093
1084
1094
1085
auto NodeImpl = CurrentPartition->MSchedule .front ();
1095
1086
// Schedule host task
@@ -1389,7 +1380,7 @@ void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
1389
1380
}
1390
1381
1391
1382
void exec_graph_impl::update (
1392
- const std::vector<std::shared_ptr<node_impl>> Nodes) {
1383
+ const std::vector<std::shared_ptr<node_impl>> & Nodes) {
1393
1384
1394
1385
if (!MIsUpdatable) {
1395
1386
throw sycl::exception (sycl::make_error_code (errc::invalid),
@@ -1445,21 +1436,45 @@ void exec_graph_impl::update(
1445
1436
NeedScheduledUpdate |= MExecutionEvents.size () > 0 ;
1446
1437
1447
1438
if (NeedScheduledUpdate) {
1439
+ // Copy the list of nodes as we may need to modify it
1440
+ auto NodesCopy = Nodes;
1441
+
1442
+ // If the graph contains host tasks we need special handling here because
1443
+ // their state lives in the graph object itself, so we must do the update
1444
+ // immediately here. Whereas all other command state lives in the backend so
1445
+ // it can be scheduled along with other commands.
1446
+ if (MContainsHostTask) {
1447
+ std::vector<std::shared_ptr<node_impl>> HostTasks;
1448
+ // Remove any nodes that are host tasks and put them in HostTasks
1449
+ auto RemovedIter = std::remove_if (
1450
+ NodesCopy.begin (), NodesCopy.end (),
1451
+ [&HostTasks](const std::shared_ptr<node_impl> &Node) -> bool {
1452
+ if (Node->MNodeType == node_type::host_task) {
1453
+ HostTasks.push_back (Node);
1454
+ return true ;
1455
+ }
1456
+ return false ;
1457
+ });
1458
+ // Clean up extra elements in NodesCopy after the remove
1459
+ NodesCopy.erase (RemovedIter, NodesCopy.end ());
1460
+
1461
+ // Update host-tasks synchronously
1462
+ for (auto &HostTaskNode : HostTasks) {
1463
+ updateImpl (HostTaskNode);
1464
+ }
1465
+ }
1466
+
1448
1467
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
1449
1468
sycl::detail::getSyclObjImpl (MGraphImpl->getDevice ()),
1450
1469
sycl::detail::getSyclObjImpl (MGraphImpl->getContext ()),
1451
1470
sycl::async_handler{}, sycl::property_list{});
1452
1471
1453
1472
auto UpdateEvent =
1454
1473
sycl::detail::Scheduler::getInstance ().addCommandGraphUpdate (
1455
- this , Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
1474
+ this , std::move (NodesCopy), AllocaQueue, UpdateRequirements,
1475
+ MExecutionEvents);
1456
1476
1457
- // If the graph contains host-task(s) we need to track update events so we
1458
- // can explicitly wait on them before enqueue further host-tasks to ensure
1459
- // updates have taken effect.
1460
- if (MContainsHostTask) {
1461
- MUpdateEvents.push_back (UpdateEvent);
1462
- }
1477
+ MExecutionEvents.push_back (UpdateEvent);
1463
1478
} else {
1464
1479
for (auto &Node : Nodes) {
1465
1480
updateImpl (Node);
0 commit comments