@@ -317,9 +317,23 @@ static void flushCrossQueueDeps(const std::vector<EventImplPtr> &EventImpls,
317
317
}
318
318
}
319
319
320
+ namespace {
321
+
322
+ struct EnqueueNativeCommandData {
323
+ sycl::interop_handle ih;
324
+ std::function<void (interop_handle)> func;
325
+ };
326
+
327
+ void InteropFreeFunc (pi_queue InteropQueue, void *InteropData) {
328
+ auto *Data = reinterpret_cast <EnqueueNativeCommandData *>(InteropData);
329
+ return Data->func (Data->ih );
330
+ };
331
+ } // namespace
332
+
320
333
class DispatchHostTask {
321
334
ExecCGCommand *MThisCmd;
322
335
std::vector<interop_handle::ReqToMem> MReqToMem;
336
+ std::vector<pi_mem> MReqPiMem;
323
337
324
338
bool waitForEvents () const {
325
339
std::map<const PluginPtr, std::vector<EventImplPtr>>
@@ -365,8 +379,10 @@ class DispatchHostTask {
365
379
366
380
public:
367
381
DispatchHostTask (ExecCGCommand *ThisCmd,
368
- std::vector<interop_handle::ReqToMem> ReqToMem)
369
- : MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)) {}
382
+ std::vector<interop_handle::ReqToMem> ReqToMem,
383
+ std::vector<pi_mem> ReqPiMem)
384
+ : MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)),
385
+ MReqPiMem (std::move(ReqPiMem)) {}
370
386
371
387
void operator ()() const {
372
388
assert (MThisCmd->getCG ().getType () == CG::CGTYPE::CodeplayHostTask);
@@ -402,8 +418,27 @@ class DispatchHostTask {
402
418
interop_handle IH{MReqToMem, HostTask.MQueue ,
403
419
HostTask.MQueue ->getDeviceImplPtr (),
404
420
HostTask.MQueue ->getContextImplPtr ()};
405
-
406
- HostTask.MHostTask ->call (MThisCmd->MEvent ->getHostProfilingInfo (), IH);
421
+ if (IH.get_backend () == backend::ext_oneapi_cuda ||
422
+ IH.get_backend () == backend::ext_oneapi_hip) {
423
+ EnqueueNativeCommandData CustomOpData{
424
+ IH, HostTask.MHostTask ->MInteropTask };
425
+
426
+ // We are assuming that we have already synchronized with the HT's
427
+ // dependent events, and that the user will synchronize before the end
428
+ // of the HT lambda. As such we don't pass in any events, or ask for
429
+ // one back.
430
+ //
431
+ // This entry point is needed in order to migrate memory across
432
+ // devices in the same context for CUDA and HIP backends
433
+ HostTask.MQueue ->getPlugin ()
434
+ ->call <PiApiKind::piextEnqueueNativeCommand>(
435
+ HostTask.MQueue ->getHandleRef (), InteropFreeFunc,
436
+ &CustomOpData, MReqPiMem.size (), MReqPiMem.data (),
437
+ 0 , nullptr , nullptr );
438
+ } else {
439
+ HostTask.MHostTask ->call (MThisCmd->MEvent ->getHostProfilingInfo (),
440
+ IH);
441
+ }
407
442
} else
408
443
HostTask.MHostTask ->call (MThisCmd->MEvent ->getHostProfilingInfo ());
409
444
} catch (...) {
@@ -3121,13 +3156,14 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
3121
3156
}
3122
3157
3123
3158
std::vector<interop_handle::ReqToMem> ReqToMem;
3159
+ std::vector<pi_mem> ReqPiMem;
3124
3160
3125
3161
if (HostTask->MHostTask ->isInteropTask ()) {
3126
3162
// Extract the Mem Objects for all Requirements, to ensure they are
3127
3163
// available if a user asks for them inside the interop task scope
3128
3164
const std::vector<Requirement *> &HandlerReq =
3129
3165
HostTask->getRequirements ();
3130
- auto ReqToMemConv = [&ReqToMem, HostTask](Requirement *Req) {
3166
+ auto ReqToMemConv = [&ReqToMem, &ReqPiMem, HostTask](Requirement *Req) {
3131
3167
const std::vector<AllocaCommandBase *> &AllocaCmds =
3132
3168
Req->MSYCLMemObj ->MRecord ->MAllocaCommands ;
3133
3169
@@ -3137,6 +3173,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
3137
3173
auto MemArg =
3138
3174
reinterpret_cast <pi_mem>(AllocaCmd->getMemAllocation ());
3139
3175
ReqToMem.emplace_back (std::make_pair (Req, MemArg));
3176
+ ReqPiMem.emplace_back (MemArg);
3140
3177
3141
3178
return ;
3142
3179
}
@@ -3158,7 +3195,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
3158
3195
copySubmissionCodeLocation ();
3159
3196
3160
3197
MQueue->getThreadPool ().submit <DispatchHostTask>(
3161
- DispatchHostTask (this , std::move (ReqToMem)));
3198
+ DispatchHostTask (this , std::move (ReqToMem)), std::move (ReqPiMem) );
3162
3199
3163
3200
MShouldCompleteEventIfPossible = false ;
3164
3201
0 commit comments