Skip to content

Commit 7e79083

Browse files
Enable use of np.int64 to specify shape of usm_ndarray
1 parent 9018745 commit 7e79083

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,19 @@ cdef class usm_ndarray:
182182
cdef bint is_fp16 = False
183183

184184
self._reset()
185-
if (not isinstance(shape, (list, tuple))
186-
and not hasattr(shape, 'tolist')):
187-
try:
188-
<Py_ssize_t> shape
189-
shape = [shape, ]
190-
except Exception:
191-
raise TypeError("Argument shape must be a list or a tuple.")
185+
if not isinstance(shape, (list, tuple)):
186+
if hasattr(shape, 'tolist'):
187+
fn = getattr(shape, 'tolist')
188+
if callable(fn):
189+
shape = shape.tolist()
190+
if not isinstance(shape, (list, tuple)):
191+
try:
192+
<Py_ssize_t> shape
193+
shape = [shape, ]
194+
except Exception:
195+
raise TypeError(
196+
"Argument shape must be a list or a tuple."
197+
)
192198
nd = len(shape)
193199
if dtype is None:
194200
if isinstance(buffer, (dpmem._memory._Memory, usm_ndarray)):

0 commit comments

Comments
 (0)