Skip to content

Commit 726432d

Browse files
Resolve gh-1769
Modify shape.setter to handle integers. ``` In [1]: import dpctl.tensor as dpt In [2]: x = dpt.ones((2, 3)) In [3]: x.shape = 6 In [4]: x Out[4]: usm_ndarray([1., 1., 1., 1., 1., 1.], dtype=float32) In [5]: x = dpt.ones((2, 3)) In [6]: class Six: ...: def __init__(self, dim=1): ...: self.v = (1,) * (dim - 1) + (6,) ...: def __len__(self): ...: return len(self.v) ...: def __iter__(self): ...: return iter(self.v) ...: In [7]: x.shape = Six(3) In [8]: x.shape Out[8]: (1, 1, 6) ```
1 parent d79dae1 commit 726432d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,11 @@ cdef class usm_ndarray:
576576

577577
from ._reshape import reshaped_strides
578578

579-
new_nd = len(new_shape)
579+
try:
580+
new_nd = len(new_shape)
581+
except TypeError:
582+
new_nd = 1
583+
new_shape = (new_shape,)
580584
try:
581585
new_shape = tuple(operator.index(dim) for dim in new_shape)
582586
except TypeError:

0 commit comments

Comments
 (0)