Skip to content

Commit d0b25d4

Browse files
authored
[SYCL][CUDA] Support host-device memcpy2D (#8181)
Addresses to support host-device memcpy2D copies
1 parent b66236a commit d0b25d4

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,34 @@ pi_result enqueueEventsWait(pi_queue command_queue, CUstream stream,
342342
}
343343
}
344344

345+
template <typename PtrT>
346+
void getUSMHostOrDevicePtr(PtrT usm_ptr, CUmemorytype *out_mem_type,
347+
CUdeviceptr *out_dev_ptr, PtrT *out_host_ptr) {
348+
// do not throw if cuPointerGetAttribute returns CUDA_ERROR_INVALID_VALUE
349+
// checks with PI_CHECK_ERROR are not suggested
350+
CUresult ret = cuPointerGetAttribute(
351+
out_mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, (CUdeviceptr)usm_ptr);
352+
assert((*out_mem_type != CU_MEMORYTYPE_ARRAY &&
353+
*out_mem_type != CU_MEMORYTYPE_UNIFIED) &&
354+
"ARRAY, UNIFIED types are not supported!");
355+
356+
// pointer not known to the CUDA subsystem (possibly a system allocated ptr)
357+
if (ret == CUDA_ERROR_INVALID_VALUE) {
358+
*out_mem_type = CU_MEMORYTYPE_HOST;
359+
*out_dev_ptr = 0;
360+
*out_host_ptr = usm_ptr;
361+
362+
// todo: resets the above "non-stick" error
363+
} else if (ret == CUDA_SUCCESS) {
364+
*out_dev_ptr = (*out_mem_type == CU_MEMORYTYPE_DEVICE)
365+
? reinterpret_cast<CUdeviceptr>(usm_ptr)
366+
: 0;
367+
*out_host_ptr = (*out_mem_type == CU_MEMORYTYPE_HOST) ? usm_ptr : nullptr;
368+
} else {
369+
PI_CHECK_ERROR(ret);
370+
}
371+
}
372+
345373
} // anonymous namespace
346374

347375
/// ------ Error handling, matching OpenCL plugin semantics.
@@ -998,7 +1026,6 @@ pi_result cuda_piContextGetInfo(pi_context context, pi_context_info param_name,
9981026
capabilities);
9991027
}
10001028
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
1001-
// 2D USM memcpy is supported.
10021029
return getInfo<pi_bool>(param_value_size, param_value, param_value_size_ret,
10031030
true);
10041031
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_FILL2D_SUPPORT:
@@ -5261,39 +5288,17 @@ pi_result cuda_piextUSMEnqueueMemcpy2D(pi_queue queue, pi_bool blocking,
52615288
(*event)->start();
52625289
}
52635290

5264-
// Determine the direction of Copy using cuPointerGetAttributes
5291+
// Determine the direction of copy using cuPointerGetAttribute
52655292
// for both the src_ptr and dst_ptr
5266-
// TODO: Doesn't yet support CU_MEMORYTYPE_UNIFIED
5267-
CUpointer_attribute attributes = {CU_POINTER_ATTRIBUTE_MEMORY_TYPE};
5268-
5269-
CUmemorytype src_type = static_cast<CUmemorytype>(0);
5270-
void *src_attribute_values[] = {(void *)(&src_type)};
5271-
result = PI_CHECK_ERROR(cuPointerGetAttributes(
5272-
1, &attributes, src_attribute_values, (CUdeviceptr)src_ptr));
5273-
assert(src_type == CU_MEMORYTYPE_DEVICE || src_type == CU_MEMORYTYPE_HOST);
5274-
5275-
CUmemorytype dst_type = static_cast<CUmemorytype>(0);
5276-
void *dst_attribute_values[] = {(void *)(&dst_type)};
5277-
result = PI_CHECK_ERROR(cuPointerGetAttributes(
5278-
1, &attributes, dst_attribute_values, (CUdeviceptr)dst_ptr));
5279-
assert(dst_type == CU_MEMORYTYPE_DEVICE || dst_type == CU_MEMORYTYPE_HOST);
5280-
52815293
CUDA_MEMCPY2D cpyDesc = {0};
52825294

5283-
cpyDesc.srcMemoryType = src_type;
5284-
cpyDesc.srcDevice = (src_type == CU_MEMORYTYPE_DEVICE)
5285-
? reinterpret_cast<CUdeviceptr>(src_ptr)
5286-
: 0;
5287-
cpyDesc.srcHost = (src_type == CU_MEMORYTYPE_HOST) ? src_ptr : nullptr;
5288-
cpyDesc.srcPitch = src_pitch;
5295+
getUSMHostOrDevicePtr(src_ptr, &cpyDesc.srcMemoryType, &cpyDesc.srcDevice,
5296+
&cpyDesc.srcHost);
5297+
getUSMHostOrDevicePtr(dst_ptr, &cpyDesc.dstMemoryType, &cpyDesc.dstDevice,
5298+
&cpyDesc.dstHost);
52895299

5290-
cpyDesc.dstMemoryType = dst_type;
5291-
cpyDesc.dstDevice = (dst_type == CU_MEMORYTYPE_DEVICE)
5292-
? reinterpret_cast<CUdeviceptr>(dst_ptr)
5293-
: 0;
5294-
cpyDesc.dstHost = (dst_type == CU_MEMORYTYPE_HOST) ? dst_ptr : nullptr;
52955300
cpyDesc.dstPitch = dst_pitch;
5296-
5301+
cpyDesc.srcPitch = src_pitch;
52975302
cpyDesc.WidthInBytes = width;
52985303
cpyDesc.Height = height;
52995304

0 commit comments

Comments
 (0)