Skip to content

Commit 40dd254

Browse files
committed
Comment why normalise newshape in _detail.implementations.reshape
1 parent 8e857b8 commit 40dd254

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch_np/_detail/implementations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,10 @@ def reshape(tensor, *shape, order="C"):
504504
if order != "C":
505505
raise NotImplementedError
506506
newshape = shape[0] if len(shape) == 1 else shape
507+
# convert any tnp.ndarray inputs into tensors before passing to torch.Tensor.reshape
508+
t_newshape = _helpers.ndarrays_to_tensors(newshape)
507509
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
508-
result = tensor.reshape(_helpers.ndarrays_to_tensors(newshape))
510+
result = tensor.reshape(t_newshape)
509511
return result
510512

511513

0 commit comments

Comments
 (0)