Skip to content

Commit 9189aad

Browse files
committed
[SYCL][Graph] Enable host-task update in graphs
- Update spec wording to allow updating host-task function in graphs - Support host-tasks in dynamic command-groups - Support host-tasks in whole graph update - Add E2E tests for both scenarios - Fix passing incorrect accessors to graph update command after update
1 parent f3d12f0 commit 9189aad

12 files changed

+891
-104
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ Parameters:
551551

552552
|===
553553

554-
==== Dynamic Command Groups
554+
==== Dynamic Command Groups [[dynamic-command-groups]]
555555

556556
[source,c++]
557557
----
@@ -570,12 +570,13 @@ public:
570570
Dynamic command-groups can be added as nodes to a graph. They provide a
571571
mechanism that allows updating the command-group function of a node after the
572572
graph is finalized. There is always one command-group function in the dynamic
573-
command-group that is set as active, this is the kernel which will execute for
574-
the node when the graph is finalized into an executable state `command_graph`,
575-
and all the other command-group functions in `cgfList` will be ignored. The
576-
executable `command_graph` node can then be updated to a different kernel in
577-
`cgfList`, by selecting a new active index on the dynamic command-group object
578-
and calling the `update(node& node)` method on the executable `command_graph`.
573+
command-group that is set as active, this is the command-group which will
574+
execute for the node when the graph is finalized into an executable state
575+
`command_graph`, and all the other command-group functions in `cgfList` will be
576+
ignored. The executable `command_graph` node can then be updated to a different
577+
kernel in `cgfList`, by selecting a new active index on the dynamic
578+
command-group object and calling the `update(node& node)` method on the
579+
executable `command_graph`.
579580

580581
The `dynamic_command_group` class provides the {crs}[common reference semantics].
581582

@@ -584,9 +585,13 @@ about updating command-groups.
584585

585586
===== Limitations
586587

587-
Dynamic command-groups can only contain kernel operations. Trying to construct
588-
a dynamic command-group with functions that contain other operations will
589-
result in an error.
588+
Dynamic command-groups can only contain the following operations:
589+
590+
* Kernel operations
591+
* <<host-tasks, Host-tasks>>
592+
593+
Trying to construct a dynamic command-group with functions that contain other
594+
operations will result in an error.
590595

591596
All the command-group functions in a dynamic command-group must have identical dependencies.
592597
It is not allowed for a dynamic command-group to have command-group functions that would
@@ -625,10 +630,13 @@ Exceptions:
625630
property for more information.
626631

627632
* Throws with error code `invalid` if the `dynamic_command_group` is created with
628-
command-group functions that are not kernel executions.
633+
command-group functions that are not kernel executions or host-tasks.
629634

630635
* Throws with error code `invalid` if `cgfList` is empty.
631636

637+
* Throws with error code `invalid` if the types of all command-groups in
638+
`cgfList` do not match.
639+
632640
|
633641
[source,c++]
634642
----
@@ -829,32 +837,54 @@ possible.
829837

830838
===== Supported Features
831839

832-
The only types of nodes that are currently able to be updated in a graph are
833-
kernel execution nodes.
840+
The only types of nodes that are currently able to be updated in a graph are:
834841

835-
There are two different API's that can be used to update a graph:
842+
* Kernel executions
843+
* <<host-tasks, Host-tasks>>
844+
845+
There are two different APIs that can be used to update a graph:
836846

837847
* <<individual-node-update, Individual Node Update>> which allows updating
838848
individual nodes of a command-graph.
839849
* <<whole-graph-update, Whole Graph Update>> which allows updating the
840850
entirety of the graph simultaneously by using another graph as a
841851
reference.
842852

843-
The aspects of a kernel execution node that can be changed during update are
844-
different depending on the API used to perform the update:
853+
The following table illustrates the aspects of each supported node type that can be changed
854+
depending on the API used to perform the update.
855+
856+
Table {counter: tableNumber}. Graph update capabilites for supported node types.
857+
[cols="1,2a,2a"]
858+
|===
859+
|Node Type|<<individual-node-update, Individual Node Update>>|<<whole-graph-update, Whole Graph Update>>
860+
861+
|`node_type::kernel`
862+
|
863+
864+
* Kernel function
865+
* Kernel Parameters
866+
* ND-range
867+
868+
|
869+
* Kernel Parameters
870+
* ND-range
845871

846-
* For the <<individual-node-update, Individual Node Update>> API it's possible to update
847-
the kernel function, the parameters to the kernel, and the ND-range.
848-
* For the <<whole-graph-update, Whole Graph Update>> API, only the parameters of the kernel
849-
and the ND-range can be updated.
872+
|`node_type::host_task`
873+
|
874+
* Host-task function
875+
|
876+
* Host-task function
877+
878+
|===
850879

851880
===== Individual Node Update [[individual-node-update]]
852881

853-
Individual nodes of an executable graph can be updated directly. Depending on the attribute
854-
of the node that requires updating, different API's should be used:
882+
Individual nodes of an executable graph can be updated directly. Depending on the attribute or `node_type` of the node that requires updating, different API's should be used:
855883

856884
====== Parameter Updates
857885

886+
_Supported Node Types: Kernel_
887+
858888
Parameters to individual nodes in a graph in the `executable` state can be
859889
updated between graph executions using dynamic parameters. A `dynamic_parameter`
860890
object is created with a modifiable state graph and an initial value for the
@@ -884,6 +914,8 @@ will maintain the graphs data dependencies.
884914

885915
====== Execution Range Updates
886916

917+
_Supported Node Types: Kernel_
918+
887919
Another configuration that can be updated is the execution range of the
888920
kernel, this can be set through `node::update_nd_range()` or
889921
`node::update_range()` but does not require any prior registration.
@@ -897,10 +929,13 @@ code may be defined as operating in a different dimension.
897929

898930
====== Command Group Updates
899931

900-
The command-groups of a kernel node can be updated using dynamic command-groups.
901-
Dynamic command-groups allow replacing the command-group function of a kernel
902-
node with a different one. This effectively allows updating the kernel function
903-
and/or the kernel execution range.
932+
_Supported Node Types: Kernel, Host-task_
933+
934+
The command-groups of a kernel node can be updated using
935+
<<dynamic-command-groups, Dynamic Command-Groups>>. Dynamic command-groups allow
936+
replacing the command-group function of a kernel node with a different one. This
937+
effectively allows updating the kernel function and/or the kernel execution
938+
range.
904939

905940
Command-group updates are performed by creating an instance of the
906941
`dynamic_command_group` class. A dynamic command-group is created with a modifiable
@@ -1972,7 +2007,7 @@ Any code like this should be moved to a separate host-task and added to the
19722007
graph via the recording or explicit APIs in order to be compatible with this
19732008
extension.
19742009

1975-
=== Host Tasks
2010+
=== Host Tasks [[host-tasks]]
19762011

19772012
:host-task: https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:interfaces.hosttasks
19782013

@@ -1992,6 +2027,9 @@ auto node = graph.add([&](sycl::handler& cgh){
19922027
});
19932028
----
19942029

2030+
Host-tasks can be updated using <<executable-graph-update, Executable Graph Update>>.
2031+
2032+
19952033
=== Queue Behavior In Recording Mode
19962034

19972035
When a queue is placed in recording mode via a call to `command_graph::begin_recording`,

sycl/source/detail/graph_impl.cpp

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -581,9 +581,9 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
581581
std::vector<std::shared_ptr<detail::node_impl>> &Deps) {
582582
// Set of Dependent nodes based on CG event and accessor dependencies.
583583
std::set<std::shared_ptr<node_impl>> DynCGDeps =
584-
getCGEdges(DynCGImpl->MKernels[0]);
584+
getCGEdges(DynCGImpl->MCommandGroups[0]);
585585
for (unsigned i = 1; i < DynCGImpl->getNumCGs(); i++) {
586-
auto &CG = DynCGImpl->MKernels[i];
586+
auto &CG = DynCGImpl->MCommandGroups[i];
587587
auto CGEdges = getCGEdges(CG);
588588
if (CGEdges != DynCGDeps) {
589589
throw sycl::exception(make_error_code(sycl::errc::invalid),
@@ -593,14 +593,16 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
593593
}
594594

595595
// Track and mark the memory objects being used by the graph.
596-
for (auto &CG : DynCGImpl->MKernels) {
596+
for (auto &CG : DynCGImpl->MCommandGroups) {
597597
markCGMemObjs(CG);
598598
}
599599

600600
// Get active dynamic command-group CG and use to create a node object
601-
const auto &ActiveKernel = DynCGImpl->getActiveKernel();
601+
const auto &ActiveKernel = DynCGImpl->getActiveCG();
602+
node_type NodeType =
603+
ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType);
602604
std::shared_ptr<detail::node_impl> NodeImpl =
603-
add(node_type::kernel, ActiveKernel, Deps);
605+
add(NodeType, ActiveKernel, Deps);
604606

605607
// Add an event associated with this explicit node for mixed usage
606608
addEventForNode(std::make_shared<sycl::detail::event_impl>(), NodeImpl);
@@ -1400,11 +1402,11 @@ void exec_graph_impl::update(
14001402
"Node passed to update() is not part of the graph.");
14011403
}
14021404

1403-
if (!(Node->isEmpty() || Node->MCGType == sycl::detail::CGType::Kernel ||
1404-
Node->MCGType == sycl::detail::CGType::Barrier)) {
1405-
throw sycl::exception(errc::invalid,
1406-
"Unsupported node type for update. Only kernel, "
1407-
"barrier and empty nodes are supported.");
1405+
if (!Node->isUpdatable()) {
1406+
throw sycl::exception(
1407+
errc::invalid,
1408+
"Unsupported node type for update. Only kernel, host_task, "
1409+
"barrier and empty nodes are supported.");
14081410
}
14091411

14101412
if (const auto &CG = Node->MCommandGroup;
@@ -1445,23 +1447,46 @@ void exec_graph_impl::update(
14451447
}
14461448
}
14471449

1448-
// Rebuild cached requirements for this graph with updated nodes
1450+
// Rebuild cached requirements and accessor storage for this graph with
1451+
// updated nodes
14491452
MRequirements.clear();
1453+
MAccessors.clear();
14501454
for (auto &Node : MNodeStorage) {
14511455
if (!Node->MCommandGroup)
14521456
continue;
14531457
MRequirements.insert(MRequirements.end(),
14541458
Node->MCommandGroup->getRequirements().begin(),
14551459
Node->MCommandGroup->getRequirements().end());
1460+
MAccessors.insert(MAccessors.end(),
1461+
Node->MCommandGroup->getAccStorage().begin(),
1462+
Node->MCommandGroup->getAccStorage().end());
14561463
}
14571464
}
14581465

14591466
void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
1460-
// Kernel node update is the only command type supported in UR for update.
1461-
// Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
1462-
if (Node->MCGType != sycl::detail::CGType::Kernel) {
1467+
// Updating empty or barrier nodes is a no-op
1468+
if (Node->isEmpty() || Node->MNodeType == node_type::ext_oneapi_barrier) {
1469+
return;
1470+
}
1471+
1472+
// Query the ID cache to find the equivalent exec node for the node passed to
1473+
// this function.
1474+
// TODO: Handle subgraphs or any other cases where multiple nodes may be
1475+
// associated with a single key, once those node types are supported for
1476+
// update.
1477+
auto ExecNode = MIDCache.find(Node->MID);
1478+
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1479+
1480+
// Update ExecNode with new values from Node, in case we ever need to
1481+
// rebuild the command buffers
1482+
ExecNode->second->updateFromOtherNode(Node);
1483+
1484+
// Host task update only requires updating the node itself, so can return
1485+
// early
1486+
if (Node->MNodeType == node_type::host_task) {
14631487
return;
14641488
}
1489+
14651490
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
14661491
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
14671492
auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());
@@ -1614,18 +1639,6 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
16141639
UpdateDesc.pNewLocalWorkSize = LocalSize;
16151640
UpdateDesc.newWorkDim = NDRDesc.Dims;
16161641

1617-
// Query the ID cache to find the equivalent exec node for the node passed to
1618-
// this function.
1619-
// TODO: Handle subgraphs or any other cases where multiple nodes may be
1620-
// associated with a single key, once those node types are supported for
1621-
// update.
1622-
auto ExecNode = MIDCache.find(Node->MID);
1623-
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1624-
1625-
// Update ExecNode with new values from Node, in case we ever need to
1626-
// rebuild the command buffers
1627-
ExecNode->second->updateFromOtherNode(Node);
1628-
16291642
ur_exp_command_buffer_command_handle_t Command =
16301643
MCommandMap[ExecNode->second];
16311644
ur_result_t Res = Adapter->call_nocheck<
@@ -1929,7 +1942,7 @@ void dynamic_parameter_impl::updateValue(const void *NewValue, size_t Size) {
19291942
for (auto &DynCGInfo : MDynCGs) {
19301943
auto DynCG = DynCGInfo.DynCG.lock();
19311944
if (DynCG) {
1932-
auto &CG = DynCG->MKernels[DynCGInfo.CGIndex];
1945+
auto &CG = DynCG->MCommandGroups[DynCGInfo.CGIndex];
19331946
dynamic_parameter_impl::updateCGArgValue(CG, DynCGInfo.ArgIndex, NewValue,
19341947
Size);
19351948
}
@@ -1952,7 +1965,7 @@ void dynamic_parameter_impl::updateAccessor(
19521965
for (auto &DynCGInfo : MDynCGs) {
19531966
auto DynCG = DynCGInfo.DynCG.lock();
19541967
if (DynCG) {
1955-
auto &CG = DynCG->MKernels[DynCGInfo.CGIndex];
1968+
auto &CG = DynCG->MCommandGroups[DynCGInfo.CGIndex];
19561969
dynamic_parameter_impl::updateCGAccessor(CG, DynCGInfo.ArgIndex, Acc);
19571970
}
19581971
}
@@ -2040,38 +2053,67 @@ void dynamic_command_group_impl::finalizeCGFList(
20402053
sycl::handler Handler{MGraph};
20412054
CGF(Handler);
20422055

2043-
if (Handler.getType() != sycl::detail::CGType::Kernel) {
2056+
if (Handler.getType() != sycl::detail::CGType::Kernel &&
2057+
Handler.getType() != sycl::detail::CGType::CodeplayHostTask) {
20442058
throw sycl::exception(
20452059
make_error_code(errc::invalid),
2046-
"The only type of command-groups that can be used in "
2047-
"dynamic command-groups is kernels.");
2060+
"The only types of command-groups that can be used in "
2061+
"dynamic command-groups are kernels and host-tasks.");
2062+
}
2063+
2064+
// We need to store the first CG's type so we can check they are all the
2065+
// same
2066+
if (CGFIndex == 0) {
2067+
MCGType = Handler.getType();
2068+
} else if (MCGType != Handler.getType()) {
2069+
throw sycl::exception(make_error_code(errc::invalid),
2070+
"Command-groups in a dynamic command-group must "
2071+
"all be the same type.");
20482072
}
20492073

20502074
Handler.finalize();
20512075

20522076
// Take unique_ptr<detail::CG> object from handler and convert to
2053-
// shared_ptr<detail::CGExecKernel> to store
2077+
// shared_ptr<detail::CG> to store
20542078
sycl::detail::CG *RawCGPtr = Handler.impl->MGraphNodeCG.release();
2055-
auto RawCGExecPtr = static_cast<sycl::detail::CGExecKernel *>(RawCGPtr);
2056-
MKernels.push_back(
2057-
std::shared_ptr<sycl::detail::CGExecKernel>(RawCGExecPtr));
2079+
MCommandGroups.push_back(std::shared_ptr<sycl::detail::CG>(RawCGPtr));
20582080

2059-
// Track dynamic_parameter usage in command-list
2081+
// Track dynamic_parameter usage in command-group
20602082
auto &DynamicParams = Handler.impl->MDynamicParameters;
2083+
2084+
if (DynamicParams.size() > 0 &&
2085+
Handler.getType() == sycl::detail::CGType::CodeplayHostTask) {
2086+
throw sycl::exception(make_error_code(errc::invalid),
2087+
"Cannot use dynamic parameters in a host_task");
2088+
}
20612089
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
20622090
DynamicParam->registerDynCG(shared_from_this(), CGFIndex, ArgIndex);
20632091
}
20642092
}
20652093

2066-
// For each CGExecKernel store the list of alternative kernels, not
2094+
// Host tasks don't need to store alternative kernels
2095+
if (MCGType == sycl::detail::CGType::CodeplayHostTask) {
2096+
return;
2097+
}
2098+
2099+
// For each Kernel CG store the list of alternative kernels, not
20672100
// including itself.
20682101
using CGExecKernelSP = std::shared_ptr<sycl::detail::CGExecKernel>;
20692102
using CGExecKernelWP = std::weak_ptr<sycl::detail::CGExecKernel>;
2070-
for (auto KernelCG : MKernels) {
2103+
for (std::shared_ptr<sycl::detail::CG> CommandGroup : MCommandGroups) {
2104+
CGExecKernelSP KernelCG =
2105+
std::dynamic_pointer_cast<sycl::detail::CGExecKernel>(CommandGroup);
20712106
std::vector<CGExecKernelWP> Alternatives;
2072-
std::copy_if(
2073-
MKernels.begin(), MKernels.end(), std::back_inserter(Alternatives),
2074-
[&KernelCG](const CGExecKernelSP &K) { return K != KernelCG; });
2107+
2108+
// Add all other command groups except for the current one to the list of
2109+
// alternatives
2110+
for (auto &OtherCG : MCommandGroups) {
2111+
CGExecKernelSP OtherKernelCG =
2112+
std::dynamic_pointer_cast<sycl::detail::CGExecKernel>(OtherCG);
2113+
if (KernelCG != OtherKernelCG) {
2114+
Alternatives.push_back(OtherKernelCG);
2115+
}
2116+
}
20752117

20762118
KernelCG->MAlternativeKernels = std::move(Alternatives);
20772119
}
@@ -2087,7 +2129,7 @@ void dynamic_command_group_impl::setActiveIndex(size_t Index) {
20872129
// Update nodes using the dynamic command-group to use the new active CG
20882130
for (auto &Node : MNodes) {
20892131
if (auto NodeSP = Node.lock()) {
2090-
NodeSP->MCommandGroup = getActiveKernel();
2132+
NodeSP->MCommandGroup = getActiveCG();
20912133
}
20922134
}
20932135
}

0 commit comments

Comments
 (0)