@@ -581,9 +581,9 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
581
581
std::vector<std::shared_ptr<detail::node_impl>> &Deps) {
582
582
// Set of Dependent nodes based on CG event and accessor dependencies.
583
583
std::set<std::shared_ptr<node_impl>> DynCGDeps =
584
- getCGEdges (DynCGImpl->MKernels [0 ]);
584
+ getCGEdges (DynCGImpl->MCommandGroups [0 ]);
585
585
for (unsigned i = 1 ; i < DynCGImpl->getNumCGs (); i++) {
586
- auto &CG = DynCGImpl->MKernels [i];
586
+ auto &CG = DynCGImpl->MCommandGroups [i];
587
587
auto CGEdges = getCGEdges (CG);
588
588
if (CGEdges != DynCGDeps) {
589
589
throw sycl::exception (make_error_code (sycl::errc::invalid),
@@ -593,14 +593,16 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
593
593
}
594
594
595
595
// Track and mark the memory objects being used by the graph.
596
- for (auto &CG : DynCGImpl->MKernels ) {
596
+ for (auto &CG : DynCGImpl->MCommandGroups ) {
597
597
markCGMemObjs (CG);
598
598
}
599
599
600
600
// Get active dynamic command-group CG and use to create a node object
601
- const auto &ActiveKernel = DynCGImpl->getActiveKernel ();
601
+ const auto &ActiveKernel = DynCGImpl->getActiveCG ();
602
+ node_type NodeType =
603
+ ext::oneapi::experimental::detail::getNodeTypeFromCG (DynCGImpl->MCGType );
602
604
std::shared_ptr<detail::node_impl> NodeImpl =
603
- add (node_type::kernel , ActiveKernel, Deps);
605
+ add (NodeType , ActiveKernel, Deps);
604
606
605
607
// Add an event associated with this explicit node for mixed usage
606
608
addEventForNode (std::make_shared<sycl::detail::event_impl>(), NodeImpl);
@@ -1400,11 +1402,11 @@ void exec_graph_impl::update(
1400
1402
" Node passed to update() is not part of the graph." );
1401
1403
}
1402
1404
1403
- if (!( Node->isEmpty () || Node-> MCGType == sycl::detail::CGType::Kernel ||
1404
- Node-> MCGType == sycl::detail::CGType::Barrier)) {
1405
- throw sycl::exception ( errc::invalid,
1406
- " Unsupported node type for update. Only kernel, "
1407
- " barrier and empty nodes are supported." );
1405
+ if (!Node->isUpdatable ()) {
1406
+ throw sycl::exception (
1407
+ errc::invalid,
1408
+ " Unsupported node type for update. Only kernel, host_task , "
1409
+ " barrier and empty nodes are supported." );
1408
1410
}
1409
1411
1410
1412
if (const auto &CG = Node->MCommandGroup ;
@@ -1445,23 +1447,46 @@ void exec_graph_impl::update(
1445
1447
}
1446
1448
}
1447
1449
1448
- // Rebuild cached requirements for this graph with updated nodes
1450
+ // Rebuild cached requirements and accessor storage for this graph with
1451
+ // updated nodes
1449
1452
MRequirements.clear ();
1453
+ MAccessors.clear ();
1450
1454
for (auto &Node : MNodeStorage) {
1451
1455
if (!Node->MCommandGroup )
1452
1456
continue ;
1453
1457
MRequirements.insert (MRequirements.end (),
1454
1458
Node->MCommandGroup ->getRequirements ().begin (),
1455
1459
Node->MCommandGroup ->getRequirements ().end ());
1460
+ MAccessors.insert (MAccessors.end (),
1461
+ Node->MCommandGroup ->getAccStorage ().begin (),
1462
+ Node->MCommandGroup ->getAccStorage ().end ());
1456
1463
}
1457
1464
}
1458
1465
1459
1466
void exec_graph_impl::updateImpl (std::shared_ptr<node_impl> Node) {
1460
- // Kernel node update is the only command type supported in UR for update.
1461
- // Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
1462
- if (Node->MCGType != sycl::detail::CGType::Kernel) {
1467
+ // Updating empty or barrier nodes is a no-op
1468
+ if (Node->isEmpty () || Node->MNodeType == node_type::ext_oneapi_barrier) {
1469
+ return ;
1470
+ }
1471
+
1472
+ // Query the ID cache to find the equivalent exec node for the node passed to
1473
+ // this function.
1474
+ // TODO: Handle subgraphs or any other cases where multiple nodes may be
1475
+ // associated with a single key, once those node types are supported for
1476
+ // update.
1477
+ auto ExecNode = MIDCache.find (Node->MID );
1478
+ assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1479
+
1480
+ // Update ExecNode with new values from Node, in case we ever need to
1481
+ // rebuild the command buffers
1482
+ ExecNode->second ->updateFromOtherNode (Node);
1483
+
1484
+ // Host task update only requires updating the node itself, so can return
1485
+ // early
1486
+ if (Node->MNodeType == node_type::host_task) {
1463
1487
return ;
1464
1488
}
1489
+
1465
1490
auto ContextImpl = sycl::detail::getSyclObjImpl (MContext);
1466
1491
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter ();
1467
1492
auto DeviceImpl = sycl::detail::getSyclObjImpl (MGraphImpl->getDevice ());
@@ -1614,18 +1639,6 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
1614
1639
UpdateDesc.pNewLocalWorkSize = LocalSize;
1615
1640
UpdateDesc.newWorkDim = NDRDesc.Dims ;
1616
1641
1617
- // Query the ID cache to find the equivalent exec node for the node passed to
1618
- // this function.
1619
- // TODO: Handle subgraphs or any other cases where multiple nodes may be
1620
- // associated with a single key, once those node types are supported for
1621
- // update.
1622
- auto ExecNode = MIDCache.find (Node->MID );
1623
- assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1624
-
1625
- // Update ExecNode with new values from Node, in case we ever need to
1626
- // rebuild the command buffers
1627
- ExecNode->second ->updateFromOtherNode (Node);
1628
-
1629
1642
ur_exp_command_buffer_command_handle_t Command =
1630
1643
MCommandMap[ExecNode->second ];
1631
1644
ur_result_t Res = Adapter->call_nocheck <
@@ -1929,7 +1942,7 @@ void dynamic_parameter_impl::updateValue(const void *NewValue, size_t Size) {
1929
1942
for (auto &DynCGInfo : MDynCGs) {
1930
1943
auto DynCG = DynCGInfo.DynCG .lock ();
1931
1944
if (DynCG) {
1932
- auto &CG = DynCG->MKernels [DynCGInfo.CGIndex ];
1945
+ auto &CG = DynCG->MCommandGroups [DynCGInfo.CGIndex ];
1933
1946
dynamic_parameter_impl::updateCGArgValue (CG, DynCGInfo.ArgIndex , NewValue,
1934
1947
Size);
1935
1948
}
@@ -1952,7 +1965,7 @@ void dynamic_parameter_impl::updateAccessor(
1952
1965
for (auto &DynCGInfo : MDynCGs) {
1953
1966
auto DynCG = DynCGInfo.DynCG .lock ();
1954
1967
if (DynCG) {
1955
- auto &CG = DynCG->MKernels [DynCGInfo.CGIndex ];
1968
+ auto &CG = DynCG->MCommandGroups [DynCGInfo.CGIndex ];
1956
1969
dynamic_parameter_impl::updateCGAccessor (CG, DynCGInfo.ArgIndex , Acc);
1957
1970
}
1958
1971
}
@@ -2040,38 +2053,67 @@ void dynamic_command_group_impl::finalizeCGFList(
2040
2053
sycl::handler Handler{MGraph};
2041
2054
CGF (Handler);
2042
2055
2043
- if (Handler.getType () != sycl::detail::CGType::Kernel) {
2056
+ if (Handler.getType () != sycl::detail::CGType::Kernel &&
2057
+ Handler.getType () != sycl::detail::CGType::CodeplayHostTask) {
2044
2058
throw sycl::exception (
2045
2059
make_error_code (errc::invalid),
2046
- " The only type of command-groups that can be used in "
2047
- " dynamic command-groups is kernels." );
2060
+ " The only types of command-groups that can be used in "
2061
+ " dynamic command-groups are kernels and host-tasks." );
2062
+ }
2063
+
2064
+ // We need to store the first CG's type so we can check they are all the
2065
+ // same
2066
+ if (CGFIndex == 0 ) {
2067
+ MCGType = Handler.getType ();
2068
+ } else if (MCGType != Handler.getType ()) {
2069
+ throw sycl::exception (make_error_code (errc::invalid),
2070
+ " Command-groups in a dynamic command-group must "
2071
+ " all be the same type." );
2048
2072
}
2049
2073
2050
2074
Handler.finalize ();
2051
2075
2052
2076
// Take unique_ptr<detail::CG> object from handler and convert to
2053
- // shared_ptr<detail::CGExecKernel > to store
2077
+ // shared_ptr<detail::CG > to store
2054
2078
sycl::detail::CG *RawCGPtr = Handler.impl ->MGraphNodeCG .release ();
2055
- auto RawCGExecPtr = static_cast <sycl::detail::CGExecKernel *>(RawCGPtr);
2056
- MKernels.push_back (
2057
- std::shared_ptr<sycl::detail::CGExecKernel>(RawCGExecPtr));
2079
+ MCommandGroups.push_back (std::shared_ptr<sycl::detail::CG>(RawCGPtr));
2058
2080
2059
- // Track dynamic_parameter usage in command-list
2081
+ // Track dynamic_parameter usage in command-group
2060
2082
auto &DynamicParams = Handler.impl ->MDynamicParameters ;
2083
+
2084
+ if (DynamicParams.size () > 0 &&
2085
+ Handler.getType () == sycl::detail::CGType::CodeplayHostTask) {
2086
+ throw sycl::exception (make_error_code (errc::invalid),
2087
+ " Cannot use dynamic parameters in a host_task" );
2088
+ }
2061
2089
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
2062
2090
DynamicParam->registerDynCG (shared_from_this (), CGFIndex, ArgIndex);
2063
2091
}
2064
2092
}
2065
2093
2066
- // For each CGExecKernel store the list of alternative kernels, not
2094
+ // Host tasks don't need to store alternative kernels
2095
+ if (MCGType == sycl::detail::CGType::CodeplayHostTask) {
2096
+ return ;
2097
+ }
2098
+
2099
+ // For each Kernel CG store the list of alternative kernels, not
2067
2100
// including itself.
2068
2101
using CGExecKernelSP = std::shared_ptr<sycl::detail::CGExecKernel>;
2069
2102
using CGExecKernelWP = std::weak_ptr<sycl::detail::CGExecKernel>;
2070
- for (auto KernelCG : MKernels) {
2103
+ for (std::shared_ptr<sycl::detail::CG> CommandGroup : MCommandGroups) {
2104
+ CGExecKernelSP KernelCG =
2105
+ std::dynamic_pointer_cast<sycl::detail::CGExecKernel>(CommandGroup);
2071
2106
std::vector<CGExecKernelWP> Alternatives;
2072
- std::copy_if (
2073
- MKernels.begin (), MKernels.end (), std::back_inserter (Alternatives),
2074
- [&KernelCG](const CGExecKernelSP &K) { return K != KernelCG; });
2107
+
2108
+ // Add all other command groups except for the current one to the list of
2109
+ // alternatives
2110
+ for (auto &OtherCG : MCommandGroups) {
2111
+ CGExecKernelSP OtherKernelCG =
2112
+ std::dynamic_pointer_cast<sycl::detail::CGExecKernel>(OtherCG);
2113
+ if (KernelCG != OtherKernelCG) {
2114
+ Alternatives.push_back (OtherKernelCG);
2115
+ }
2116
+ }
2075
2117
2076
2118
KernelCG->MAlternativeKernels = std::move (Alternatives);
2077
2119
}
@@ -2087,7 +2129,7 @@ void dynamic_command_group_impl::setActiveIndex(size_t Index) {
2087
2129
// Update nodes using the dynamic command-group to use the new active CG
2088
2130
for (auto &Node : MNodes) {
2089
2131
if (auto NodeSP = Node.lock ()) {
2090
- NodeSP->MCommandGroup = getActiveKernel ();
2132
+ NodeSP->MCommandGroup = getActiveCG ();
2091
2133
}
2092
2134
}
2093
2135
}
0 commit comments