Skip to content

[SYCL][USM] Add more extensive checks to USM pointer queries #1123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions sycl/source/detail/usm/usm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ void *aligned_alloc(size_t Alignment, size_t Size, const queue &Q, alloc Kind) {
/// @param ptr is the USM pointer to query
/// @param ctxt is the sycl context the ptr was allocated in
alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
if (!Ptr)
return alloc::unknown;

// Everything on a host device is just system malloc so call it host
if (Ctxt.is_host())
return alloc::host;
Expand All @@ -251,8 +254,18 @@ alloc get_pointer_type(const void *Ptr, const context &Ctxt) {

// query type using PI function
const detail::plugin &Plugin = CtxImpl->getPlugin();
Plugin.call<detail::PiApiKind::piextUSMGetMemAllocInfo>(
PICtx, Ptr, PI_MEM_ALLOC_TYPE, sizeof(pi_usm_type), &AllocTy, nullptr);
RT::PiResult Err =
Plugin.call_nocheck<detail::PiApiKind::piextUSMGetMemAllocInfo>(
PICtx, Ptr, PI_MEM_ALLOC_TYPE, sizeof(pi_usm_type), &AllocTy,
nullptr);

// PI_INVALID_VALUE means USM doesn't know about this ptr
if (Err == PI_INVALID_VALUE)
return alloc::unknown;
// otherwise PI_SUCCESS is expected
if (Err != PI_SUCCESS) {
throw runtime_error("Error querying USM pointer: ", Err);
}

alloc ResultAlloc;
switch (AllocTy) {
Expand All @@ -278,6 +291,10 @@ alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
/// @param ptr is the USM pointer to query
/// @param ctxt is the sycl context the ptr was allocated in
device get_pointer_device(const void *Ptr, const context &Ctxt) {
// Check if ptr is a valid USM pointer
if (get_pointer_type(Ptr, Ctxt) == alloc::unknown)
throw runtime_error("Ptr not a valid USM allocation!");

// Just return the host device in the host context
if (Ctxt.is_host())
return Ctxt.get_devices()[0];
Expand Down
27 changes: 27 additions & 0 deletions sycl/test/usm/pointer_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,32 @@ int main() {
}
free(array, ctxt);

// Test invalid ptrs
Kind = get_pointer_type(nullptr, ctxt);
if (Kind != usm::alloc::unknown) {
return 11;
}

// next checks only valid for non-host contexts
array = (int*)malloc(N*sizeof(int));
Kind = get_pointer_type(array, ctxt);
if (!ctxt.is_host()) {
if (Kind != usm::alloc::unknown) {
return 12;
}
try {
D = get_pointer_device(array, ctxt);
} catch (runtime_error) {
return 0;
}
return 13;
} else {
// host ctxts always report host
if (Kind != usm::alloc::host) {
return 14;
}
}
free(array);

return 0;
}