Skip to content

Commit ee1d590

Browse files
Implemented support for constructing MemoryUSM* from object with __sycl_usm_array_interface__ when array-info is not contiguous (#400)
* implemented feature #396
1 parent d97370a commit ee1d590

File tree

4 files changed

+357
-164
lines changed

4 files changed

+357
-164
lines changed

dpctl/memory/_memory.pxd

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,22 @@ cdef public class _Memory [object Py_MemoryObject, type Py_MemoryType]:
5151
cpdef bytes tobytes(self)
5252

5353
@staticmethod
54-
cdef public SyclDevice get_pointer_device(DPCTLSyclUSMRef p, SyclContext ctx)
54+
cdef public SyclDevice get_pointer_device(
55+
DPCTLSyclUSMRef p, SyclContext ctx)
5556
@staticmethod
5657
cdef public bytes get_pointer_type(DPCTLSyclUSMRef p, SyclContext ctx)
5758

5859

59-
cdef public class MemoryUSMShared(_Memory) [object PyMemoryUSMSharedObject, type PyMemoryUSMSharedType]:
60+
cdef public class MemoryUSMShared(_Memory) [object PyMemoryUSMSharedObject,
61+
type PyMemoryUSMSharedType]:
6062
pass
6163

6264

63-
cdef public class MemoryUSMHost(_Memory) [object PyMemoryUSMHostObject, type PyMemoryUSMHostType]:
65+
cdef public class MemoryUSMHost(_Memory) [object PyMemoryUSMHostObject,
66+
type PyMemoryUSMHostType]:
6467
pass
6568

6669

67-
cdef public class MemoryUSMDevice(_Memory) [object PyMemoryUSMDeviceObject, type PyMemoryUSMDeviceType]:
70+
cdef public class MemoryUSMDevice(_Memory) [object PyMemoryUSMDeviceObject,
71+
type PyMemoryUSMDeviceType]:
6872
pass

dpctl/memory/_memory.pyx

Lines changed: 72 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -35,68 +35,15 @@ from cpython cimport pycapsule
3535

3636
import numpy as np
3737
import numbers
38+
import collections
3839

3940
__all__ = [
4041
"MemoryUSMShared",
4142
"MemoryUSMHost",
4243
"MemoryUSMDevice"
4344
]
4445

45-
cdef object _sycl_usm_ary_iface_error():
46-
return ValueError("__sycl_usm_array_interface__ is malformed")
47-
48-
49-
cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(SyclQueue q):
50-
return DPCTLQueue_Copy(q.get_queue_ref())
51-
52-
53-
cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext(
54-
DPCTLSyclUSMRef ptr, SyclContext ctx):
55-
""" Obtain device from pointer and sycl context, use
56-
context and device to create a queue from which this memory
57-
can be accessible.
58-
"""
59-
cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx)
60-
cdef DPCTLSyclContextRef CRef = NULL
61-
cdef DPCTLSyclDeviceRef DRef = NULL
62-
CRef = ctx.get_context_ref()
63-
DRef = dev.get_device_ref()
64-
return DPCTLQueue_Create(CRef, DRef, NULL, 0)
65-
66-
67-
cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
68-
DPCTLSyclUSMRef ptr, object syclobj):
69-
""" Constructs queue from pointer and syclobject from
70-
__sycl_usm_array_interface__
71-
"""
72-
cdef DPCTLSyclQueueRef QRef = NULL
73-
cdef SyclContext ctx
74-
if type(syclobj) is SyclQueue:
75-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> syclobj)
76-
elif type(syclobj) is SyclContext:
77-
ctx = <SyclContext>syclobj
78-
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
79-
elif type(syclobj) is str:
80-
q = SyclQueue(syclobj)
81-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
82-
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclQueueRef"):
83-
q = SyclQueue(syclobj)
84-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
85-
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclContextRef"):
86-
ctx = <SyclContext>SyclContext(syclobj)
87-
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
88-
elif hasattr(syclobj, '_get_capsule'):
89-
cap = syclobj._get_capsule()
90-
if pycapsule.PyCapsule_IsValid(cap, "SyclQueueRef"):
91-
q = SyclQueue(cap)
92-
return _queue_ref_copy_from_SyclQueue(<SyclQueue> q)
93-
elif pycapsule.PyCapsule_IsValid(cap, "SyclContexRef"):
94-
ctx = <SyclContext>SyclContext(cap)
95-
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
96-
else:
97-
return QRef
98-
else:
99-
return QRef
46+
include "_sycl_usm_array_interface_utils.pxi"
10047

10148

10249
cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
@@ -126,66 +73,6 @@ cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
12673
)
12774

12875

129-
cdef class _BufferData:
130-
"""
131-
Internal data struct populated from parsing
132-
`__sycl_usm_array_interface__` dictionary
133-
"""
134-
cdef DPCTLSyclUSMRef p
135-
cdef int writeable
136-
cdef object dt
137-
cdef Py_ssize_t itemsize
138-
cdef Py_ssize_t nbytes
139-
cdef SyclQueue queue
140-
141-
@staticmethod
142-
cdef _BufferData from_sycl_usm_ary_iface(dict ary_iface):
143-
cdef object ary_data_tuple = ary_iface.get('data', None)
144-
cdef object ary_typestr = ary_iface.get('typestr', None)
145-
cdef object ary_shape = ary_iface.get('shape', None)
146-
cdef object ary_strides = ary_iface.get('strides', None)
147-
cdef object ary_syclobj = ary_iface.get('syclobj', None)
148-
cdef Py_ssize_t ary_offset = ary_iface.get('offset', 0)
149-
cdef int ary_version = ary_iface.get('version', 0)
150-
cdef object dt
151-
cdef _BufferData buf
152-
cdef Py_ssize_t arr_data_ptr
153-
cdef SyclDevice dev
154-
cdef SyclContext ctx
155-
cdef DPCTLSyclQueueRef QRef = NULL
156-
157-
if ary_version != 1:
158-
raise _sycl_usm_ary_iface_error()
159-
if not ary_data_tuple or len(ary_data_tuple) != 2:
160-
raise _sycl_usm_ary_iface_error()
161-
if not ary_shape or len(ary_shape) != 1 or ary_shape[0] < 1:
162-
raise ValueError
163-
try:
164-
dt = np.dtype(ary_typestr)
165-
except TypeError:
166-
raise _sycl_usm_ary_iface_error()
167-
if (ary_strides and len(ary_strides) != 1
168-
and ary_strides[0] != dt.itemsize):
169-
raise ValueError("Must be contiguous")
170-
171-
if (not ary_syclobj or
172-
not isinstance(ary_syclobj,
173-
(dpctl.SyclQueue, dpctl.SyclContext))):
174-
raise _sycl_usm_ary_iface_error()
175-
176-
buf = _BufferData.__new__(_BufferData)
177-
arr_data_ptr = <Py_ssize_t>ary_data_tuple[0]
178-
buf.p = <DPCTLSyclUSMRef>(<void*>arr_data_ptr)
179-
buf.writeable = 1 if ary_data_tuple[1] else 0
180-
buf.itemsize = <Py_ssize_t>(dt.itemsize)
181-
buf.nbytes = (<Py_ssize_t>ary_shape[0]) * buf.itemsize
182-
183-
QRef = get_queue_ref_from_ptr_and_syclobj(buf.p, ary_syclobj)
184-
buf.queue = SyclQueue._create(QRef)
185-
186-
return buf
187-
188-
18976
def _to_memory(unsigned char [::1] b, str usm_kind):
19077
"""
19178
Constructs Memory of the same size as the argument
@@ -272,7 +159,7 @@ cdef class _Memory:
272159
elif hasattr(other, '__sycl_usm_array_interface__'):
273160
other_iface = other.__sycl_usm_array_interface__
274161
if isinstance(other_iface, dict):
275-
other_buf = _BufferData.from_sycl_usm_ary_iface(other_iface)
162+
other_buf = _USMBufferData.from_sycl_usm_ary_iface(other_iface)
276163
self.memory_ptr = other_buf.p
277164
self.nbytes = other_buf.nbytes
278165
self.queue = other_buf.queue
@@ -415,13 +302,16 @@ cdef class _Memory:
415302
return obj
416303

417304
cpdef copy_from_host(self, object obj):
418-
"""Copy content of Python buffer provided by `obj` to instance memory."""
305+
"""
306+
Copy content of Python buffer provided by `obj` to instance memory.
307+
"""
419308
cdef const unsigned char[::1] host_buf = obj
420309
cdef Py_ssize_t buf_len = len(host_buf)
421310

422311
if (buf_len > self.nbytes):
423312
raise ValueError("Source object is too large to be "
424-
"accommodated in {} bytes buffer".format(self.nbytes))
313+
"accommodated in {} bytes buffer".format(
314+
self.nbytes))
425315
# call kernel to copy from
426316
DPCTLQueue_Memcpy(
427317
self.queue.get_queue_ref(),
@@ -433,19 +323,20 @@ cdef class _Memory:
433323
cpdef copy_from_device(self, object sycl_usm_ary):
434324
"""Copy SYCL memory underlying the argument object into
435325
the memory of the instance"""
436-
cdef _BufferData src_buf
326+
cdef _USMBufferData src_buf
437327
cdef const char* kind
438328

439329
if not hasattr(sycl_usm_ary, '__sycl_usm_array_interface__'):
440330
raise ValueError("Object does not implement "
441331
"`__sycl_usm_array_interface__` protocol")
442332
sycl_usm_ary_iface = sycl_usm_ary.__sycl_usm_array_interface__
443333
if isinstance(sycl_usm_ary_iface, dict):
444-
src_buf = _BufferData.from_sycl_usm_ary_iface(sycl_usm_ary_iface)
334+
src_buf = _USMBufferData.from_sycl_usm_ary_iface(sycl_usm_ary_iface)
445335

446336
if (src_buf.nbytes > self.nbytes):
447337
raise ValueError("Source object is too large to "
448-
"be accommondated in {} bytes buffer".format(self.nbytes))
338+
"be accommondated in {} bytes buffer".format(
339+
self.nbytes))
449340
kind = DPCTLUSM_GetPointerType(
450341
src_buf.p, self.queue.get_sycl_context().get_context_ref())
451342
if (kind == b'unknown'):
@@ -477,107 +368,128 @@ cdef class _Memory:
477368

478369
@staticmethod
479370
cdef SyclDevice get_pointer_device(DPCTLSyclUSMRef p, SyclContext ctx):
480-
"""Returns sycl device used to allocate given pointer `p` in given sycl context `ctx`"""
481-
cdef DPCTLSyclDeviceRef dref = DPCTLUSM_GetPointerDevice(p, ctx.get_context_ref())
371+
"""
372+
Returns sycl device used to allocate given pointer `p` in
373+
given sycl context `ctx`
374+
"""
375+
cdef DPCTLSyclDeviceRef dref = DPCTLUSM_GetPointerDevice(
376+
p, ctx.get_context_ref())
482377

483378
return SyclDevice._create(dref)
484379

485380
@staticmethod
486381
cdef bytes get_pointer_type(DPCTLSyclUSMRef p, SyclContext ctx):
487382
"""Returns USM-type of given pointer `p` in given sycl context `ctx`"""
488-
cdef const char * usm_type = DPCTLUSM_GetPointerType(p, ctx.get_context_ref())
383+
cdef const char * usm_type = DPCTLUSM_GetPointerType(
384+
p, ctx.get_context_ref())
489385

490386
return <bytes>usm_type
491387

492388

493389
cdef class MemoryUSMShared(_Memory):
494390
"""
495-
MemoryUSMShared(nbytes, alignment=0, queue=None, copy=False) allocates nbytes of
496-
USM shared memory.
391+
MemoryUSMShared(nbytes, alignment=0, queue=None, copy=False)
392+
allocates nbytes of USM shared memory.
497393
498394
Non-positive alignments are not used (malloc_shared is used instead).
499395
For the queue=None cast the `dpctl.SyclQueue()` is used to allocate memory.
500396
501-
MemoryUSMShared(usm_obj) constructor create instance from `usm_obj` expected to
502-
implement `__sycl_usm_array_interface__` protocol and exposing a contiguous block of
503-
USM memory of USM shared type. Using copy=True to perform a copy if USM type is other
504-
than 'shared'.
397+
MemoryUSMShared(usm_obj) constructor create instance from `usm_obj`
398+
expected to implement `__sycl_usm_array_interface__` protocol and exposing
399+
a contiguous block of USM memory of USM shared type. Using copy=True to
400+
perform a copy if USM type is other than 'shared'.
505401
"""
506-
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None, int copy=False):
402+
def __cinit__(self, other, *, Py_ssize_t alignment=0,
403+
SyclQueue queue=None, int copy=False):
507404
if (isinstance(other, numbers.Integral)):
508405
self._cinit_alloc(alignment, <Py_ssize_t>other, b"shared", queue)
509406
else:
510407
self._cinit_other(other)
511408
if (self.get_usm_type() != "shared"):
512409
if copy:
513-
self._cinit_alloc(0, <Py_ssize_t>self.nbytes, b"shared", queue)
410+
self._cinit_alloc(0, <Py_ssize_t>self.nbytes,
411+
b"shared", queue)
514412
self.copy_from_device(other)
515413
else:
516-
raise ValueError("USM pointer in the argument {} is not a USM shared pointer. "
517-
"Zero-copy operation is not possible with copy=False. "
518-
"Either use copy=True, or use a constructor appropriate for "
519-
"type '{}'".format(other, self.get_usm_type()))
414+
raise ValueError(
415+
"USM pointer in the argument {} is not a "
416+
"USM shared pointer. "
417+
"Zero-copy operation is not possible with "
418+
"copy=False. "
419+
"Either use copy=True, or use a constructor "
420+
"appropriate for "
421+
"type '{}'".format(other, self.get_usm_type()))
520422

521423
def __getbuffer__(self, Py_buffer *buffer, int flags):
522424
self._getbuffer(buffer, flags)
523425

524426

525427
cdef class MemoryUSMHost(_Memory):
526428
"""
527-
MemoryUSMHost(nbytes, alignment=0, queue=None, copy=False) allocates nbytes of
528-
USM host memory.
429+
MemoryUSMHost(nbytes, alignment=0, queue=None, copy=False)
430+
allocates nbytes of USM host memory.
529431
530432
Non-positive alignments are not used (malloc_host is used instead).
531433
For the queue=None case `dpctl.SyclQueue()` is used to allocate memory.
532434
533-
MemoryUSMDevice(usm_obj) constructor create instance from `usm_obj` expected to
534-
implement `__sycl_usm_array_interface__` protocol and exposing a contiguous block of
535-
USM memory of USM host type. Using copy=True to perform a copy if USM type is other
536-
than 'host'.
435+
MemoryUSMDevice(usm_obj) constructor create instance from `usm_obj`
436+
expected to implement `__sycl_usm_array_interface__` protocol and exposing
437+
a contiguous block of USM memory of USM host type. Using copy=True to
438+
perform a copy if USM type is other than 'host'.
537439
"""
538-
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None, int copy=False):
440+
def __cinit__(self, other, *, Py_ssize_t alignment=0,
441+
SyclQueue queue=None, int copy=False):
539442
if (isinstance(other, numbers.Integral)):
540443
self._cinit_alloc(alignment, <Py_ssize_t>other, b"host", queue)
541444
else:
542445
self._cinit_other(other)
543446
if (self.get_usm_type() != "host"):
544447
if copy:
545-
self._cinit_alloc(0, <Py_ssize_t>self.nbytes, b"host", queue)
448+
self._cinit_alloc(0, <Py_ssize_t>self.nbytes,
449+
b"host", queue)
546450
self.copy_from_device(other)
547451
else:
548-
raise ValueError("USM pointer in the argument {} is not a USM host pointer. "
549-
"Zero-copy operation is not possible with copy=False. "
550-
"Either use copy=True, or use a constructor appropriate for "
551-
"type '{}'".format(other, self.get_usm_type()))
452+
raise ValueError(
453+
"USM pointer in the argument {} is "
454+
"not a USM host pointer. "
455+
"Zero-copy operation is not possible with copy=False. "
456+
"Either use copy=True, or use a constructor "
457+
"appropriate for type '{}'".format(
458+
other, self.get_usm_type()))
552459

553460
def __getbuffer__(self, Py_buffer *buffer, int flags):
554461
self._getbuffer(buffer, flags)
555462

556463

557464
cdef class MemoryUSMDevice(_Memory):
558465
"""
559-
MemoryUSMDevice(nbytes, alignment=0, queue=None, copy=False) allocates nbytes of
560-
USM device memory.
466+
MemoryUSMDevice(nbytes, alignment=0, queue=None, copy=False)
467+
allocates nbytes of USM device memory.
561468
562469
Non-positive alignments are not used (malloc_device is used instead).
563470
For the queue=None cast the `dpctl.SyclQueue()` is used to allocate memory.
564471
565-
MemoryUSMDevice(usm_obj) constructor create instance from `usm_obj` expected to
566-
implement `__sycl_usm_array_interface__` protocol and exposing a contiguous block of
567-
USM memory of USM device type. Using copy=True to perform a copy if USM type is other
568-
than 'device'.
472+
MemoryUSMDevice(usm_obj) constructor create instance from `usm_obj`
473+
expected to implement `__sycl_usm_array_interface__` protocol and exposing
474+
a contiguous block of USM memory of USM device type. Using copy=True to
475+
perform a copy if USM type is other than 'device'.
569476
"""
570-
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None, int copy=False):
477+
def __cinit__(self, other, *, Py_ssize_t alignment=0,
478+
SyclQueue queue=None, int copy=False):
571479
if (isinstance(other, numbers.Integral)):
572480
self._cinit_alloc(alignment, <Py_ssize_t>other, b"device", queue)
573481
else:
574482
self._cinit_other(other)
575483
if (self.get_usm_type() != "device"):
576484
if copy:
577-
self._cinit_alloc(0, <Py_ssize_t>self.nbytes, b"device", queue)
485+
self._cinit_alloc(0, <Py_ssize_t>self.nbytes,
486+
b"device", queue)
578487
self.copy_from_device(other)
579488
else:
580-
raise ValueError("USM pointer in the argument {} is not a USM device pointer. "
581-
"Zero-copy operation is not possible with copy=False. "
582-
"Either use copy=True, or use a constructor appropriate for "
583-
"type '{}'".format(other, self.get_usm_type()))
489+
raise ValueError(
490+
"USM pointer in the argument {} is not "
491+
"a USM device pointer. "
492+
"Zero-copy operation is not possible with copy=False. "
493+
"Either use copy=True, or use a constructor "
494+
"appropriate for type '{}'".format(
495+
other, self.get_usm_type()))

0 commit comments

Comments
 (0)