Skip to content

Commit 2e212e0

Browse files
author
Hugh Delaney
authored
[CUDA][HIP] Fix host task mem migration and add pi entry point for urEnqueueNativeCommandExp (#14353)
The SYCL RT assumes that for devices in the same context, no mem migration needs to occur across devices for a kernel launch or host task. However, a CUdeviceptr is relevant to a specific device, so mem migration must occur between devices in a ctx. If this assumption that the SYCL RT makes about native mems being accessible to all devices in a context, it must hand off the HT lambda to the plugin, so that the plugin can handle the necessary mem migration. This patch uses the new urEnqueueCustomCommandExp to execute the HT lambda, which takes care of mem migration implicitly in the plugin.
1 parent 43286ab commit 2e212e0

File tree

18 files changed

+179
-17
lines changed

18 files changed

+179
-17
lines changed

sycl/include/sycl/detail/host_task_impl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class HostTask {
4747
if (HPI)
4848
HPI->end();
4949
}
50+
51+
friend class DispatchHostTask;
5052
};
5153

5254
class CGHostTask : public CG {

sycl/include/sycl/detail/pi.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,4 +227,7 @@ _PI_API(piextVirtualMemUnmap)
227227
_PI_API(piextVirtualMemSetAccess)
228228
_PI_API(piextVirtualMemGetInfo)
229229

230+
// Enqueue native command
231+
_PI_API(piextEnqueueNativeCommand)
232+
230233
#undef _PI_API

sycl/include/sycl/detail/pi.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,10 @@
195195
// PI_EXT_ONEAPI_DEVICE_INFO_SUPPORTS_VIRTUAL_MEM device info descriptor,
196196
// _pi_virtual_mem_granularity_info enum, _pi_virtual_mem_info enum and
197197
// pi_virtual_access_flags bit flags.
198+
// 15.55 Added piextEnqueueNativeCommand as well as associated types and enums
198199

199200
#define _PI_H_VERSION_MAJOR 15
200-
#define _PI_H_VERSION_MINOR 54
201+
#define _PI_H_VERSION_MINOR 55
201202

202203
#define _PI_STRING_HELPER(a) #a
203204
#define _PI_CONCAT(a, b) _PI_STRING_HELPER(a.b)
@@ -512,6 +513,8 @@ typedef enum {
512513

513514
// Virtual memory support
514515
PI_EXT_ONEAPI_DEVICE_INFO_SUPPORTS_VIRTUAL_MEM = 0x2011E,
516+
// Native enqueue
517+
PI_EXT_ONEAPI_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT = 0x2011F,
515518
} _pi_device_info;
516519

517520
typedef enum {
@@ -1279,6 +1282,7 @@ using pi_image_mem_handle = void *;
12791282
using pi_interop_mem_handle = pi_uint64;
12801283
using pi_interop_semaphore_handle = pi_uint64;
12811284
using pi_physical_mem = _pi_physical_mem *;
1285+
using pi_enqueue_native_command_function = void (*)(pi_queue, void *);
12821286

12831287
typedef struct {
12841288
pi_image_channel_order image_channel_order;
@@ -3201,6 +3205,25 @@ __SYCL_EXPORT pi_result piextSignalExternalSemaphore(
32013205
pi_uint32 num_events_in_wait_list, const pi_event *event_wait_list,
32023206
pi_event *event);
32033207

3208+
/// API to enqueue work through a backend API such that the plugin can schedule
3209+
/// the backend API calls within its own DAG.
3210+
///
3211+
/// \param command_queue is the queue instructed to signal
3212+
/// \param fn is the user submitted native function enqueueing work to a
3213+
/// backend API
3214+
/// \param data is the data that will be used in fn
3215+
/// \param num_mems is the number of mems in mem_list
3216+
/// \param mem_list is the list of mems that are used in fn
3217+
/// \param num_events_in_wait_list is the number of events in the wait list
3218+
/// \param event_wait_list is the list of events to wait on before this
3219+
/// operation
3220+
/// \param event is the returned event representing this operation
3221+
__SYCL_EXPORT pi_result piextEnqueueNativeCommand(
3222+
pi_queue command_queue, pi_enqueue_native_command_function fn, void *data,
3223+
pi_uint32 num_mems, const pi_mem *mem_list,
3224+
pi_uint32 num_events_in_wait_list, const pi_event *event_wait_list,
3225+
pi_event *event);
3226+
32043227
typedef enum {
32053228
_PI_SANITIZE_TYPE_NONE = 0x0,
32063229
_PI_SANITIZE_TYPE_ADDRESS = 0x1,

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,16 @@ pi_result piextVirtualMemGetInfo(pi_context context, const void *ptr,
13611361
param_value_size_ret);
13621362
}
13631363

1364+
pi_result
1365+
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
1366+
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
1367+
pi_uint32 NumEventsInWaitList,
1368+
const pi_event *EventWaitList, pi_event *Event) {
1369+
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
1370+
NumEventsInWaitList, EventWaitList,
1371+
Event);
1372+
}
1373+
13641374
const char SupportedVersion[] = _PI_CUDA_PLUGIN_VERSION_STRING;
13651375

13661376
pi_result piPluginInit(pi_plugin *PluginInit) {

sycl/plugins/hip/pi_hip.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,16 @@ pi_result piextVirtualMemGetInfo(pi_context context, const void *ptr,
13641364
param_value_size_ret);
13651365
}
13661366

1367+
pi_result
1368+
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
1369+
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
1370+
pi_uint32 NumEventsInWaitList,
1371+
const pi_event *EventWaitList, pi_event *Event) {
1372+
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
1373+
NumEventsInWaitList, EventWaitList,
1374+
Event);
1375+
}
1376+
13671377
const char SupportedVersion[] = _PI_HIP_PLUGIN_VERSION_STRING;
13681378

13691379
pi_result piPluginInit(pi_plugin *PluginInit) {

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,6 +1562,16 @@ pi_result piextVirtualMemGetInfo(pi_context Context, const void *Ptr,
15621562
ParamValueSizeRet);
15631563
}
15641564

1565+
pi_result
1566+
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
1567+
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
1568+
pi_uint32 NumEventsInWaitList,
1569+
const pi_event *EventWaitList, pi_event *Event) {
1570+
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
1571+
NumEventsInWaitList, EventWaitList,
1572+
Event);
1573+
}
1574+
15651575
const char SupportedVersion[] = _PI_LEVEL_ZERO_PLUGIN_VERSION_STRING;
15661576

15671577
pi_result piPluginInit(pi_plugin *PluginInit) { // missing

sycl/plugins/native_cpu/pi_native_cpu.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,16 @@ pi_result piextVirtualMemGetInfo(pi_context context, const void *ptr,
13841384
param_value_size_ret);
13851385
}
13861386

1387+
pi_result
1388+
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
1389+
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
1390+
pi_uint32 NumEventsInWaitList,
1391+
const pi_event *EventWaitList, pi_event *Event) {
1392+
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
1393+
NumEventsInWaitList, EventWaitList,
1394+
Event);
1395+
}
1396+
13871397
// Initialize function table with stubs.
13881398
#define _PI_API(api) \
13891399
(PluginInit->PiFunctionTable).api = (decltype(&::api))(&api);

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,16 @@ pi_result piextVirtualMemGetInfo(pi_context Context, const void *Ptr,
12911291
ParamValueSizeRet);
12921292
}
12931293

1294+
pi_result
1295+
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
1296+
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
1297+
pi_uint32 NumEventsInWaitList,
1298+
const pi_event *EventWaitList, pi_event *Event) {
1299+
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
1300+
NumEventsInWaitList, EventWaitList,
1301+
Event);
1302+
}
1303+
12941304
pi_result piTearDown(void *PluginParameter) {
12951305
return pi2ur::piTearDown(PluginParameter);
12961306
}

sycl/plugins/unified_runtime/pi2ur.hpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,9 @@ inline pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
13251325
PI_TO_UR_MAP_DEVICE_INFO(
13261326
PI_EXT_ONEAPI_DEVICE_INFO_TIMESTAMP_RECORDING_SUPPORT,
13271327
UR_DEVICE_INFO_TIMESTAMP_RECORDING_SUPPORT_EXP)
1328+
PI_TO_UR_MAP_DEVICE_INFO(
1329+
PI_EXT_ONEAPI_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT,
1330+
UR_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT_EXP)
13281331
PI_TO_UR_MAP_DEVICE_INFO(PI_EXT_INTEL_DEVICE_INFO_ESIMD_SUPPORT,
13291332
UR_DEVICE_INFO_ESIMD_SUPPORT)
13301333
PI_TO_UR_MAP_DEVICE_INFO(PI_EXT_ONEAPI_DEVICE_INFO_COMPONENT_DEVICES,
@@ -5722,7 +5725,6 @@ piextVirtualMemGranularityGetInfo(pi_context Context, pi_device Device,
57225725
HANDLE_ERRORS(urVirtualMemGranularityGetInfo(UrContext, UrDevice, InfoType,
57235726
ParamValueSize, ParamValue,
57245727
ParamValueSizeRet));
5725-
57265728
return PI_SUCCESS;
57275729
}
57285730

@@ -5882,4 +5884,30 @@ inline pi_result piextVirtualMemGetInfo(pi_context Context, const void *Ptr,
58825884
// Virtual Memory
58835885
///////////////////////////////////////////////////////////////////////////////
58845886

5887+
///////////////////////////////////////////////////////////////////////////////
5888+
// Enqueue Native Command Extension
5889+
inline pi_result
5890+
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
5891+
void *Data, pi_uint32 NumMems, const pi_mem *MemList,
5892+
pi_uint32 NumEventsInWaitList,
5893+
const pi_event *EventWaitList, pi_event *Event) {
5894+
PI_ASSERT(Queue, PI_ERROR_INVALID_QUEUE);
5895+
5896+
auto UrQueue = reinterpret_cast<ur_queue_handle_t>(Queue);
5897+
auto UrFn = reinterpret_cast<void (*)(ur_queue_handle_t, void *)>(Fn);
5898+
const ur_mem_handle_t *UrMemList =
5899+
reinterpret_cast<const ur_mem_handle_t *>(MemList);
5900+
const ur_event_handle_t *UrEventWaitList =
5901+
reinterpret_cast<const ur_event_handle_t *>(EventWaitList);
5902+
ur_event_handle_t *UREvent = reinterpret_cast<ur_event_handle_t *>(Event);
5903+
5904+
HANDLE_ERRORS(urEnqueueNativeCommandExp(
5905+
UrQueue, UrFn, Data, NumMems, UrMemList, nullptr /*pProperties*/,
5906+
NumEventsInWaitList, UrEventWaitList, UREvent));
5907+
5908+
return PI_SUCCESS;
5909+
}
5910+
// Enqueue Native Command Extension
5911+
///////////////////////////////////////////////////////////////////////////////
5912+
58855913
} // namespace pi2ur

sycl/plugins/unified_runtime/pi_unified_runtime.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,16 @@ __SYCL_EXPORT pi_result piextSignalExternalSemaphore(
14471447
EventWaitList, Event);
14481448
}
14491449

1450+
pi_result
1451+
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
1452+
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
1453+
pi_uint32 NumEventsInWaitList,
1454+
const pi_event *EventWaitList, pi_event *Event) {
1455+
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
1456+
NumEventsInWaitList, EventWaitList,
1457+
Event);
1458+
}
1459+
14501460
// This interface is not in Unified Runtime currently
14511461
__SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
14521462
PI_ASSERT(PluginInit, PI_ERROR_INVALID_VALUE);

sycl/source/detail/scheduler/commands.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,23 @@ static void flushCrossQueueDeps(const std::vector<EventImplPtr> &EventImpls,
317317
}
318318
}
319319

320+
namespace {
321+
322+
struct EnqueueNativeCommandData {
323+
sycl::interop_handle ih;
324+
std::function<void(interop_handle)> func;
325+
};
326+
327+
void InteropFreeFunc(pi_queue InteropQueue, void *InteropData) {
328+
auto *Data = reinterpret_cast<EnqueueNativeCommandData *>(InteropData);
329+
return Data->func(Data->ih);
330+
}
331+
} // namespace
332+
320333
class DispatchHostTask {
321334
ExecCGCommand *MThisCmd;
322335
std::vector<interop_handle::ReqToMem> MReqToMem;
336+
std::vector<pi_mem> MReqPiMem;
323337

324338
bool waitForEvents() const {
325339
std::map<const PluginPtr, std::vector<EventImplPtr>>
@@ -365,8 +379,10 @@ class DispatchHostTask {
365379

366380
public:
367381
DispatchHostTask(ExecCGCommand *ThisCmd,
368-
std::vector<interop_handle::ReqToMem> ReqToMem)
369-
: MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)) {}
382+
std::vector<interop_handle::ReqToMem> ReqToMem,
383+
std::vector<pi_mem> ReqPiMem)
384+
: MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)),
385+
MReqPiMem(std::move(ReqPiMem)) {}
370386

371387
void operator()() const {
372388
assert(MThisCmd->getCG().getType() == CG::CGTYPE::CodeplayHostTask);
@@ -402,8 +418,32 @@ class DispatchHostTask {
402418
interop_handle IH{MReqToMem, HostTask.MQueue,
403419
HostTask.MQueue->getDeviceImplPtr(),
404420
HostTask.MQueue->getContextImplPtr()};
405-
406-
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo(), IH);
421+
// TODO: should all the backends that support this entry point use this
422+
// for host task?
423+
auto &Queue = HostTask.MQueue;
424+
bool NativeCommandSupport = false;
425+
Queue->getPlugin()->call<PiApiKind::piDeviceGetInfo>(
426+
detail::getSyclObjImpl(Queue->get_device())->getHandleRef(),
427+
PI_EXT_ONEAPI_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT,
428+
sizeof(NativeCommandSupport), &NativeCommandSupport, nullptr);
429+
if (NativeCommandSupport) {
430+
EnqueueNativeCommandData CustomOpData{
431+
IH, HostTask.MHostTask->MInteropTask};
432+
433+
// We are assuming that we have already synchronized with the HT's
434+
// dependent events, and that the user will synchronize before the end
435+
// of the HT lambda. As such we don't pass in any events, or ask for
436+
// one back.
437+
//
438+
// This entry point is needed in order to migrate memory across
439+
// devices in the same context for CUDA and HIP backends
440+
Queue->getPlugin()->call<PiApiKind::piextEnqueueNativeCommand>(
441+
HostTask.MQueue->getHandleRef(), InteropFreeFunc, &CustomOpData,
442+
MReqPiMem.size(), MReqPiMem.data(), 0, nullptr, nullptr);
443+
} else {
444+
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo(),
445+
IH);
446+
}
407447
} else
408448
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo());
409449
} catch (...) {
@@ -3121,13 +3161,14 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
31213161
}
31223162

31233163
std::vector<interop_handle::ReqToMem> ReqToMem;
3164+
std::vector<pi_mem> ReqPiMem;
31243165

31253166
if (HostTask->MHostTask->isInteropTask()) {
31263167
// Extract the Mem Objects for all Requirements, to ensure they are
31273168
// available if a user asks for them inside the interop task scope
31283169
const std::vector<Requirement *> &HandlerReq =
31293170
HostTask->getRequirements();
3130-
auto ReqToMemConv = [&ReqToMem, HostTask](Requirement *Req) {
3171+
auto ReqToMemConv = [&ReqToMem, &ReqPiMem, HostTask](Requirement *Req) {
31313172
const std::vector<AllocaCommandBase *> &AllocaCmds =
31323173
Req->MSYCLMemObj->MRecord->MAllocaCommands;
31333174

@@ -3137,6 +3178,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
31373178
auto MemArg =
31383179
reinterpret_cast<pi_mem>(AllocaCmd->getMemAllocation());
31393180
ReqToMem.emplace_back(std::make_pair(Req, MemArg));
3181+
ReqPiMem.emplace_back(MemArg);
31403182

31413183
return;
31423184
}
@@ -3158,7 +3200,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
31583200
copySubmissionCodeLocation();
31593201

31603202
MQueue->getThreadPool().submit<DispatchHostTask>(
3161-
DispatchHostTask(this, std::move(ReqToMem)));
3203+
DispatchHostTask(this, std::move(ReqToMem), std::move(ReqPiMem)));
31623204

31633205
MShouldCompleteEventIfPossible = false;
31643206

sycl/test-e2e/HostInteropTask/interop-task-cuda-buffer-migrate.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
// REQUIRES: cuda
2-
// XFAIL: cuda
3-
//
4-
// FIXME: this is broken with a multi device context
52
//
63
// RUN: %{build} -o %t.out -lcuda
74
// RUN: %{run} %t.out
@@ -31,12 +28,6 @@ int main() {
3128
platform(gpu_selector_v).get_devices(info::device_type::gpu);
3229
std::cout << Devices.size() << " devices found" << std::endl;
3330

34-
if (Devices.size() == 1) {
35-
// Since this is XFAIL for Devices.size() > 1 we need to return failure if
36-
// test can't run
37-
return 1;
38-
}
39-
4031
context C(Devices);
4132

4233
int Index = 0;

sycl/test/abi/pi_cuda_symbol_check.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ piextDisablePeerAccess
114114
piextEnablePeerAccess
115115
piextEnqueueCommandBuffer
116116
piextEnqueueCooperativeKernelLaunch
117+
piextEnqueueNativeCommand
117118
piextEnqueueReadHostPipe
118119
piextEnqueueWriteHostPipe
119120
piextEventCreateWithNativeHandle

sycl/test/abi/pi_hip_symbol_check.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ piextDisablePeerAccess
114114
piextEnablePeerAccess
115115
piextEnqueueCommandBuffer
116116
piextEnqueueCooperativeKernelLaunch
117+
piextEnqueueNativeCommand
117118
piextEnqueueReadHostPipe
118119
piextEnqueueWriteHostPipe
119120
piextEventCreateWithNativeHandle

sycl/test/abi/pi_level_zero_symbol_check.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ piextDisablePeerAccess
113113
piextEnablePeerAccess
114114
piextEnqueueCommandBuffer
115115
piextEnqueueCooperativeKernelLaunch
116+
piextEnqueueNativeCommand
116117
piextEnqueueReadHostPipe
117118
piextEnqueueWriteHostPipe
118119
piextEventCreateWithNativeHandle

sycl/test/abi/pi_nativecpu_symbol_check.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ piextDisablePeerAccess
114114
piextEnablePeerAccess
115115
piextEnqueueCommandBuffer
116116
piextEnqueueCooperativeKernelLaunch
117+
piextEnqueueNativeCommand
117118
piextEnqueueReadHostPipe
118119
piextEnqueueWriteHostPipe
119120
piextEventCreateWithNativeHandle

sycl/test/abi/pi_opencl_symbol_check.dump

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ piextDisablePeerAccess
113113
piextEnablePeerAccess
114114
piextEnqueueCommandBuffer
115115
piextEnqueueCooperativeKernelLaunch
116+
piextEnqueueNativeCommand
116117
piextEnqueueReadHostPipe
117118
piextEnqueueWriteHostPipe
118119
piextEventCreateWithNativeHandle

0 commit comments

Comments
 (0)