Skip to content

Commit 9008a5d

Browse files
authored
[SYCL][CUDA][HIP] Implement piextUSMEnqueueMemcpy2D (#7941)
1 parent 24c2aa8 commit 9008a5d

File tree

3 files changed

+149
-21
lines changed

3 files changed

+149
-21
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,9 @@ pi_result cuda_piContextGetInfo(pi_context context, pi_context_info param_name,
10671067
capabilities);
10681068
}
10691069
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
1070+
// 2D USM memcpy is supported.
1071+
return getInfo<pi_bool>(param_value_size, param_value, param_value_size_ret,
1072+
true);
10701073
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_FILL2D_SUPPORT:
10711074
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMSET2D_SUPPORT:
10721075
// 2D USM operations currently not supported.
@@ -1949,10 +1952,12 @@ pi_result cuda_piDeviceGetInfo(pi_device device, pi_device_info param_name,
19491952
CUresult current_ctx_device_ret = cuCtxGetDevice(&current_ctx_device);
19501953
if (current_ctx_device_ret != CUDA_ERROR_INVALID_CONTEXT)
19511954
PI_CHECK_ERROR(current_ctx_device_ret);
1952-
bool need_primary_ctx = current_ctx_device_ret == CUDA_ERROR_INVALID_CONTEXT ||
1953-
current_ctx_device != device->get();
1955+
bool need_primary_ctx =
1956+
current_ctx_device_ret == CUDA_ERROR_INVALID_CONTEXT ||
1957+
current_ctx_device != device->get();
19541958
if (need_primary_ctx) {
1955-
// Use the primary context for the device if no context with the device is set.
1959+
// Use the primary context for the device if no context with the device is
1960+
// set.
19561961
CUcontext primary_context;
19571962
PI_CHECK_ERROR(cuDevicePrimaryCtxRetain(&primary_context, device->get()));
19581963
PI_CHECK_ERROR(cuCtxSetCurrent(primary_context));
@@ -5383,14 +5388,91 @@ pi_result cuda_piextUSMEnqueueMemset2D(pi_queue, void *, size_t, int, size_t,
53835388
return {};
53845389
}
53855390

5386-
// TODO: Implement this. Remember to return true for
5387-
// PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT when it is implemented.
5388-
pi_result cuda_piextUSMEnqueueMemcpy2D(pi_queue, pi_bool, void *, size_t,
5389-
const void *, size_t, size_t, size_t,
5390-
pi_uint32, const pi_event *,
5391-
pi_event *) {
5392-
sycl::detail::pi::die("piextUSMEnqueueMemcpy2D not implemented");
5393-
return {};
5391+
/// 2D Memcpy API
5392+
///
5393+
/// \param queue is the queue to submit to
5394+
/// \param blocking is whether this operation should block the host
5395+
/// \param dst_ptr is the location the data will be copied
5396+
/// \param dst_pitch is the total width of the destination memory including
5397+
/// padding
5398+
/// \param src_ptr is the data to be copied
5399+
/// \param dst_pitch is the total width of the source memory including padding
5400+
/// \param width is width in bytes of each row to be copied
5401+
/// \param height is height the columns to be copied
5402+
/// \param num_events_in_waitlist is the number of events to wait on
5403+
/// \param events_waitlist is an array of events to wait on
5404+
/// \param event is the event that represents this operation
5405+
pi_result cuda_piextUSMEnqueueMemcpy2D(pi_queue queue, pi_bool blocking,
5406+
void *dst_ptr, size_t dst_pitch,
5407+
const void *src_ptr, size_t src_pitch,
5408+
size_t width, size_t height,
5409+
pi_uint32 num_events_in_wait_list,
5410+
const pi_event *event_wait_list,
5411+
pi_event *event) {
5412+
5413+
assert(queue != nullptr);
5414+
5415+
pi_result result = PI_SUCCESS;
5416+
5417+
try {
5418+
ScopedContext active(queue->get_context());
5419+
CUstream cuStream = queue->get_next_transfer_stream();
5420+
result = enqueueEventsWait(queue, cuStream, num_events_in_wait_list,
5421+
event_wait_list);
5422+
if (event) {
5423+
(*event) = _pi_event::make_native(PI_COMMAND_TYPE_MEM_BUFFER_COPY_RECT,
5424+
queue, cuStream);
5425+
(*event)->start();
5426+
}
5427+
5428+
// Determine the direction of Copy using cuPointerGetAttributes
5429+
// for both the src_ptr and dst_ptr
5430+
// TODO: Doesn't yet support CU_MEMORYTYPE_UNIFIED
5431+
CUpointer_attribute attributes = {CU_POINTER_ATTRIBUTE_MEMORY_TYPE};
5432+
5433+
CUmemorytype src_type = static_cast<CUmemorytype>(0);
5434+
void *src_attribute_values[] = {(void *)(&src_type)};
5435+
result = PI_CHECK_ERROR(cuPointerGetAttributes(
5436+
1, &attributes, src_attribute_values, (CUdeviceptr)src_ptr));
5437+
assert(src_type == CU_MEMORYTYPE_DEVICE || src_type == CU_MEMORYTYPE_HOST);
5438+
5439+
CUmemorytype dst_type = static_cast<CUmemorytype>(0);
5440+
void *dst_attribute_values[] = {(void *)(&dst_type)};
5441+
result = PI_CHECK_ERROR(cuPointerGetAttributes(
5442+
1, &attributes, dst_attribute_values, (CUdeviceptr)dst_ptr));
5443+
assert(dst_type == CU_MEMORYTYPE_DEVICE || dst_type == CU_MEMORYTYPE_HOST);
5444+
5445+
CUDA_MEMCPY2D cpyDesc = {0};
5446+
5447+
cpyDesc.srcMemoryType = src_type;
5448+
cpyDesc.srcDevice = (src_type == CU_MEMORYTYPE_DEVICE)
5449+
? reinterpret_cast<CUdeviceptr>(src_ptr)
5450+
: 0;
5451+
cpyDesc.srcHost = (src_type == CU_MEMORYTYPE_HOST) ? src_ptr : nullptr;
5452+
cpyDesc.srcPitch = src_pitch;
5453+
5454+
cpyDesc.dstMemoryType = dst_type;
5455+
cpyDesc.dstDevice = (dst_type == CU_MEMORYTYPE_DEVICE)
5456+
? reinterpret_cast<CUdeviceptr>(dst_ptr)
5457+
: 0;
5458+
cpyDesc.dstHost = (dst_type == CU_MEMORYTYPE_HOST) ? dst_ptr : nullptr;
5459+
cpyDesc.dstPitch = dst_pitch;
5460+
5461+
cpyDesc.WidthInBytes = width;
5462+
cpyDesc.Height = height;
5463+
5464+
result = PI_CHECK_ERROR(cuMemcpy2DAsync(&cpyDesc, cuStream));
5465+
5466+
if (event) {
5467+
(*event)->record();
5468+
}
5469+
if (blocking) {
5470+
result = PI_CHECK_ERROR(cuStreamSynchronize(cuStream));
5471+
}
5472+
} catch (pi_result err) {
5473+
result = err;
5474+
}
5475+
return result;
53945476
}
53955477

53965478
/// API to query information about USM allocated pointers

sycl/plugins/hip/pi_hip.cpp

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,8 @@ pi_result hip_piContextGetInfo(pi_context context, pi_context_info param_name,
10001000
return getInfo(param_value_size, param_value, param_value_size_ret,
10011001
context->get_reference_count());
10021002
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
1003+
return getInfo<pi_bool>(param_value_size, param_value, param_value_size_ret,
1004+
true);
10031005
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_FILL2D_SUPPORT:
10041006
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMSET2D_SUPPORT:
10051007
// 2D USM operations currently not supported.
@@ -5122,13 +5124,57 @@ pi_result hip_piextUSMEnqueueMemset2D(pi_queue, void *, size_t, int, size_t,
51225124
return {};
51235125
}
51245126

5125-
// TODO: Implement this. Remember to return true for
5126-
// PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT when it is implemented.
5127-
pi_result hip_piextUSMEnqueueMemcpy2D(pi_queue, pi_bool, void *, size_t,
5128-
const void *, size_t, size_t, size_t,
5129-
pi_uint32, const pi_event *, pi_event *) {
5130-
sycl::detail::pi::die("piextUSMEnqueueMemcpy2D not implemented");
5131-
return {};
5127+
/// 2D Memcpy API
5128+
///
5129+
/// \param queue is the queue to submit to
5130+
/// \param blocking is whether this operation should block the host
5131+
/// \param dst_ptr is the location the data will be copied
5132+
/// \param dst_pitch is the total width of the destination memory including
5133+
/// padding
5134+
/// \param src_ptr is the data to be copied
5135+
/// \param dst_pitch is the total width of the source memory including padding
5136+
/// \param width is width in bytes of each row to be copied
5137+
/// \param height is height the columns to be copied
5138+
/// \param num_events_in_waitlist is the number of events to wait on
5139+
/// \param events_waitlist is an array of events to wait on
5140+
/// \param event is the event that represents this operation
5141+
pi_result hip_piextUSMEnqueueMemcpy2D(pi_queue queue, pi_bool blocking,
5142+
void *dst_ptr, size_t dst_pitch,
5143+
const void *src_ptr, size_t src_pitch,
5144+
size_t width, size_t height,
5145+
pi_uint32 num_events_in_wait_list,
5146+
const pi_event *event_wait_list,
5147+
pi_event *event) {
5148+
assert(queue != nullptr);
5149+
5150+
pi_result result = PI_SUCCESS;
5151+
5152+
try {
5153+
ScopedContext active(queue->get_context());
5154+
hipStream_t hipStream = queue->get_next_transfer_stream();
5155+
result = enqueueEventsWait(queue, hipStream, num_events_in_wait_list,
5156+
event_wait_list);
5157+
if (event) {
5158+
(*event) = _pi_event::make_native(PI_COMMAND_TYPE_MEM_BUFFER_COPY_RECT,
5159+
queue, hipStream);
5160+
(*event)->start();
5161+
}
5162+
5163+
result = PI_CHECK_ERROR(hipMemcpy2DAsync(dst_ptr, dst_pitch, src_ptr,
5164+
src_pitch, width, height,
5165+
hipMemcpyDefault, hipStream));
5166+
5167+
if (event) {
5168+
(*event)->record();
5169+
}
5170+
if (blocking) {
5171+
result = PI_CHECK_ERROR(hipStreamSynchronize(hipStream));
5172+
}
5173+
} catch (pi_result err) {
5174+
result = err;
5175+
}
5176+
5177+
return result;
51325178
}
51335179

51345180
/// API to query information about USM allocated pointers
@@ -5461,4 +5507,4 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
54615507

54625508
} // extern "C"
54635509

5464-
hipEvent_t _pi_platform::evBase_{nullptr};
5510+
hipEvent_t _pi_platform::evBase_{nullptr};

sycl/source/detail/memory_manager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ void MemoryManager::copy_usm(const void *SrcMem, QueueImplPtr SrcQueue,
858858

859859
const detail::plugin &Plugin = SrcQueue->getPlugin();
860860
Plugin.call<PiApiKind::piextUSMEnqueueMemcpy>(SrcQueue->getHandleRef(),
861-
/* blocking */ false, DstMem,
861+
/* blocking */ PI_FALSE, DstMem,
862862
SrcMem, Len, DepEvents.size(),
863863
DepEvents.data(), OutEvent);
864864
}
@@ -933,7 +933,7 @@ void MemoryManager::copy_2d_usm(const void *SrcMem, size_t SrcPitch,
933933
"NULL pointer argument in 2D memory copy operation.");
934934
const detail::plugin &Plugin = Queue->getPlugin();
935935
Plugin.call<PiApiKind::piextUSMEnqueueMemcpy2D>(
936-
Queue->getHandleRef(), /*blocking=*/false, DstMem, DstPitch, SrcMem,
936+
Queue->getHandleRef(), /*blocking=*/PI_FALSE, DstMem, DstPitch, SrcMem,
937937
SrcPitch, Width, Height, DepEvents.size(), DepEvents.data(), OutEvent);
938938
}
939939

0 commit comments

Comments
 (0)