|
| 1 | +//==--------- graph.hpp --- SYCL graph extension ---------------------------==// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#pragma once |
| 10 | + |
| 11 | +#include <memory> |
| 12 | +#include <vector> |
| 13 | + |
| 14 | +#include <sycl/detail/common.hpp> |
| 15 | +#include <sycl/detail/defines_elementary.hpp> |
| 16 | +#include <sycl/property_list.hpp> |
| 17 | + |
| 18 | +namespace sycl { |
| 19 | +__SYCL_INLINE_VER_NAMESPACE(_V1) { |
| 20 | + |
| 21 | +class handler; |
| 22 | +class queue; |
| 23 | +class device; |
| 24 | +namespace ext { |
| 25 | +namespace oneapi { |
| 26 | +namespace experimental { |
| 27 | + |
| 28 | +namespace detail { |
| 29 | +class node_impl; |
| 30 | +class graph_impl; |
| 31 | +class exec_graph_impl; |
| 32 | + |
| 33 | +} // namespace detail |
| 34 | + |
| 35 | +/// State to template the command_graph class on. |
| 36 | +enum class graph_state { |
| 37 | + modifiable, ///< In modifiable state, commands can be added to graph. |
| 38 | + executable, ///< In executable state, the graph is ready to execute. |
| 39 | +}; |
| 40 | + |
| 41 | +/// Class representing a node in the graph, returned by command_graph::add(). |
| 42 | +class __SYCL_EXPORT node { |
| 43 | +private: |
| 44 | + node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {} |
| 45 | + |
| 46 | + template <class Obj> |
| 47 | + friend decltype(Obj::impl) |
| 48 | + sycl::detail::getSyclObjImpl(const Obj &SyclObject); |
| 49 | + template <class T> |
| 50 | + friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj); |
| 51 | + |
| 52 | + std::shared_ptr<detail::node_impl> impl; |
| 53 | +}; |
| 54 | + |
| 55 | +namespace property { |
| 56 | +namespace graph { |
| 57 | + |
| 58 | +/// Property passed to command_graph constructor to disable checking for cycles. |
| 59 | +/// |
| 60 | +/// \todo Cycle check not yet implemented. |
| 61 | +class no_cycle_check : public ::sycl::detail::DataLessProperty< |
| 62 | + ::sycl::detail::GraphNoCycleCheck> { |
| 63 | +public: |
| 64 | + no_cycle_check() = default; |
| 65 | +}; |
| 66 | + |
| 67 | +} // namespace graph |
| 68 | + |
| 69 | +namespace node { |
| 70 | + |
| 71 | +/// Property used to define dependent nodes when creating a new node with |
| 72 | +/// command_graph::add(). |
| 73 | +class depends_on : public ::sycl::detail::PropertyWithData< |
| 74 | + ::sycl::detail::GraphNodeDependencies> { |
| 75 | +public: |
| 76 | + template <typename... NodeTN> depends_on(NodeTN... nodes) : MDeps{nodes...} {} |
| 77 | + |
| 78 | + const std::vector<::sycl::ext::oneapi::experimental::node> & |
| 79 | + get_dependencies() const { |
| 80 | + return MDeps; |
| 81 | + } |
| 82 | + |
| 83 | +private: |
| 84 | + const std::vector<::sycl::ext::oneapi::experimental::node> MDeps; |
| 85 | +}; |
| 86 | + |
| 87 | +} // namespace node |
| 88 | +} // namespace property |
| 89 | + |
| 90 | +/// Graph in the modifiable state. |
| 91 | +template <graph_state State = graph_state::modifiable> |
| 92 | +class __SYCL_EXPORT command_graph { |
| 93 | +public: |
| 94 | + /// Constructor. |
| 95 | + /// @param SyclContext Context to use for graph. |
| 96 | + /// @param SyclDevice Device all nodes will be associated with. |
| 97 | + /// @param PropList Optional list of properties to pass. |
| 98 | + command_graph(const context &SyclContext, const device &SyclDevice, |
| 99 | + const property_list &PropList = {}); |
| 100 | + |
| 101 | + /// Add an empty node to the graph. |
| 102 | + /// @param PropList Property list used to pass [0..n] predecessor nodes. |
| 103 | + /// @return Constructed empty node which has been added to the graph. |
| 104 | + node add(const property_list &PropList = {}) { |
| 105 | + if (PropList.has_property<property::node::depends_on>()) { |
| 106 | + auto Deps = PropList.get_property<property::node::depends_on>(); |
| 107 | + return addImpl(Deps.get_dependencies()); |
| 108 | + } |
| 109 | + return addImpl({}); |
| 110 | + } |
| 111 | + |
| 112 | + /// Add a command-group node to the graph. |
| 113 | + /// @param CGF Command-group function to create node with. |
| 114 | + /// @param PropList Property list used to pass [0..n] predecessor nodes. |
| 115 | + /// @return Constructed node which has been added to the graph. |
| 116 | + template <typename T> node add(T CGF, const property_list &PropList = {}) { |
| 117 | + if (PropList.has_property<property::node::depends_on>()) { |
| 118 | + auto Deps = PropList.get_property<property::node::depends_on>(); |
| 119 | + return addImpl(CGF, Deps.get_dependencies()); |
| 120 | + } |
| 121 | + return addImpl(CGF, {}); |
| 122 | + } |
| 123 | + |
| 124 | + /// Add a dependency between two nodes. |
| 125 | + /// @param Src Node which will be a dependency of \p Dest. |
| 126 | + /// @param Dest Node which will be dependent on \p Src. |
| 127 | + void make_edge(node &Src, node &Dest); |
| 128 | + |
| 129 | + /// Finalize modifiable graph into an executable graph. |
| 130 | + /// @param PropList Property list used to pass properties for finalization. |
| 131 | + /// @return Executable graph object. |
| 132 | + command_graph<graph_state::executable> |
| 133 | + finalize(const property_list &PropList = {}) const; |
| 134 | + |
| 135 | + /// Change the state of a queue to be recording and associate this graph with |
| 136 | + /// it. |
| 137 | + /// @param RecordingQueue The queue to change state on and associate this |
| 138 | + /// graph with. |
| 139 | + /// @return True if the queue had its state changed from executing to |
| 140 | + /// recording. |
| 141 | + bool begin_recording(queue &RecordingQueue); |
| 142 | + |
| 143 | + /// Change the state of multiple queues to be recording and associate this |
| 144 | + /// graph with each of them. |
| 145 | + /// @param RecordingQueues The queues to change state on and associate this |
| 146 | + /// graph with. |
| 147 | + /// @return True if any queue had its state changed from executing to |
| 148 | + /// recording. |
| 149 | + bool begin_recording(const std::vector<queue> &RecordingQueues); |
| 150 | + |
| 151 | + /// Set all queues currently recording to this graph to the executing state. |
| 152 | + /// @return True if any queue had its state changed from recording to |
| 153 | + /// executing. |
| 154 | + bool end_recording(); |
| 155 | + |
| 156 | + /// Set a queue currently recording to this graph to the executing state. |
| 157 | + /// @param RecordingQueue The queue to change state on. |
| 158 | + /// @return True if the queue had its state changed from recording to |
| 159 | + /// executing. |
| 160 | + bool end_recording(queue &RecordingQueue); |
| 161 | + |
| 162 | + /// Set multiple queues currently recording to this graph to the executing |
| 163 | + /// state. |
| 164 | + /// @param RecordingQueues The queues to change state on. |
| 165 | + /// @return True if any queue had its state changed from recording to |
| 166 | + /// executing. |
| 167 | + bool end_recording(const std::vector<queue> &RecordingQueues); |
| 168 | + |
| 169 | +private: |
| 170 | + /// Constructor used internally by the runtime. |
| 171 | + /// @param Impl Detail implementation class to construct object with. |
| 172 | + command_graph(const std::shared_ptr<detail::graph_impl> &Impl) : impl(Impl) {} |
| 173 | + |
| 174 | + /// Template-less implementation of add() for CGF nodes. |
| 175 | + /// @param CGF Command-group function to add. |
| 176 | + /// @param Dep List of predecessor nodes. |
| 177 | + /// @return Node added to the graph. |
| 178 | + node addImpl(std::function<void(handler &)> CGF, |
| 179 | + const std::vector<node> &Dep); |
| 180 | + |
| 181 | + /// Template-less implementation of add() for empty nodes. |
| 182 | + /// @param Dep List of predecessor nodes. |
| 183 | + /// @return Node added to the graph. |
| 184 | + node addImpl(const std::vector<node> &Dep); |
| 185 | + |
| 186 | + template <class Obj> |
| 187 | + friend decltype(Obj::impl) |
| 188 | + sycl::detail::getSyclObjImpl(const Obj &SyclObject); |
| 189 | + template <class T> |
| 190 | + friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj); |
| 191 | + |
| 192 | + std::shared_ptr<detail::graph_impl> impl; |
| 193 | +}; |
| 194 | + |
| 195 | +template <> class __SYCL_EXPORT command_graph<graph_state::executable> { |
| 196 | +public: |
| 197 | + /// An executable command-graph is not user constructable. |
| 198 | + command_graph() = delete; |
| 199 | + |
| 200 | + /// Update the inputs & output of the graph. |
| 201 | + /// @param Graph Graph to use the inputs and outputs of. |
| 202 | + void update(const command_graph<graph_state::modifiable> &Graph); |
| 203 | + |
| 204 | +private: |
| 205 | + /// Constructor used by internal runtime. |
| 206 | + /// @param Graph Detail implementation class to construct with. |
| 207 | + /// @param Ctx Context to use for graph. |
| 208 | + command_graph(const std::shared_ptr<detail::graph_impl> &Graph, |
| 209 | + const sycl::context &Ctx); |
| 210 | + |
| 211 | + template <class Obj> |
| 212 | + friend decltype(Obj::impl) |
| 213 | + sycl::detail::getSyclObjImpl(const Obj &SyclObject); |
| 214 | + |
| 215 | + /// Creates a backend representation of the graph in \p impl member variable. |
| 216 | + void finalizeImpl(); |
| 217 | + |
| 218 | + int MTag; |
| 219 | + std::shared_ptr<detail::exec_graph_impl> impl; |
| 220 | + |
| 221 | + friend class command_graph<graph_state::modifiable>; |
| 222 | +}; |
| 223 | + |
| 224 | +/// Additional CTAD deduction guide. |
| 225 | +template <graph_state State = graph_state::modifiable> |
| 226 | +command_graph(const context &SyclContext, const device &SyclDevice, |
| 227 | + const property_list &PropList) -> command_graph<State>; |
| 228 | + |
| 229 | +} // namespace experimental |
| 230 | +} // namespace oneapi |
| 231 | +} // namespace ext |
| 232 | + |
| 233 | +template <> |
| 234 | +struct is_property<ext::oneapi::experimental::property::graph::no_cycle_check> |
| 235 | + : std::true_type {}; |
| 236 | + |
| 237 | +template <> |
| 238 | +struct is_property<ext::oneapi::experimental::property::node::depends_on> |
| 239 | + : std::true_type {}; |
| 240 | + |
| 241 | +template <> |
| 242 | +struct is_property_of< |
| 243 | + ext::oneapi::experimental::property::graph::no_cycle_check, |
| 244 | + ext::oneapi::experimental::command_graph< |
| 245 | + ext::oneapi::experimental::graph_state::modifiable>> : std::true_type { |
| 246 | +}; |
| 247 | + |
| 248 | +template <> |
| 249 | +struct is_property_of<ext::oneapi::experimental::property::node::depends_on, |
| 250 | + ext::oneapi::experimental::node> : std::true_type {}; |
| 251 | + |
| 252 | +} // __SYCL_INLINE_VER_NAMESPACE(_V1) |
| 253 | +} // namespace sycl |
0 commit comments