@@ -342,6 +342,34 @@ pi_result enqueueEventsWait(pi_queue command_queue, CUstream stream,
342
342
}
343
343
}
344
344
345
+ template <typename PtrT>
346
+ void getUSMHostOrDevicePtr (PtrT usm_ptr, CUmemorytype *out_mem_type,
347
+ CUdeviceptr *out_dev_ptr, PtrT *out_host_ptr) {
348
+ // do not throw if cuPointerGetAttribute returns CUDA_ERROR_INVALID_VALUE
349
+ // checks with PI_CHECK_ERROR are not suggested
350
+ CUresult ret = cuPointerGetAttribute (
351
+ out_mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, (CUdeviceptr)usm_ptr);
352
+ assert ((*out_mem_type != CU_MEMORYTYPE_ARRAY &&
353
+ *out_mem_type != CU_MEMORYTYPE_UNIFIED) &&
354
+ " ARRAY, UNIFIED types are not supported!" );
355
+
356
+ // pointer not known to the CUDA subsystem (possibly a system allocated ptr)
357
+ if (ret == CUDA_ERROR_INVALID_VALUE) {
358
+ *out_mem_type = CU_MEMORYTYPE_HOST;
359
+ *out_dev_ptr = 0 ;
360
+ *out_host_ptr = usm_ptr;
361
+
362
+ // todo: resets the above "non-stick" error
363
+ } else if (ret == CUDA_SUCCESS) {
364
+ *out_dev_ptr = (*out_mem_type == CU_MEMORYTYPE_DEVICE)
365
+ ? reinterpret_cast <CUdeviceptr>(usm_ptr)
366
+ : 0 ;
367
+ *out_host_ptr = (*out_mem_type == CU_MEMORYTYPE_HOST) ? usm_ptr : nullptr ;
368
+ } else {
369
+ PI_CHECK_ERROR (ret);
370
+ }
371
+ }
372
+
345
373
} // anonymous namespace
346
374
347
375
// / ------ Error handling, matching OpenCL plugin semantics.
@@ -998,7 +1026,6 @@ pi_result cuda_piContextGetInfo(pi_context context, pi_context_info param_name,
998
1026
capabilities);
999
1027
}
1000
1028
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
1001
- // 2D USM memcpy is supported.
1002
1029
return getInfo<pi_bool>(param_value_size, param_value, param_value_size_ret,
1003
1030
true );
1004
1031
case PI_EXT_ONEAPI_CONTEXT_INFO_USM_FILL2D_SUPPORT:
@@ -5261,39 +5288,17 @@ pi_result cuda_piextUSMEnqueueMemcpy2D(pi_queue queue, pi_bool blocking,
5261
5288
(*event)->start ();
5262
5289
}
5263
5290
5264
- // Determine the direction of Copy using cuPointerGetAttributes
5291
+ // Determine the direction of copy using cuPointerGetAttribute
5265
5292
// for both the src_ptr and dst_ptr
5266
- // TODO: Doesn't yet support CU_MEMORYTYPE_UNIFIED
5267
- CUpointer_attribute attributes = {CU_POINTER_ATTRIBUTE_MEMORY_TYPE};
5268
-
5269
- CUmemorytype src_type = static_cast <CUmemorytype>(0 );
5270
- void *src_attribute_values[] = {(void *)(&src_type)};
5271
- result = PI_CHECK_ERROR (cuPointerGetAttributes (
5272
- 1 , &attributes, src_attribute_values, (CUdeviceptr)src_ptr));
5273
- assert (src_type == CU_MEMORYTYPE_DEVICE || src_type == CU_MEMORYTYPE_HOST);
5274
-
5275
- CUmemorytype dst_type = static_cast <CUmemorytype>(0 );
5276
- void *dst_attribute_values[] = {(void *)(&dst_type)};
5277
- result = PI_CHECK_ERROR (cuPointerGetAttributes (
5278
- 1 , &attributes, dst_attribute_values, (CUdeviceptr)dst_ptr));
5279
- assert (dst_type == CU_MEMORYTYPE_DEVICE || dst_type == CU_MEMORYTYPE_HOST);
5280
-
5281
5293
CUDA_MEMCPY2D cpyDesc = {0 };
5282
5294
5283
- cpyDesc.srcMemoryType = src_type;
5284
- cpyDesc.srcDevice = (src_type == CU_MEMORYTYPE_DEVICE)
5285
- ? reinterpret_cast <CUdeviceptr>(src_ptr)
5286
- : 0 ;
5287
- cpyDesc.srcHost = (src_type == CU_MEMORYTYPE_HOST) ? src_ptr : nullptr ;
5288
- cpyDesc.srcPitch = src_pitch;
5295
+ getUSMHostOrDevicePtr (src_ptr, &cpyDesc.srcMemoryType , &cpyDesc.srcDevice ,
5296
+ &cpyDesc.srcHost );
5297
+ getUSMHostOrDevicePtr (dst_ptr, &cpyDesc.dstMemoryType , &cpyDesc.dstDevice ,
5298
+ &cpyDesc.dstHost );
5289
5299
5290
- cpyDesc.dstMemoryType = dst_type;
5291
- cpyDesc.dstDevice = (dst_type == CU_MEMORYTYPE_DEVICE)
5292
- ? reinterpret_cast <CUdeviceptr>(dst_ptr)
5293
- : 0 ;
5294
- cpyDesc.dstHost = (dst_type == CU_MEMORYTYPE_HOST) ? dst_ptr : nullptr ;
5295
5300
cpyDesc.dstPitch = dst_pitch;
5296
-
5301
+ cpyDesc. srcPitch = src_pitch;
5297
5302
cpyDesc.WidthInBytes = width;
5298
5303
cpyDesc.Height = height;
5299
5304
0 commit comments