Skip to content

Commit 6e3be4e

Browse files
author
Hugh Delaney
committed
Keep track of only shared USM allocations and prefetch only for those
1 parent 3d6bc0f commit 6e3be4e

File tree

3 files changed

+21
-63
lines changed

3 files changed

+21
-63
lines changed

sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#pragma once
99

1010
#include <set>
11-
#include <unordered_map>
1211

1312
#include "common.hpp"
1413
#include "device.hpp"
@@ -104,57 +103,6 @@ struct ur_context_handle_t_ {
104103

105104
ur_usm_pool_handle_t getOwningURPool(umf_memory_pool_t *UMFPool);
106105

107-
/// We need to keep track of USM mappings in AMD HIP, as certain extra
108-
/// synchronization *is* actually required for correctness.
109-
/// During kernel enqueue we must dispatch a prefetch for each kernel argument
110-
/// that points to a USM mapping to ensure the mapping is correctly
111-
/// populated on the device (https://github.com/intel/llvm/issues/7252). Thus,
112-
/// we keep track of mappings in the context, and then check against them just
113-
/// before the kernel is launched. The stream against which the kernel is
114-
/// launched is not known until enqueue time, but the USM mappings can happen
115-
/// at any time. Thus, they are tracked on the context used for the urUSM*
116-
/// mapping.
117-
///
118-
/// The three utility function are simple wrappers around a mapping from a
119-
/// pointer to a size.
120-
void addUSMMapping(void *Ptr, size_t Size) {
121-
std::lock_guard<std::mutex> Guard(Mutex);
122-
assert(USMMappings.find(Ptr) == USMMappings.end() &&
123-
"mapping already exists");
124-
USMMappings[Ptr] = Size;
125-
}
126-
127-
void removeUSMMapping(const void *Ptr) {
128-
std::lock_guard<std::mutex> guard(Mutex);
129-
auto It = USMMappings.find(Ptr);
130-
if (It != USMMappings.end())
131-
USMMappings.erase(It);
132-
}
133-
134-
std::pair<const void *, size_t> getUSMMapping(const void *Ptr) {
135-
std::lock_guard<std::mutex> Guard(Mutex);
136-
auto It = USMMappings.find(Ptr);
137-
// The simple case is the fast case...
138-
if (It != USMMappings.end())
139-
return *It;
140-
141-
// ... but in the failure case we have to fall back to a full scan to search
142-
// for "offset" pointers in case the user passes in the middle of an
143-
// allocation. We have to do some not-so-ordained-by-the-standard ordered
144-
// comparisons of pointers here, but it'll work on all platforms we support.
145-
uintptr_t PtrVal = (uintptr_t)Ptr;
146-
for (std::pair<const void *, size_t> Pair : USMMappings) {
147-
uintptr_t BaseAddr = (uintptr_t)Pair.first;
148-
uintptr_t EndAddr = BaseAddr + Pair.second;
149-
if (PtrVal > BaseAddr && PtrVal < EndAddr) {
150-
// If we've found something now, offset *must* be nonzero
151-
assert(Pair.second);
152-
return Pair;
153-
}
154-
}
155-
return {nullptr, 0};
156-
}
157-
158106
private:
159107
std::mutex Mutex;
160108
std::vector<deleter_data> ExtendedDeleters;

sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
254254
try {
255255
ur_device_handle_t Dev = hQueue->getDevice();
256256
ScopedContext Active(Dev);
257-
ur_context_handle_t Ctx = hQueue->getContext();
258257

259258
uint32_t StreamToken;
260259
ur_stream_quard Guard;
@@ -263,15 +262,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
263262
hipFunction_t HIPFunc = hKernel->get();
264263

265264
hipDevice_t HIPDev = Dev->get();
266-
for (const void *P : hKernel->getPtrArgs()) {
267-
auto [Addr, Size] = Ctx->getUSMMapping(P);
268-
if (!Addr)
269-
continue;
270-
if (hipMemPrefetchAsync(Addr, Size, HIPDev, HIPStream) != hipSuccess)
271-
return UR_RESULT_ERROR_INVALID_KERNEL_ARGS;
265+
266+
// Some args using shared USM require prefetch
267+
for (auto [Ptr, Size] : hKernel->Args.PtrArgsRequiringPrefetch) {
268+
if (Ptr && Size) {
269+
UR_CHECK_ERROR(hipMemPrefetchAsync(Ptr, Size, HIPDev, HIPStream));
270+
}
272271
}
273-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
274-
phEventWaitList);
272+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
273+
phEventWaitList));
275274

276275
// Set the implicit global offset parameter if kernel has offset variant
277276
if (hKernel->getWithOffsetParameter()) {

sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ struct ur_kernel_handle_t_ {
5757
args_index_t Indices;
5858
args_size_t OffsetPerIndex;
5959
std::set<const void *> PtrArgs;
60+
// Ptr args needing prefetch arranged [Ptr, Size of alloca]
61+
std::set<std::pair<const void *, size_t>> PtrArgsRequiringPrefetch;
6062

6163
std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0};
6264

@@ -177,11 +179,20 @@ struct ur_kernel_handle_t_ {
177179
Args.addArg(Index, Size, Arg);
178180
}
179181

180-
/// We track all pointer arguments to be able to issue prefetches at enqueue
181-
/// time
182182
void setKernelPtrArg(int Index, size_t Size, const void *PtrArg) {
183183
Args.PtrArgs.insert(*static_cast<void *const *>(PtrArg));
184184
setKernelArg(Index, Size, PtrArg);
185+
// Ptr args using managed memory may require prefetch
186+
hipPointerAttribute_t Attribs;
187+
// We are only using hipPointerGetAttributes to check if the ptr refers to
188+
// a managed memory location, meaning the Ptr may require a prefetch at
189+
// kernel launch. If this call fails then it means that the Ptr may have
190+
// been a host Ptr, which is not a problem
191+
if (hipPointerGetAttributes(&Attribs, PtrArg) == hipSuccess &&
192+
Attribs.isManaged) {
193+
Args.PtrArgsRequiringPrefetch.insert(
194+
{*static_cast<void *const *>(PtrArg), Size});
195+
}
185196
}
186197

187198
bool isPtrArg(const void *ptr) {

0 commit comments

Comments
 (0)