Skip to content

Commit a722e78

Browse files
authored
[SYCL][Graph] Memory reuse for graph allocations in a single graph (#18340)
- Optimize memory use for allocations within a single graph by reusing memory where possible - New handler impl member for node dependency access with the CGF - New E2E tests for memory reuse - Add missing CGType -> string conversion for graph printing alloc and free nodes
1 parent 01e67f8 commit a722e78

19 files changed

+1164
-65
lines changed

sycl/doc/design/CommandGraph.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,26 @@ safely assumed to be more performant. It is not likely we'll try to allow
349349
in-order execution in more scenarios through a complicated (and imperfect)
350350
heuristic but rather expose this as a hint the user can provide.
351351

352+
### Graph Allocation Memory Reuse
353+
354+
When adding a new allocation node to a graph, memory allocations which have
355+
previously been freed are checked to see if they can be reused. Because we have
356+
to return a pointer to the user immediately when the CGF for a node is
357+
processed, we have to do this eagerly anytime `async_malloc()` is called.
358+
Allocations track the last free node associated with them to represent the most
359+
recent use of that allocation.
360+
361+
To be reused, the two allocations must meet these criteria:
362+
363+
- They must be of the same allocation type (device/host/shared).
364+
- They must have a matching size.
365+
- They must have the same properties (currently only read-only matters).
366+
- There must be a path from the last free node associated with a given
367+
allocation to the new allocation node being added.
368+
369+
If these criteria are met we update the last free node for the allocation and
370+
return the existing pointer to the user.
371+
352372
## Backend Implementation
353373

354374
Implementation of UR command-buffers for each of the supported SYCL 2020

sycl/source/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ set(SYCL_COMMON_SOURCES
322322
"detail/memory_pool_impl.cpp"
323323
"detail/async_alloc.cpp"
324324
"detail/memory_pool.cpp"
325+
"detail/graph_memory_pool.cpp"
325326
"$<$<PLATFORM_ID:Windows>:detail/windows_ur.cpp>"
326327
"$<$<OR:$<PLATFORM_ID:Linux>,$<PLATFORM_ID:Darwin>>:detail/posix_ur.cpp>"
327328
)

sycl/source/detail/async_alloc.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "sycl/accessor.hpp"
910
#include <detail/context_impl.hpp>
1011
#include <detail/event_impl.hpp>
12+
#include <detail/graph_impl.hpp>
1113
#include <detail/queue_impl.hpp>
1214
#include <sycl/detail/ur.hpp>
1315
#include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
@@ -29,6 +31,27 @@ getUrEvents(const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
2931
}
3032
return RetUrEvents;
3133
}
34+
35+
std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
36+
sycl::handler &Handler, const std::shared_ptr<detail::queue_impl> &Queue,
37+
const std::shared_ptr<detail::graph_impl> &Graph,
38+
const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
39+
auto HandlerImpl = detail::getSyclObjImpl(Handler);
40+
// Get dependent graph nodes from any events
41+
auto DepNodes = Graph->getNodesForEvents(DepEvents);
42+
// If this node was added explicitly we may have node deps in the handler as
43+
// well, so add them to the list
44+
DepNodes.insert(DepNodes.end(), HandlerImpl->MNodeDeps.begin(),
45+
HandlerImpl->MNodeDeps.end());
46+
// If this is being recorded from an in-order queue we need to get the last
47+
// in-order node if any, since this will later become a dependency of the
48+
// node being processed here.
49+
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
50+
LastInOrderNode) {
51+
DepNodes.push_back(LastInOrderNode);
52+
}
53+
return DepNodes;
54+
}
3255
} // namespace
3356

3457
__SYCL_EXPORT
@@ -46,22 +69,23 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
4669

4770
auto &Adapter = h.getContextImplPtr()->getAdapter();
4871

49-
// Get events to wait on.
50-
auto depEvents = getUrEvents(h.impl->CGData.MEvents);
51-
uint32_t numEvents = h.impl->CGData.MEvents.size();
72+
// Get CG event dependencies for this allocation.
73+
const auto &DepEvents = h.impl->CGData.MEvents;
74+
auto UREvents = getUrEvents(DepEvents);
5275

5376
void *alloc = nullptr;
5477

5578
ur_event_handle_t Event = nullptr;
5679
// If a graph is present do the allocation from the graph memory pool instead.
5780
if (auto Graph = h.getCommandGraph(); Graph) {
58-
alloc = Graph->getMemPool().malloc(size, kind);
81+
auto DepNodes = getDepGraphNodes(h, h.MQueue, Graph, DepEvents);
82+
alloc = Graph->getMemPool().malloc(size, kind, DepNodes);
5983
} else {
6084
auto &Q = h.MQueue->getHandleRef();
6185
Adapter->call<sycl::errc::runtime,
6286
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
63-
Q, (ur_usm_pool_handle_t)0, size, nullptr, numEvents, depEvents.data(),
64-
&alloc, &Event);
87+
Q, (ur_usm_pool_handle_t)0, size, nullptr, UREvents.size(),
88+
UREvents.data(), &alloc, &Event);
6589
}
6690

6791
// Async malloc must return a void* immediately.
@@ -95,24 +119,26 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
95119
auto &Adapter = h.getContextImplPtr()->getAdapter();
96120
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);
97121

98-
// Get events to wait on.
99-
auto depEvents = getUrEvents(h.impl->CGData.MEvents);
100-
uint32_t numEvents = h.impl->CGData.MEvents.size();
122+
// Get CG event dependencies for this allocation.
123+
const auto &DepEvents = h.impl->CGData.MEvents;
124+
auto UREvents = getUrEvents(DepEvents);
101125

102126
void *alloc = nullptr;
103127

104128
ur_event_handle_t Event = nullptr;
105129
// If a graph is present do the allocation from the graph memory pool instead.
106130
if (auto Graph = h.getCommandGraph(); Graph) {
131+
auto DepNodes = getDepGraphNodes(h, h.MQueue, Graph, DepEvents);
132+
107133
// Memory pool is passed as the graph may use some properties of it.
108-
alloc = Graph->getMemPool().malloc(size, pool.get_alloc_kind(),
134+
alloc = Graph->getMemPool().malloc(size, pool.get_alloc_kind(), DepNodes,
109135
sycl::detail::getSyclObjImpl(pool));
110136
} else {
111137
auto &Q = h.MQueue->getHandleRef();
112138
Adapter->call<sycl::errc::runtime,
113139
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
114-
Q, memPoolImpl.get()->get_handle(), size, nullptr, numEvents,
115-
depEvents.data(), &alloc, &Event);
140+
Q, memPoolImpl.get()->get_handle(), size, nullptr, UREvents.size(),
141+
UREvents.data(), &alloc, &Event);
116142
}
117143
// Async malloc must return a void* immediately.
118144
// Set up CommandGroup which is a no-op and pass the event from the alloc.
@@ -140,6 +166,9 @@ async_malloc_from_pool(const sycl::queue &q, size_t size,
140166
}
141167

142168
__SYCL_EXPORT void async_free(sycl::handler &h, void *ptr) {
169+
// We only check for errors for the graph here because marking the allocation
170+
// as free in the graph memory pool requires a node object which doesn't exist
171+
// at this point.
143172
if (auto Graph = h.getCommandGraph(); Graph) {
144173
// Check if the pointer to be freed has an associated allocation node, and
145174
// error if not

sycl/source/detail/graph_impl.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
344344
const sycl::device &SyclDevice,
345345
const sycl::property_list &PropList)
346346
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
347-
MEventsMap(), MInorderQueueMap(), MGraphMemPool(SyclContext, SyclDevice),
347+
MEventsMap(), MInorderQueueMap(),
348+
MGraphMemPool(*this, SyclContext, SyclDevice),
348349
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
349350
checkGraphPropertiesAndThrow(PropList);
350351
if (PropList.has_property<property::graph::no_cycle_check>()) {
@@ -509,6 +510,10 @@ graph_impl::add(std::function<void(handler &)> CGF,
509510
sycl::handler Handler{shared_from_this()};
510511
#endif
511512

513+
// Pass the node deps to the handler so they are available when processing the
514+
// CGF, need for async_malloc nodes.
515+
Handler.impl->MNodeDeps = Deps;
516+
512517
#if XPTI_ENABLE_INSTRUMENTATION
513518
// Save code location if one was set in TLS.
514519
// Ideally it would be nice to capture user's call code location
@@ -532,6 +537,10 @@ graph_impl::add(std::function<void(handler &)> CGF,
532537

533538
Handler.finalize();
534539

540+
// In explicit mode the handler processing of the CGF does not need a write
541+
// lock as it does not modify the graph, we extract information from it here
542+
// and modify the graph.
543+
graph_impl::WriteLock Lock(MMutex);
535544
node_type NodeType =
536545
Handler.impl->MUserFacingNodeType !=
537546
ext::oneapi::experimental::node_type::empty
@@ -602,6 +611,14 @@ graph_impl::add(node_type NodeType,
602611

603612
addDepsToNode(NodeImpl, Deps);
604613

614+
if (NodeType == node_type::async_free) {
615+
auto AsyncFreeCG =
616+
static_cast<CGAsyncFree *>(NodeImpl->MCommandGroup.get());
617+
// If this is an async free node mark that it is now available for reuse,
618+
// and pass the async free node for tracking.
619+
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl);
620+
}
621+
605622
return NodeImpl;
606623
}
607624

@@ -1791,7 +1808,6 @@ node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
17911808
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
17921809
}
17931810

1794-
graph_impl::WriteLock Lock(impl->MMutex);
17951811
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
17961812
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
17971813
}

sycl/source/detail/graph_impl.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,12 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
735735
case sycl::detail::CGType::EnqueueNativeCommand:
736736
Stream << "CGNativeCommand \\n";
737737
break;
738+
case sycl::detail::CGType::AsyncAlloc:
739+
Stream << "CGAsyncAlloc \\n";
740+
break;
741+
case sycl::detail::CGType::AsyncFree:
742+
Stream << "CGAsyncFree \\n";
743+
break;
738744
default:
739745
Stream << "Other \\n";
740746
break;
@@ -937,6 +943,31 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
937943
"No node in this graph is associated with this event");
938944
}
939945

946+
/// Find the nodes associated with a list of SYCL events. Throws if no node is
947+
/// found for a given event.
948+
/// @param Events Events to find nodes for.
949+
/// @return A list of node counterparts for each event, in the same order.
950+
std::vector<std::shared_ptr<node_impl>> getNodesForEvents(
951+
const std::vector<std::shared_ptr<sycl::detail::event_impl>> &Events) {
952+
std::vector<std::shared_ptr<node_impl>> NodeList{};
953+
NodeList.reserve(Events.size());
954+
955+
ReadLock Lock(MMutex);
956+
957+
for (const auto &Event : Events) {
958+
if (auto NodeFound = MEventsMap.find(Event);
959+
NodeFound != std::end(MEventsMap)) {
960+
NodeList.push_back(NodeFound->second);
961+
} else {
962+
throw sycl::exception(
963+
sycl::make_error_code(errc::invalid),
964+
"No node in this graph is associated with this event");
965+
}
966+
}
967+
968+
return NodeList;
969+
}
970+
940971
/// Query for the context tied to this graph.
941972
/// @return Context associated with graph.
942973
sycl::context getContext() const { return MContext; }
@@ -1191,6 +1222,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11911222
/// this graph.
11921223
size_t getExecGraphCount() const { return MExecGraphCount; }
11931224

1225+
/// Resets the visited edges variable across all nodes in the graph to 0.
1226+
void resetNodeVisitedEdges() {
1227+
for (auto &Node : MNodeStorage) {
1228+
Node->MTotalVisitedEdges = 0;
1229+
}
1230+
}
1231+
11941232
private:
11951233
/// Check the graph for cycles by performing a depth-first search of the
11961234
/// graph. If a node is visited more than once in a given path through the

0 commit comments

Comments
 (0)