6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " sycl/accessor.hpp"
9
10
#include < detail/context_impl.hpp>
10
11
#include < detail/event_impl.hpp>
12
+ #include < detail/graph_impl.hpp>
11
13
#include < detail/queue_impl.hpp>
12
14
#include < sycl/detail/ur.hpp>
13
15
#include < sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
@@ -29,6 +31,27 @@ getUrEvents(const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
29
31
}
30
32
return RetUrEvents;
31
33
}
34
+
35
+ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes (
36
+ sycl::handler &Handler, const std::shared_ptr<detail::queue_impl> &Queue,
37
+ const std::shared_ptr<detail::graph_impl> &Graph,
38
+ const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
39
+ auto HandlerImpl = detail::getSyclObjImpl (Handler);
40
+ // Get dependent graph nodes from any events
41
+ auto DepNodes = Graph->getNodesForEvents (DepEvents);
42
+ // If this node was added explicitly we may have node deps in the handler as
43
+ // well, so add them to the list
44
+ DepNodes.insert (DepNodes.end (), HandlerImpl->MNodeDeps .begin (),
45
+ HandlerImpl->MNodeDeps .end ());
46
+ // If this is being recorded from an in-order queue we need to get the last
47
+ // in-order node if any, since this will later become a dependency of the
48
+ // node being processed here.
49
+ if (const auto &LastInOrderNode = Graph->getLastInorderNode (Queue);
50
+ LastInOrderNode) {
51
+ DepNodes.push_back (LastInOrderNode);
52
+ }
53
+ return DepNodes;
54
+ }
32
55
} // namespace
33
56
34
57
__SYCL_EXPORT
@@ -46,22 +69,23 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
46
69
47
70
auto &Adapter = h.getContextImplPtr ()->getAdapter ();
48
71
49
- // Get events to wait on .
50
- auto depEvents = getUrEvents ( h.impl ->CGData .MEvents ) ;
51
- uint32_t numEvents = h. impl -> CGData . MEvents . size ( );
72
+ // Get CG event dependencies for this allocation .
73
+ const auto &DepEvents = h.impl ->CGData .MEvents ;
74
+ auto UREvents = getUrEvents (DepEvents );
52
75
53
76
void *alloc = nullptr ;
54
77
55
78
ur_event_handle_t Event = nullptr ;
56
79
// If a graph is present do the allocation from the graph memory pool instead.
57
80
if (auto Graph = h.getCommandGraph (); Graph) {
58
- alloc = Graph->getMemPool ().malloc (size, kind);
81
+ auto DepNodes = getDepGraphNodes (h, h.MQueue , Graph, DepEvents);
82
+ alloc = Graph->getMemPool ().malloc (size, kind, DepNodes);
59
83
} else {
60
84
auto &Q = h.MQueue ->getHandleRef ();
61
85
Adapter->call <sycl::errc::runtime,
62
86
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
63
- Q, (ur_usm_pool_handle_t )0 , size, nullptr , numEvents, depEvents. data (),
64
- &alloc, &Event);
87
+ Q, (ur_usm_pool_handle_t )0 , size, nullptr , UREvents. size (),
88
+ UREvents. data (), &alloc, &Event);
65
89
}
66
90
67
91
// Async malloc must return a void* immediately.
@@ -95,24 +119,26 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
95
119
auto &Adapter = h.getContextImplPtr ()->getAdapter ();
96
120
auto &memPoolImpl = sycl::detail::getSyclObjImpl (pool);
97
121
98
- // Get events to wait on .
99
- auto depEvents = getUrEvents ( h.impl ->CGData .MEvents ) ;
100
- uint32_t numEvents = h. impl -> CGData . MEvents . size ( );
122
+ // Get CG event dependencies for this allocation .
123
+ const auto &DepEvents = h.impl ->CGData .MEvents ;
124
+ auto UREvents = getUrEvents (DepEvents );
101
125
102
126
void *alloc = nullptr ;
103
127
104
128
ur_event_handle_t Event = nullptr ;
105
129
// If a graph is present do the allocation from the graph memory pool instead.
106
130
if (auto Graph = h.getCommandGraph (); Graph) {
131
+ auto DepNodes = getDepGraphNodes (h, h.MQueue , Graph, DepEvents);
132
+
107
133
// Memory pool is passed as the graph may use some properties of it.
108
- alloc = Graph->getMemPool ().malloc (size, pool.get_alloc_kind (),
134
+ alloc = Graph->getMemPool ().malloc (size, pool.get_alloc_kind (), DepNodes,
109
135
sycl::detail::getSyclObjImpl (pool));
110
136
} else {
111
137
auto &Q = h.MQueue ->getHandleRef ();
112
138
Adapter->call <sycl::errc::runtime,
113
139
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
114
- Q, memPoolImpl.get ()->get_handle (), size, nullptr , numEvents ,
115
- depEvents .data (), &alloc, &Event);
140
+ Q, memPoolImpl.get ()->get_handle (), size, nullptr , UREvents. size () ,
141
+ UREvents .data (), &alloc, &Event);
116
142
}
117
143
// Async malloc must return a void* immediately.
118
144
// Set up CommandGroup which is a no-op and pass the event from the alloc.
@@ -140,6 +166,9 @@ async_malloc_from_pool(const sycl::queue &q, size_t size,
140
166
}
141
167
142
168
__SYCL_EXPORT void async_free (sycl::handler &h, void *ptr) {
169
+ // We only check for errors for the graph here because marking the allocation
170
+ // as free in the graph memory pool requires a node object which doesn't exist
171
+ // at this point.
143
172
if (auto Graph = h.getCommandGraph (); Graph) {
144
173
// Check if the pointer to be freed has an associated allocation node, and
145
174
// error if not
0 commit comments