Skip to content

Commit 926e38e

Browse files
authored
[SYCL][USM] Add support for APIs that query properties of USM pointers (#1016)
Signed-off-by: James Brodman <[email protected]>
1 parent 16ce311 commit 926e38e

File tree

3 files changed

+185
-0
lines changed

3 files changed

+185
-0
lines changed

sycl/include/CL/sycl/usm.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,5 +145,19 @@ T *aligned_alloc(size_t Alignment, size_t Count, const queue &Q,
145145
Kind);
146146
}
147147

148+
// Pointer queries
149+
/// Query the allocation type from a USM pointer
150+
///
151+
/// @param ptr is the USM pointer to query
152+
/// @param ctxt is the sycl context the ptr was allocated in
153+
usm::alloc get_pointer_type(const void *ptr, const context &ctxt);
154+
155+
/// Queries the device against which the pointer was allocated
156+
/// Throws an invalid_object_error if ptr is a host allocation.
157+
///
158+
/// @param ptr is the USM pointer to query
159+
/// @param ctxt is the sycl context the ptr was allocated in
160+
device get_pointer_device(const void *ptr, const context &ctxt);
161+
148162
} // namespace sycl
149163
} // namespace cl

sycl/source/detail/usm/usm_impl.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,5 +231,80 @@ void *aligned_alloc(size_t Alignment, size_t Size, const queue &Q, alloc Kind) {
231231
return aligned_alloc(Alignment, Size, Q.get_device(), Q.get_context(), Kind);
232232
}
233233

234+
// Pointer queries
235+
/// Query the allocation type from a USM pointer
236+
/// Returns alloc::host for all pointers in a host context.
237+
///
238+
/// @param ptr is the USM pointer to query
239+
/// @param ctxt is the sycl context the ptr was allocated in
240+
alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
241+
// Everything on a host device is just system malloc so call it host
242+
if (Ctxt.is_host())
243+
return alloc::host;
244+
245+
std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl(Ctxt);
246+
pi_context PICtx = CtxImpl->getHandleRef();
247+
pi_usm_type AllocTy;
248+
249+
// query type using PI function
250+
PI_CALL(piextUSMGetMemAllocInfo)(PICtx, Ptr, PI_MEM_ALLOC_TYPE,
251+
sizeof(pi_usm_type), &AllocTy, nullptr);
252+
253+
alloc ResultAlloc;
254+
switch (AllocTy) {
255+
case PI_MEM_TYPE_HOST:
256+
ResultAlloc = alloc::host;
257+
break;
258+
case PI_MEM_TYPE_DEVICE:
259+
ResultAlloc = alloc::device;
260+
break;
261+
case PI_MEM_TYPE_SHARED:
262+
ResultAlloc = alloc::shared;
263+
break;
264+
default:
265+
ResultAlloc = alloc::unknown;
266+
break;
267+
}
268+
269+
return ResultAlloc;
270+
}
271+
272+
/// Queries the device against which the pointer was allocated
273+
///
274+
/// @param ptr is the USM pointer to query
275+
/// @param ctxt is the sycl context the ptr was allocated in
276+
device get_pointer_device(const void *Ptr, const context &Ctxt) {
277+
// Just return the host device in the host context
278+
if (Ctxt.is_host())
279+
return Ctxt.get_devices()[0];
280+
281+
std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl(Ctxt);
282+
283+
// Check if ptr is a host allocation
284+
if (get_pointer_type(Ptr, Ctxt) == alloc::host) {
285+
auto Devs = CtxImpl->getDevices();
286+
if (Devs.size() == 0)
287+
throw runtime_error("No devices in passed context!");
288+
289+
// Just return the first device in the context
290+
return Devs[0];
291+
}
292+
293+
pi_context PICtx = CtxImpl->getHandleRef();
294+
pi_device DeviceId;
295+
296+
// query device using PI function
297+
PI_CALL(piextUSMGetMemAllocInfo)(PICtx, Ptr, PI_MEM_ALLOC_DEVICE,
298+
sizeof(pi_device), &DeviceId, nullptr);
299+
300+
for (const device &Dev : CtxImpl->getDevices()) {
301+
// Try to find the real sycl device used in the context
302+
if (detail::getSyclObjImpl(Dev)->getHandleRef() == DeviceId)
303+
return Dev;
304+
}
305+
306+
throw runtime_error("Cannot find device associated with USM allocation!");
307+
}
308+
234309
} // namespace sycl
235310
} // namespace cl

sycl/test/usm/pointer_query.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: %clangxx -fsycl %s -o %t1.out
2+
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out
3+
4+
//==-------------- pointer_query.cpp - Pointer Query test ------------------==//
5+
//
6+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
7+
// See https://llvm.org/LICENSE.txt for license information.
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#include <CL/sycl.hpp>
13+
14+
using namespace cl::sycl;
15+
16+
int main() {
17+
int *array = nullptr;
18+
const int N = 4;
19+
queue q;
20+
auto dev = q.get_device();
21+
auto ctxt = q.get_context();
22+
23+
if (!(dev.get_info<info::device::usm_device_allocations>() &&
24+
dev.get_info<info::device::usm_shared_allocations>() &&
25+
dev.get_info<info::device::usm_host_allocations>()))
26+
return 0;
27+
28+
usm::alloc Kind;
29+
device D;
30+
31+
// Test device allocs
32+
array = (int *)malloc_device(N * sizeof(int), q);
33+
if (array == nullptr) {
34+
return 1;
35+
}
36+
Kind = get_pointer_type(array, ctxt);
37+
if (ctxt.is_host()) {
38+
// for now, host device treats all allocations
39+
// as host allocations
40+
if (Kind != usm::alloc::host) {
41+
return 2;
42+
}
43+
} else {
44+
if (Kind != usm::alloc::device) {
45+
return 3;
46+
}
47+
}
48+
D = get_pointer_device(array, ctxt);
49+
if (D != dev) {
50+
return 4;
51+
}
52+
free(array, ctxt);
53+
54+
// Test shared allocs
55+
array = (int *)malloc_shared(N * sizeof(int), q);
56+
if (array == nullptr) {
57+
return 5;
58+
}
59+
Kind = get_pointer_type(array, ctxt);
60+
if (ctxt.is_host()) {
61+
// for now, host device treats all allocations
62+
// as host allocations
63+
if (Kind != usm::alloc::host) {
64+
return 6;
65+
}
66+
} else {
67+
if (Kind != usm::alloc::shared) {
68+
return 7;
69+
}
70+
}
71+
D = get_pointer_device(array, ctxt);
72+
if (D != dev) {
73+
return 8;
74+
}
75+
free(array, ctxt);
76+
77+
// Test host allocs
78+
array = (int *)malloc_host(N * sizeof(int), q);
79+
if (array == nullptr) {
80+
return 9;
81+
}
82+
Kind = get_pointer_type(array, ctxt);
83+
if (Kind != usm::alloc::host) {
84+
return 10;
85+
}
86+
D = get_pointer_device(array, ctxt);
87+
auto Devs = ctxt.get_devices();
88+
auto result = std::find(Devs.begin(), Devs.end(), D);
89+
if (result == Devs.end()) {
90+
// Returned device was not in queried context
91+
return 11;
92+
}
93+
free(array, ctxt);
94+
95+
return 0;
96+
}

0 commit comments

Comments
 (0)