@@ -243,22 +243,27 @@ alloc get_pointer_type(const void *Ptr, const context &Ctxt) {
243
243
return alloc::host;
244
244
245
245
std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl (Ctxt);
246
- pi_context C = CtxImpl->getHandleRef ();
246
+ pi_context PICtx = CtxImpl->getHandleRef ();
247
247
pi_usm_type AllocTy;
248
248
249
249
// query type using PI function
250
- PI_CALL (piextUSMGetMemAllocInfo)(C , Ptr, PI_MEM_ALLOC_TYPE,
250
+ PI_CALL (piextUSMGetMemAllocInfo)(PICtx , Ptr, PI_MEM_ALLOC_TYPE,
251
251
sizeof (pi_usm_type), &AllocTy, nullptr );
252
252
253
- alloc ResultAlloc = alloc::unknown;
254
- if (AllocTy == PI_MEM_TYPE_HOST) {
253
+ alloc ResultAlloc;
254
+ switch (AllocTy) {
255
+ case PI_MEM_TYPE_HOST:
255
256
ResultAlloc = alloc::host;
256
- }
257
- else if (AllocTy == PI_MEM_TYPE_DEVICE) {
257
+ break ;
258
+ case PI_MEM_TYPE_DEVICE:
258
259
ResultAlloc = alloc::device;
259
- }
260
- else if (AllocTy == PI_MEM_TYPE_SHARED) {
260
+ break ;
261
+ case PI_MEM_TYPE_SHARED:
261
262
ResultAlloc = alloc::shared;
263
+ break ;
264
+ default :
265
+ ResultAlloc = alloc::unknown;
266
+ break ;
262
267
}
263
268
264
269
return ResultAlloc;
@@ -273,21 +278,26 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) {
273
278
if (Ctxt.is_host ())
274
279
return Ctxt.get_devices ()[0 ];
275
280
281
+ std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl (Ctxt);
282
+
276
283
// Check if ptr is a host allocation
277
- if (get_pointer_type (Ptr, Ctxt) == alloc::host)
278
- return device ();
284
+ if (get_pointer_type (Ptr, Ctxt) == alloc::host) {
285
+ auto Devs = CtxImpl->getDevices ();
286
+ if (Devs.size () == 0 )
287
+ throw runtime_error (" No devices in passed context!" );
279
288
280
- std::shared_ptr<detail::context_impl> CtxImpl = detail::getSyclObjImpl (Ctxt);
281
- pi_context C = CtxImpl->getHandleRef ();
289
+ // Just return the first device in the context
290
+ return Devs[0 ];
291
+ }
292
+
293
+ pi_context PICtx = CtxImpl->getHandleRef ();
282
294
pi_device DeviceId;
283
295
284
296
// query device using PI function
285
- PI_CALL (piextUSMGetMemAllocInfo)(C , Ptr, PI_MEM_ALLOC_DEVICE,
297
+ PI_CALL (piextUSMGetMemAllocInfo)(PICtx , Ptr, PI_MEM_ALLOC_DEVICE,
286
298
sizeof (pi_device), &DeviceId, nullptr );
287
299
288
- auto Devs = Ctxt.get_devices ();
289
-
290
- for (auto D : Devs) {
300
+ for (const auto D : CtxImpl->getDevices ()) {
291
301
// Try to find the real sycl device used in the context
292
302
if (detail::pi::cast<pi_device>(D.get ()) == DeviceId)
293
303
return D;
0 commit comments