Skip to content

[CUDA][HIP] Fix host task mem migration and add pi entry point for urEnqueueNativeCommandExp #14353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sycl/include/sycl/detail/host_task_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class HostTask {
if (HPI)
HPI->end();
}

friend class DispatchHostTask;
};

class CGHostTask : public CG {
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/sycl/detail/pi.def
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,7 @@ _PI_API(piextVirtualMemUnmap)
_PI_API(piextVirtualMemSetAccess)
_PI_API(piextVirtualMemGetInfo)

// Enqueue native command
_PI_API(piextEnqueueNativeCommand)

#undef _PI_API
25 changes: 24 additions & 1 deletion sycl/include/sycl/detail/pi.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@
// PI_EXT_ONEAPI_DEVICE_INFO_SUPPORTS_VIRTUAL_MEM device info descriptor,
// _pi_virtual_mem_granularity_info enum, _pi_virtual_mem_info enum and
// pi_virtual_access_flags bit flags.
// 15.55 Added piextEnqueueNativeCommand as well as associated types and enums

#define _PI_H_VERSION_MAJOR 15
#define _PI_H_VERSION_MINOR 54
#define _PI_H_VERSION_MINOR 55

#define _PI_STRING_HELPER(a) #a
#define _PI_CONCAT(a, b) _PI_STRING_HELPER(a.b)
Expand Down Expand Up @@ -512,6 +513,8 @@ typedef enum {

// Virtual memory support
PI_EXT_ONEAPI_DEVICE_INFO_SUPPORTS_VIRTUAL_MEM = 0x2011E,
// Native enqueue
PI_EXT_ONEAPI_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT = 0x2011F,
} _pi_device_info;

typedef enum {
Expand Down Expand Up @@ -1279,6 +1282,7 @@ using pi_image_mem_handle = void *;
using pi_interop_mem_handle = pi_uint64;
using pi_interop_semaphore_handle = pi_uint64;
using pi_physical_mem = _pi_physical_mem *;
using pi_enqueue_native_command_function = void (*)(pi_queue, void *);

typedef struct {
pi_image_channel_order image_channel_order;
Expand Down Expand Up @@ -3201,6 +3205,25 @@ __SYCL_EXPORT pi_result piextSignalExternalSemaphore(
pi_uint32 num_events_in_wait_list, const pi_event *event_wait_list,
pi_event *event);

/// API to enqueue work through a backend API such that the plugin can schedule
/// the backend API calls within its own DAG.
///
/// \param command_queue is the queue instructed to signal
/// \param fn is the user submitted native function enqueueing work to a
/// backend API
/// \param data is the data that will be used in fn
/// \param num_mems is the number of mems in mem_list
/// \param mem_list is the list of mems that are used in fn
/// \param num_events_in_wait_list is the number of events in the wait list
/// \param event_wait_list is the list of events to wait on before this
/// operation
/// \param event is the returned event representing this operation
__SYCL_EXPORT pi_result piextEnqueueNativeCommand(
pi_queue command_queue, pi_enqueue_native_command_function fn, void *data,
pi_uint32 num_mems, const pi_mem *mem_list,
pi_uint32 num_events_in_wait_list, const pi_event *event_wait_list,
pi_event *event);

typedef enum {
_PI_SANITIZE_TYPE_NONE = 0x0,
_PI_SANITIZE_TYPE_ADDRESS = 0x1,
Expand Down
10 changes: 10 additions & 0 deletions sycl/plugins/cuda/pi_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,16 @@ pi_result piextVirtualMemGetInfo(pi_context context, const void *ptr,
param_value_size_ret);
}

pi_result
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event) {
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
NumEventsInWaitList, EventWaitList,
Event);
}

const char SupportedVersion[] = _PI_CUDA_PLUGIN_VERSION_STRING;

pi_result piPluginInit(pi_plugin *PluginInit) {
Expand Down
10 changes: 10 additions & 0 deletions sycl/plugins/hip/pi_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,16 @@ pi_result piextVirtualMemGetInfo(pi_context context, const void *ptr,
param_value_size_ret);
}

pi_result
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event) {
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
NumEventsInWaitList, EventWaitList,
Event);
}

const char SupportedVersion[] = _PI_HIP_PLUGIN_VERSION_STRING;

pi_result piPluginInit(pi_plugin *PluginInit) {
Expand Down
10 changes: 10 additions & 0 deletions sycl/plugins/level_zero/pi_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,16 @@ pi_result piextVirtualMemGetInfo(pi_context Context, const void *Ptr,
ParamValueSizeRet);
}

pi_result
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event) {
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
NumEventsInWaitList, EventWaitList,
Event);
}

const char SupportedVersion[] = _PI_LEVEL_ZERO_PLUGIN_VERSION_STRING;

pi_result piPluginInit(pi_plugin *PluginInit) { // missing
Expand Down
10 changes: 10 additions & 0 deletions sycl/plugins/native_cpu/pi_native_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,16 @@ pi_result piextVirtualMemGetInfo(pi_context context, const void *ptr,
param_value_size_ret);
}

pi_result
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event) {
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
NumEventsInWaitList, EventWaitList,
Event);
}

// Initialize function table with stubs.
#define _PI_API(api) \
(PluginInit->PiFunctionTable).api = (decltype(&::api))(&api);
Expand Down
10 changes: 10 additions & 0 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,16 @@ pi_result piextVirtualMemGetInfo(pi_context Context, const void *Ptr,
ParamValueSizeRet);
}

pi_result
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event) {
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
NumEventsInWaitList, EventWaitList,
Event);
}

pi_result piTearDown(void *PluginParameter) {
return pi2ur::piTearDown(PluginParameter);
}
Expand Down
30 changes: 29 additions & 1 deletion sycl/plugins/unified_runtime/pi2ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,9 @@ inline pi_result piDeviceGetInfo(pi_device Device, pi_device_info ParamName,
PI_TO_UR_MAP_DEVICE_INFO(
PI_EXT_ONEAPI_DEVICE_INFO_TIMESTAMP_RECORDING_SUPPORT,
UR_DEVICE_INFO_TIMESTAMP_RECORDING_SUPPORT_EXP)
PI_TO_UR_MAP_DEVICE_INFO(
PI_EXT_ONEAPI_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT,
UR_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT_EXP)
PI_TO_UR_MAP_DEVICE_INFO(PI_EXT_INTEL_DEVICE_INFO_ESIMD_SUPPORT,
UR_DEVICE_INFO_ESIMD_SUPPORT)
PI_TO_UR_MAP_DEVICE_INFO(PI_EXT_ONEAPI_DEVICE_INFO_COMPONENT_DEVICES,
Expand Down Expand Up @@ -5722,7 +5725,6 @@ piextVirtualMemGranularityGetInfo(pi_context Context, pi_device Device,
HANDLE_ERRORS(urVirtualMemGranularityGetInfo(UrContext, UrDevice, InfoType,
ParamValueSize, ParamValue,
ParamValueSizeRet));

return PI_SUCCESS;
}

Expand Down Expand Up @@ -5882,4 +5884,30 @@ inline pi_result piextVirtualMemGetInfo(pi_context Context, const void *Ptr,
// Virtual Memory
///////////////////////////////////////////////////////////////////////////////

///////////////////////////////////////////////////////////////////////////////
// Enqueue Native Command Extension
inline pi_result
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
void *Data, pi_uint32 NumMems, const pi_mem *MemList,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event) {
PI_ASSERT(Queue, PI_ERROR_INVALID_QUEUE);

auto UrQueue = reinterpret_cast<ur_queue_handle_t>(Queue);
auto UrFn = reinterpret_cast<void (*)(ur_queue_handle_t, void *)>(Fn);
const ur_mem_handle_t *UrMemList =
reinterpret_cast<const ur_mem_handle_t *>(MemList);
const ur_event_handle_t *UrEventWaitList =
reinterpret_cast<const ur_event_handle_t *>(EventWaitList);
ur_event_handle_t *UREvent = reinterpret_cast<ur_event_handle_t *>(Event);

HANDLE_ERRORS(urEnqueueNativeCommandExp(
UrQueue, UrFn, Data, NumMems, UrMemList, nullptr /*pProperties*/,
NumEventsInWaitList, UrEventWaitList, UREvent));

return PI_SUCCESS;
}
// Enqueue Native Command Extension
///////////////////////////////////////////////////////////////////////////////

} // namespace pi2ur
10 changes: 10 additions & 0 deletions sycl/plugins/unified_runtime/pi_unified_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,16 @@ __SYCL_EXPORT pi_result piextSignalExternalSemaphore(
EventWaitList, Event);
}

pi_result
piextEnqueueNativeCommand(pi_queue Queue, pi_enqueue_native_command_function Fn,
void *Data, pi_uint32 NumMems, const pi_mem *Mems,
pi_uint32 NumEventsInWaitList,
const pi_event *EventWaitList, pi_event *Event) {
return pi2ur::piextEnqueueNativeCommand(Queue, Fn, Data, NumMems, Mems,
NumEventsInWaitList, EventWaitList,
Event);
}

// This interface is not in Unified Runtime currently
__SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
PI_ASSERT(PluginInit, PI_ERROR_INVALID_VALUE);
Expand Down
54 changes: 48 additions & 6 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,23 @@ static void flushCrossQueueDeps(const std::vector<EventImplPtr> &EventImpls,
}
}

namespace {

struct EnqueueNativeCommandData {
sycl::interop_handle ih;
std::function<void(interop_handle)> func;
};

void InteropFreeFunc(pi_queue InteropQueue, void *InteropData) {
auto *Data = reinterpret_cast<EnqueueNativeCommandData *>(InteropData);
return Data->func(Data->ih);
}
} // namespace

class DispatchHostTask {
ExecCGCommand *MThisCmd;
std::vector<interop_handle::ReqToMem> MReqToMem;
std::vector<pi_mem> MReqPiMem;

bool waitForEvents() const {
std::map<const PluginPtr, std::vector<EventImplPtr>>
Expand Down Expand Up @@ -365,8 +379,10 @@ class DispatchHostTask {

public:
DispatchHostTask(ExecCGCommand *ThisCmd,
std::vector<interop_handle::ReqToMem> ReqToMem)
: MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)) {}
std::vector<interop_handle::ReqToMem> ReqToMem,
std::vector<pi_mem> ReqPiMem)
: MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)),
MReqPiMem(std::move(ReqPiMem)) {}

void operator()() const {
assert(MThisCmd->getCG().getType() == CG::CGTYPE::CodeplayHostTask);
Expand Down Expand Up @@ -402,8 +418,32 @@ class DispatchHostTask {
interop_handle IH{MReqToMem, HostTask.MQueue,
HostTask.MQueue->getDeviceImplPtr(),
HostTask.MQueue->getContextImplPtr()};

HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo(), IH);
// TODO: should all the backends that support this entry point use this
// for host task?
auto &Queue = HostTask.MQueue;
bool NativeCommandSupport = false;
Queue->getPlugin()->call<PiApiKind::piDeviceGetInfo>(
detail::getSyclObjImpl(Queue->get_device())->getHandleRef(),
PI_EXT_ONEAPI_DEVICE_INFO_ENQUEUE_NATIVE_COMMAND_SUPPORT,
sizeof(NativeCommandSupport), &NativeCommandSupport, nullptr);
if (NativeCommandSupport) {
EnqueueNativeCommandData CustomOpData{
IH, HostTask.MHostTask->MInteropTask};

// We are assuming that we have already synchronized with the HT's
// dependent events, and that the user will synchronize before the end
// of the HT lambda. As such we don't pass in any events, or ask for
// one back.
//
// This entry point is needed in order to migrate memory across
// devices in the same context for CUDA and HIP backends
Queue->getPlugin()->call<PiApiKind::piextEnqueueNativeCommand>(
HostTask.MQueue->getHandleRef(), InteropFreeFunc, &CustomOpData,
MReqPiMem.size(), MReqPiMem.data(), 0, nullptr, nullptr);
} else {
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo(),
IH);
}
} else
HostTask.MHostTask->call(MThisCmd->MEvent->getHostProfilingInfo());
} catch (...) {
Expand Down Expand Up @@ -3121,13 +3161,14 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
}

std::vector<interop_handle::ReqToMem> ReqToMem;
std::vector<pi_mem> ReqPiMem;

if (HostTask->MHostTask->isInteropTask()) {
// Extract the Mem Objects for all Requirements, to ensure they are
// available if a user asks for them inside the interop task scope
const std::vector<Requirement *> &HandlerReq =
HostTask->getRequirements();
auto ReqToMemConv = [&ReqToMem, HostTask](Requirement *Req) {
auto ReqToMemConv = [&ReqToMem, &ReqPiMem, HostTask](Requirement *Req) {
const std::vector<AllocaCommandBase *> &AllocaCmds =
Req->MSYCLMemObj->MRecord->MAllocaCommands;

Expand All @@ -3137,6 +3178,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
auto MemArg =
reinterpret_cast<pi_mem>(AllocaCmd->getMemAllocation());
ReqToMem.emplace_back(std::make_pair(Req, MemArg));
ReqPiMem.emplace_back(MemArg);

return;
}
Expand All @@ -3158,7 +3200,7 @@ pi_int32 ExecCGCommand::enqueueImpQueue() {
copySubmissionCodeLocation();

MQueue->getThreadPool().submit<DispatchHostTask>(
DispatchHostTask(this, std::move(ReqToMem)));
DispatchHostTask(this, std::move(ReqToMem), std::move(ReqPiMem)));

MShouldCompleteEventIfPossible = false;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
// REQUIRES: cuda
// XFAIL: cuda
//
// FIXME: this is broken with a multi device context
//
// RUN: %{build} -o %t.out -lcuda
// RUN: %{run} %t.out
Expand Down Expand Up @@ -31,12 +28,6 @@ int main() {
platform(gpu_selector_v).get_devices(info::device_type::gpu);
std::cout << Devices.size() << " devices found" << std::endl;

if (Devices.size() == 1) {
// Since this is XFAIL for Devices.size() > 1 we need to return failure if
// test can't run
return 1;
}

context C(Devices);

int Index = 0;
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/pi_cuda_symbol_check.dump
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ piextDisablePeerAccess
piextEnablePeerAccess
piextEnqueueCommandBuffer
piextEnqueueCooperativeKernelLaunch
piextEnqueueNativeCommand
piextEnqueueReadHostPipe
piextEnqueueWriteHostPipe
piextEventCreateWithNativeHandle
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/pi_hip_symbol_check.dump
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ piextDisablePeerAccess
piextEnablePeerAccess
piextEnqueueCommandBuffer
piextEnqueueCooperativeKernelLaunch
piextEnqueueNativeCommand
piextEnqueueReadHostPipe
piextEnqueueWriteHostPipe
piextEventCreateWithNativeHandle
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/pi_level_zero_symbol_check.dump
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ piextDisablePeerAccess
piextEnablePeerAccess
piextEnqueueCommandBuffer
piextEnqueueCooperativeKernelLaunch
piextEnqueueNativeCommand
piextEnqueueReadHostPipe
piextEnqueueWriteHostPipe
piextEventCreateWithNativeHandle
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/pi_nativecpu_symbol_check.dump
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ piextDisablePeerAccess
piextEnablePeerAccess
piextEnqueueCommandBuffer
piextEnqueueCooperativeKernelLaunch
piextEnqueueNativeCommand
piextEnqueueReadHostPipe
piextEnqueueWriteHostPipe
piextEventCreateWithNativeHandle
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/pi_opencl_symbol_check.dump
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ piextDisablePeerAccess
piextEnablePeerAccess
piextEnqueueCommandBuffer
piextEnqueueCooperativeKernelLaunch
piextEnqueueNativeCommand
piextEnqueueReadHostPipe
piextEnqueueWriteHostPipe
piextEventCreateWithNativeHandle
Expand Down
Loading
Loading