Skip to content

Commit ccf64e5

Browse files
committed
Add common reference semantics to sycl graphs
Adds missing common reference semantic functionality such as operator==, operator!= and hash functions to all sycl graph related classes.
1 parent 70f92a0 commit ccf64e5

File tree

5 files changed

+323
-6
lines changed

5 files changed

+323
-6
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1818
#include <sycl/detail/string_view.hpp>
1919
#endif
20-
#include <sycl/device.hpp> // for device
20+
#include <sycl/device.hpp> // for device
2121
#include <sycl/ext/oneapi/experimental/detail/properties/graph_properties.hpp> // for graph properties classes
2222
#include <sycl/nd_range.hpp> // for range, nd_range
2323
#include <sycl/properties/property_traits.hpp> // for is_property, is_property_of
@@ -142,6 +142,14 @@ class __SYCL_EXPORT node {
142142
/// Update the Range of this node if it is a kernel execution node
143143
template <int Dimensions> void update_range(range<Dimensions> executionRange);
144144

145+
/// Common Reference Semantics
146+
friend bool operator==(const node &LHS, const node &RHS) {
147+
return LHS.impl == RHS.impl;
148+
}
149+
friend bool operator!=(const node &LHS, const node &RHS) {
150+
return LHS.impl != RHS.impl;
151+
}
152+
145153
private:
146154
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}
147155

@@ -181,6 +189,16 @@ class __SYCL_EXPORT dynamic_command_group {
181189
size_t get_active_index() const;
182190
void set_active_index(size_t Index);
183191

192+
/// Common Reference Semantics
193+
friend bool operator==(const dynamic_command_group &LHS,
194+
const dynamic_command_group &RHS) {
195+
return LHS.impl == RHS.impl;
196+
}
197+
friend bool operator!=(const dynamic_command_group &LHS,
198+
const dynamic_command_group &RHS) {
199+
return LHS.impl != RHS.impl;
200+
}
201+
184202
private:
185203
template <class Obj>
186204
friend const decltype(Obj::impl) &
@@ -307,6 +325,16 @@ class __SYCL_EXPORT modifiable_command_graph
307325
/// Get a list of all root nodes (nodes without dependencies) in this graph.
308326
std::vector<node> get_root_nodes() const;
309327

328+
/// Common Reference Semantics
329+
friend bool operator==(const modifiable_command_graph &LHS,
330+
const modifiable_command_graph &RHS) {
331+
return LHS.impl == RHS.impl;
332+
}
333+
friend bool operator!=(const modifiable_command_graph &LHS,
334+
const modifiable_command_graph &RHS) {
335+
return LHS.impl != RHS.impl;
336+
}
337+
310338
protected:
311339
/// Constructor used internally by the runtime.
312340
/// @param Impl Detail implementation class to construct object with.
@@ -386,6 +414,16 @@ class __SYCL_EXPORT executable_command_graph
386414
/// @param Nodes The nodes to use for updating the graph.
387415
void update(const std::vector<node> &Nodes);
388416

417+
/// Common Reference Semantics
418+
friend bool operator==(const executable_command_graph &LHS,
419+
const executable_command_graph &RHS) {
420+
return LHS.impl == RHS.impl;
421+
}
422+
friend bool operator!=(const executable_command_graph &LHS,
423+
const executable_command_graph &RHS) {
424+
return LHS.impl != RHS.impl;
425+
}
426+
389427
protected:
390428
/// Constructor used by internal runtime.
391429
/// @param Graph Detail implementation class to construct with.
@@ -452,6 +490,16 @@ class __SYCL_EXPORT dynamic_parameter_base {
452490
Graph,
453491
size_t ParamSize, const void *Data);
454492

493+
/// Common Reference Semantics
494+
friend bool operator==(const dynamic_parameter_base &LHS,
495+
const dynamic_parameter_base &RHS) {
496+
return LHS.impl == RHS.impl;
497+
}
498+
friend bool operator!=(const dynamic_parameter_base &LHS,
499+
const dynamic_parameter_base &RHS) {
500+
return LHS.impl != RHS.impl;
501+
}
502+
455503
protected:
456504
void updateValue(const void *NewValue, size_t Size);
457505

@@ -512,3 +560,37 @@ command_graph(const context &SyclContext, const device &SyclDevice,
512560

513561
} // namespace _V1
514562
} // namespace sycl
563+
564+
namespace std {
565+
template <> struct __SYCL_EXPORT hash<sycl::ext::oneapi::experimental::node> {
566+
size_t operator()(const sycl::ext::oneapi::experimental::node &Node) const;
567+
};
568+
569+
template <>
570+
struct __SYCL_EXPORT
571+
hash<sycl::ext::oneapi::experimental::dynamic_command_group> {
572+
size_t operator()(const sycl::ext::oneapi::experimental::dynamic_command_group
573+
&DynamicCGH) const;
574+
};
575+
576+
template <sycl::ext::oneapi::experimental::graph_state State>
577+
struct __SYCL_EXPORT
578+
hash<sycl::ext::oneapi::experimental::command_graph<State>> {
579+
size_t operator()(const sycl::ext::oneapi::experimental::command_graph<State>
580+
&Graph) const {
581+
auto ID = sycl::detail::getSyclObjImpl(Graph)->getID();
582+
return std::hash<decltype(ID)>()(ID);
583+
}
584+
};
585+
586+
template <typename ValueT>
587+
struct __SYCL_EXPORT
588+
hash<sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>> {
589+
size_t
590+
operator()(const sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>
591+
&DynamicParam) const {
592+
auto ID = sycl::detail::getSyclObjImpl(DynamicParam)->getID();
593+
return std::hash<decltype(ID)>()(ID);
594+
}
595+
};
596+
} // namespace std

sycl/source/detail/graph_impl.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
324324
const sycl::device &SyclDevice,
325325
const sycl::property_list &PropList)
326326
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
327-
MEventsMap(), MInorderQueueMap() {
327+
MEventsMap(), MInorderQueueMap(),
328+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
328329
checkGraphPropertiesAndThrow(PropList);
329330
if (PropList.has_property<property::graph::no_cycle_check>()) {
330331
MSkipCycleChecks = true;
@@ -913,7 +914,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
913914
MExecutionEvents(),
914915
MIsUpdatable(PropList.has_property<property::graph::updatable>()),
915916
MEnableProfiling(
916-
PropList.has_property<property::graph::enable_profiling>()) {
917+
PropList.has_property<property::graph::enable_profiling>()),
918+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
917919
checkGraphPropertiesAndThrow(PropList);
918920
// If the graph has been marked as updatable then check if the backend
919921
// actually supports that. Devices supporting aspect::ext_oneapi_graph must
@@ -2035,7 +2037,8 @@ void dynamic_parameter_impl::updateCGAccessor(
20352037

20362038
dynamic_command_group_impl::dynamic_command_group_impl(
20372039
const command_graph<graph_state::modifiable> &Graph)
2038-
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {}
2040+
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0),
2041+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {}
20392042

20402043
void dynamic_command_group_impl::finalizeCGFList(
20412044
const std::vector<std::function<void(handler &)>> &CGFList) {
@@ -2159,3 +2162,17 @@ void dynamic_command_group::set_active_index(size_t Index) {
21592162
} // namespace ext
21602163
} // namespace _V1
21612164
} // namespace sycl
2165+
2166+
size_t std::hash<sycl::ext::oneapi::experimental::node>::operator()(
2167+
const sycl::ext::oneapi::experimental::node &Node) const {
2168+
auto ID = sycl::detail::getSyclObjImpl(Node)->getID();
2169+
return std::hash<decltype(ID)>()(ID);
2170+
}
2171+
2172+
size_t
2173+
std::hash<sycl::ext::oneapi::experimental::dynamic_command_group>::operator()(
2174+
const sycl::ext::oneapi::experimental::dynamic_command_group &DynamicCGH)
2175+
const {
2176+
auto ID = sycl::detail::getSyclObjImpl(DynamicCGH)->getID();
2177+
return std::hash<decltype(ID)>()(ID);
2178+
}

sycl/source/detail/graph_impl.hpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11201120
return MBarrierDependencyMap[Queue];
11211121
}
11221122

1123+
unsigned long long getID() { return MID; }
1124+
11231125
private:
11241126
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
11251127
/// @param NodeFunc A function which receives as input a node in the graph to
@@ -1198,6 +1200,9 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11981200
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
11991201
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
12001202
MBarrierDependencyMap;
1203+
1204+
unsigned long long MID;
1205+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
12011206
};
12021207

12031208
/// Class representing the implementation of command_graph<executable>.
@@ -1297,6 +1302,8 @@ class exec_graph_impl {
12971302

12981303
void updateImpl(std::shared_ptr<node_impl> NodeImpl);
12991304

1305+
unsigned long long getID() { return MID; }
1306+
13001307
private:
13011308
/// Create a command-group for the node and add it to command-buffer by going
13021309
/// through the scheduler.
@@ -1408,21 +1415,26 @@ class exec_graph_impl {
14081415
// Stores a cache of node ids from modifiable graph nodes to the companion
14091416
// node(s) in this graph. Used for quick access when updating this graph.
14101417
std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;
1418+
1419+
unsigned long long MID;
1420+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
14111421
};
14121422

14131423
class dynamic_parameter_impl {
14141424
public:
14151425
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
14161426
size_t ParamSize, const void *Data)
1417-
: MGraph(GraphImpl), MValueStorage(ParamSize) {
1427+
: MGraph(GraphImpl), MValueStorage(ParamSize),
1428+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
14181429
std::memcpy(MValueStorage.data(), Data, ParamSize);
14191430
}
14201431

14211432
/// sycl_ext_oneapi_raw_kernel_arg constructor
14221433
/// Parameter size is taken from member of raw_kernel_arg object.
14231434
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl, size_t,
14241435
raw_kernel_arg *Data)
1425-
: MGraph(GraphImpl) {
1436+
: MGraph(GraphImpl),
1437+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
14261438
size_t RawArgSize = Data->MArgSize;
14271439
const void *RawArgData = Data->MArgData;
14281440
MValueStorage.reserve(RawArgSize);
@@ -1493,13 +1505,19 @@ class dynamic_parameter_impl {
14931505
int ArgIndex,
14941506
const sycl::detail::AccessorBaseHost *Acc);
14951507

1508+
unsigned long long getID() { return MID; }
1509+
14961510
// Weak ptrs to node_impls which will be updated
14971511
std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
14981512
// Dynamic command-groups which will be updated
14991513
std::vector<DynamicCGInfo> MDynCGs;
15001514

15011515
std::shared_ptr<graph_impl> MGraph;
15021516
std::vector<std::byte> MValueStorage;
1517+
1518+
private:
1519+
unsigned long long MID;
1520+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
15031521
};
15041522

15051523
class dynamic_command_group_impl
@@ -1540,6 +1558,12 @@ class dynamic_command_group_impl
15401558

15411559
/// List of nodes using this dynamic command-group.
15421560
std::vector<std::weak_ptr<node_impl>> MNodes;
1561+
1562+
unsigned long long getID() { return MID; }
1563+
1564+
private:
1565+
unsigned long long MID;
1566+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
15431567
};
15441568
} // namespace detail
15451569
} // namespace experimental

sycl/unittests/Extensions/CommandGraph/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ set(CMAKE_CXX_EXTENSIONS OFF)
33
add_sycl_unittest(CommandGraphExtensionTests OBJECT
44
Barrier.cpp
55
CommandGraph.cpp
6+
CommonReferenceSemantics.cpp
67
Exceptions.cpp
78
InOrderQueue.cpp
89
MultiThreaded.cpp

0 commit comments

Comments
 (0)