Skip to content

Commit 8ed0a4c

Browse files
committed
Style changes and return a device in context for host allocs
Signed-off-by: James Brodman <[email protected]>
1 parent bd7142a commit 8ed0a4c

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

sycl/source/detail/usm/usm_impl.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -243,22 +243,27 @@ alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
243243
return alloc::host;
244244

245245
std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl(Ctxt);
246-
pi_context C = CtxImpl->getHandleRef();
246+
pi_context PICtx = CtxImpl->getHandleRef();
247247
pi_usm_type AllocTy;
248248

249249
// query type using PI function
250-
PI_CALL(piextUSMGetMemAllocInfo)(C, Ptr, PI_MEM_ALLOC_TYPE,
250+
PI_CALL(piextUSMGetMemAllocInfo)(PICtx, Ptr, PI_MEM_ALLOC_TYPE,
251251
sizeof(pi_usm_type), &AllocTy, nullptr);
252252

253-
alloc ResultAlloc = alloc::unknown;
254-
if (AllocTy == PI_MEM_TYPE_HOST) {
253+
alloc ResultAlloc;
254+
switch (AllocTy) {
255+
case PI_MEM_TYPE_HOST:
255256
ResultAlloc = alloc::host;
256-
}
257-
else if (AllocTy == PI_MEM_TYPE_DEVICE) {
257+
break;
258+
case PI_MEM_TYPE_DEVICE:
258259
ResultAlloc = alloc::device;
259-
}
260-
else if (AllocTy == PI_MEM_TYPE_SHARED) {
260+
break;
261+
case PI_MEM_TYPE_SHARED:
261262
ResultAlloc = alloc::shared;
263+
break;
264+
default:
265+
ResultAlloc = alloc::unknown;
266+
break;
262267
}
263268

264269
return ResultAlloc;
@@ -273,21 +278,26 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) {
273278
if (Ctxt.is_host())
274279
return Ctxt.get_devices()[0];
275280

281+
std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl(Ctxt);
282+
276283
// Check if ptr is a host allocation
277-
if (get_pointer_type(Ptr, Ctxt) == alloc::host)
278-
return device();
284+
if (get_pointer_type(Ptr, Ctxt) == alloc::host) {
285+
auto Devs = CtxImpl->getDevices();
286+
if (Devs.size() == 0)
287+
throw runtime_error("No devices in passed context!");
279288

280-
std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl(Ctxt);
281-
pi_context C = CtxImpl->getHandleRef();
289+
// Just return the first device in the context
290+
return Devs[0];
291+
}
292+
293+
pi_context PICtx = CtxImpl->getHandleRef();
282294
pi_device DeviceId;
283295

284296
// query device using PI function
285-
PI_CALL(piextUSMGetMemAllocInfo)(C, Ptr, PI_MEM_ALLOC_DEVICE,
297+
PI_CALL(piextUSMGetMemAllocInfo)(PICtx, Ptr, PI_MEM_ALLOC_DEVICE,
286298
sizeof(pi_device), &DeviceId, nullptr);
287299

288-
auto Devs = Ctxt.get_devices();
289-
290-
for (auto D : Devs) {
300+
for (const auto D : CtxImpl->getDevices()) {
291301
// Try to find the real sycl device used in the context
292302
if (detail::pi::cast<pi_device>(D.get()) == DeviceId)
293303
return D;

sycl/test/usm/pointer_query.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ int main() {
8686
return 10;
8787
}
8888
D = get_pointer_device(array, ctxt);
89-
if (!D.is_host()) {
89+
auto Devs = ctxt.get_devices();
90+
auto result = std::find(Devs.begin(), Devs.end(), D);
91+
if (result == Devs.end()) {
92+
// Returned device was not in queried context
9093
return 11;
9194
}
9295
free(array, ctxt);

0 commit comments

Comments
 (0)