Skip to content

Commit e566695

Browse files
Closes #189
When type-checking input, check for the input to be instance of abtract class numbers.Integral, rather than class `int` Modified test suite to not expect the test from the ticket to fail.
1 parent 926e315 commit e566695

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

dpctl/memory/_memory.pyx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_FromStringAndSize
3434
from cpython cimport pycapsule
3535

3636
import numpy as np
37+
import numbers
3738

3839
__all__ = [
3940
"MemoryUSMShared",
@@ -51,21 +52,21 @@ cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(SyclQueue q):
5152

5253
cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext(
5354
DPCTLSyclUSMRef ptr, SyclContext ctx):
54-
""" Obtain device from pointer and sycl context, use
55+
""" Obtain device from pointer and sycl context, use
5556
context and device to create a queue from which this memory
5657
can be accessible.
5758
"""
5859
cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx)
5960
cdef DPCTLSyclContextRef CRef = NULL
60-
cdef DPCTLSyclDeviceRef DRef = NULL
61+
cdef DPCTLSyclDeviceRef DRef = NULL
6162
CRef = ctx.get_context_ref()
6263
DRef = dev.get_device_ref()
6364
return DPCTLQueue_Create(CRef, DRef, NULL, 0)
6465

6566

6667
cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
6768
DPCTLSyclUSMRef ptr, object syclobj):
68-
""" Constructs queue from pointer and syclobject from
69+
""" Constructs queue from pointer and syclobject from
6970
__sycl_usm_array_interface__
7071
"""
7172
cdef DPCTLSyclQueueRef QRef = NULL
@@ -96,7 +97,7 @@ cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
9697
return QRef
9798
else:
9899
return QRef
99-
100+
100101

101102
cdef void copy_via_host(void *dest_ptr, SyclQueue dest_queue,
102103
void *src_ptr, SyclQueue src_queue, size_t nbytes):
@@ -207,6 +208,9 @@ def _to_memory(unsigned char [::1] b, str usm_kind):
207208

208209

209210
cdef class _Memory:
211+
""" Internal class implementing methods common to
212+
MemoryUSMShared, MemoryUSMDevice, MemoryUSMHost
213+
"""
210214
cdef _cinit_empty(self):
211215
self.memory_ptr = NULL
212216
self.nbytes = 0
@@ -500,7 +504,7 @@ cdef class MemoryUSMShared(_Memory):
500504
than 'shared'.
501505
"""
502506
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None, int copy=False):
503-
if (isinstance(other, int)):
507+
if (isinstance(other, numbers.Integral)):
504508
self._cinit_alloc(alignment, <Py_ssize_t>other, b"shared", queue)
505509
else:
506510
self._cinit_other(other)
@@ -532,7 +536,7 @@ cdef class MemoryUSMHost(_Memory):
532536
than 'host'.
533537
"""
534538
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None, int copy=False):
535-
if (isinstance(other, int)):
539+
if (isinstance(other, numbers.Integral)):
536540
self._cinit_alloc(alignment, <Py_ssize_t>other, b"host", queue)
537541
else:
538542
self._cinit_other(other)
@@ -564,7 +568,7 @@ cdef class MemoryUSMDevice(_Memory):
564568
than 'device'.
565569
"""
566570
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None, int copy=False):
567-
if (isinstance(other, int)):
571+
if (isinstance(other, numbers.Integral)):
568572
self._cinit_alloc(alignment, <Py_ssize_t>other, b"device", queue)
569573
else:
570574
self._cinit_other(other)

dpctl/tests/test_sycl_usm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ def test_memory_create(self):
4848
self.assertEqual(mobj.nbytes, nbytes)
4949
self.assertTrue(hasattr(mobj, "__sycl_usm_array_interface__"))
5050

51-
@unittest.expectedFailure
5251
@unittest.skipUnless(
5352
has_sycl_platforms(), "No SYCL devices except the default host device."
5453
)
5554
def test_memory_create_with_np(self):
56-
mobj = dpctl.memory.MemoryUSMShared(np.int64(16384))
55+
nbytes = 16384
56+
mobj = dpctl.memory.MemoryUSMShared(np.int64(nbytes))
57+
self.assertEqual(mobj.nbytes, nbytes)
5758
self.assertTrue(hasattr(mobj, "__sycl_usm_array_interface__"))
5859

5960
def _create_memory(self):

0 commit comments

Comments
 (0)