Skip to content

Commit def7c42

Browse files
author
Hugh Delaney
committed
Fix host task mem migration for CUDA and HIP
The SYCL RT assumes that for devices in the same context, no mem migration needs to occur across devices for a kernel launch or host task. However, a CUdeviceptr is relevant to a specific device, so mem migration must occur between devices in a ctx. If this assumption that the SYCL RT makes about native mems being accessible to all devices in a context, it must hand off the HT lambda to the plugin, so that the plugin can handle the necessary mem migration. This patch uses the new urEnqueueCustomCommandExp to execute the HT lambda, which takes care of mem migration implicitly in the plugin.
1 parent 3040061 commit def7c42

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,23 @@ static void flushCrossQueueDeps(const std::vector<EventImplPtr> &EventImpls,
317317
}
318318
}
319319

320+
namespace {
321+
322+
struct EnqueueNativeCommandData {
323+
sycl::interop_handle ih;
324+
std::function<void(interop_handle)> func;
325+
};
326+
327+
void InteropFreeFunc(pi_queue InteropQueue, void *InteropData) {
328+
auto *Data = reinterpret_cast<EnqueueNativeCommandData *>(InteropData);
329+
return Data->func(Data->ih);
330+
};
331+
} // namespace
332+
320333
class DispatchHostTask {
321334
ExecCGCommand *MThisCmd;
322335
std::vector<interop_handle::ReqToMem> MReqToMem;
336+
std::vector<pi_mem> MReqPiMem;
323337

324338
bool waitForEvents() const {
325339
std::map<const PluginPtr, std::vector<EventImplPtr>>
@@ -365,8 +379,10 @@ class DispatchHostTask {
365379

366380
public:
367381
DispatchHostTask(ExecCGCommand *ThisCmd,
368-
std::vector<interop_handle::ReqToMem> ReqToMem)
369-
: MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)) {}
382+
std::vector<interop_handle::ReqToMem> ReqToMem,
383+
std::vector<pi_mem> ReqPiMem)
384+
: MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)),
385+
MReqPiMem(std::move(ReqPiMem)) {}
370386

371387
void operator()() const {
372388
assert(MThisCmd->getCG().getType() == CG::CGTYPE::CodeplayHostTask);
@@ -402,8 +418,27 @@ class DispatchHostTask {
402418
interop_handle IH{MReqToMem, HostTask.MQueue,
403419
HostTask.MQueue->getDeviceImplPtr(),
404420
HostTask.MQueue->getContextImplPtr()};
405-
406-
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo(), IH);
421+
if (IH.get_backend() == backend::ext_oneapi_cuda ||
422+
IH.get_backend() == backend::ext_oneapi_hip) {
423+
EnqueueNativeCommandData CustomOpData{
424+
IH, HostTask.MHostTask->MInteropTask};
425+
426+
// We are assuming that we have already synchronized with the HT's
427+
// dependent events, and that the user will synchronize before the end
428+
// of the HT lambda. As such we don't pass in any events, or ask for
429+
// one back.
430+
//
431+
// This entry point is needed in order to migrate memory across
432+
// devices in the same context for CUDA and HIP backends
433+
HostTask.MQueue->getPlugin()
434+
->call<PiApiKind::piextEnqueueNativeCommand>(
435+
HostTask.MQueue->getHandleRef(), InteropFreeFunc,
436+
&CustomOpData, MReqPiMem.size(), MReqPiMem.data(),
437+
0, nullptr, nullptr);
438+
} else {
439+
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo(),
440+
IH);
441+
}
407442
} else
408443
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo());
409444
} catch (...) {
@@ -3121,13 +3156,14 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
31213156
}
31223157

31233158
std::vector<interop_handle::ReqToMem> ReqToMem;
3159+
std::vector<pi_mem> ReqPiMem;
31243160

31253161
if (HostTask->MHostTask->isInteropTask()) {
31263162
// Extract the Mem Objects for all Requirements, to ensure they are
31273163
// available if a user asks for them inside the interop task scope
31283164
const std::vector<Requirement *> &HandlerReq =
31293165
HostTask->getRequirements();
3130-
auto ReqToMemConv = [&ReqToMem, HostTask](Requirement *Req) {
3166+
auto ReqToMemConv = [&ReqToMem, &ReqPiMem, HostTask](Requirement *Req) {
31313167
const std::vector<AllocaCommandBase *> &AllocaCmds =
31323168
Req->MSYCLMemObj->MRecord->MAllocaCommands;
31333169

@@ -3137,6 +3173,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
31373173
auto MemArg =
31383174
reinterpret_cast<pi_mem>(AllocaCmd->getMemAllocation());
31393175
ReqToMem.emplace_back(std::make_pair(Req, MemArg));
3176+
ReqPiMem.emplace_back(MemArg);
31403177

31413178
return;
31423179
}
@@ -3158,7 +3195,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
31583195
copySubmissionCodeLocation();
31593196

31603197
MQueue->getThreadPool().submit<DispatchHostTask>(
3161-
DispatchHostTask(this, std::move(ReqToMem)));
3198+
DispatchHostTask(this, std::move(ReqToMem)), std::move(ReqPiMem));
31623199

31633200
MShouldCompleteEventIfPossible = false;
31643201

sycl/test-e2e/HostInteropTask/interop-task-cuda-buffer-migrate.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
// REQUIRES: cuda
2-
// XFAIL: cuda
3-
//
4-
// FIXME: this is broken with a multi device context
52
//
63
// RUN: %{build} -o %t.out -lcuda
74
// RUN: %{run} %t.out

0 commit comments

Comments
 (0)