Skip to content

Commit 099af9c

Browse files
authored
[SYCL][HIP] Use valid cast for CUdeviceptr in HIP PI (#6207)
Replace invalid reinterpret_cast with static_cast for CUdeviceptr in hip_piextMemGetNativeHandle when building HIP PI on CUDA machine; Signed-off-by: Lukas Sommer <[email protected]>
1 parent 0bfffd6 commit 099af9c

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

sycl/plugins/hip/pi_hip.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,8 +2132,27 @@ pi_result hip_piMemGetInfo(pi_mem memObj, pi_mem_info queriedInfo,
21322132
/// \return PI_SUCCESS
21332133
pi_result hip_piextMemGetNativeHandle(pi_mem mem,
21342134
pi_native_handle *nativeHandle) {
2135+
#if defined(__HIP_PLATFORM_NVIDIA__)
2136+
if (sizeof(_pi_mem::mem_::buffer_mem_::native_type) >
2137+
sizeof(pi_native_handle)) {
2138+
// Check that all the upper bits that cannot be represented by
2139+
// pi_native_handle are empty.
2140+
// NOTE: The following shift might trigger a warning, but the check in the
2141+
// if above makes sure that this does not underflow.
2142+
_pi_mem::mem_::buffer_mem_::native_type upperBits =
2143+
mem->mem_.buffer_mem_.get() >> (sizeof(pi_native_handle) * CHAR_BIT);
2144+
if (upperBits) {
2145+
// Return an error if any of the remaining bits is non-zero.
2146+
return PI_INVALID_MEM_OBJECT;
2147+
}
2148+
}
2149+
*nativeHandle = static_cast<pi_native_handle>(mem->mem_.buffer_mem_.get());
2150+
#elif defined(__HIP_PLATFORM_AMD__)
21352151
*nativeHandle =
21362152
reinterpret_cast<pi_native_handle>(mem->mem_.buffer_mem_.get());
2153+
#else
2154+
#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
2155+
#endif
21372156
return PI_SUCCESS;
21382157
}
21392158

0 commit comments

Comments
 (0)