Skip to content

Commit 5081a98

Browse files
committed
Smoke test for advance integer indexing
1 parent d378b56 commit 5081a98

File tree

6 files changed

+50
-9
lines changed

6 files changed

+50
-9
lines changed

torch_np/_helpers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,23 @@ def ndarrays_to_tensors(*inputs):
7474
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
7575
from ._ndarray import asarray, ndarray
7676

77-
return tuple(
78-
[value.get() if isinstance(value, ndarray) else value for value in inputs]
79-
)
77+
if len(inputs) == 0:
78+
return ValueError()
79+
elif len(inputs) == 1:
80+
input_ = inputs[0]
81+
if isinstance(input_, ndarray):
82+
return input_.get()
83+
elif isinstance(input_, tuple):
84+
result = []
85+
for sub_input in input_:
86+
sub_result = ndarrays_to_tensors(sub_input)
87+
result.append(sub_result)
88+
return tuple(result)
89+
else:
90+
return input_
91+
else:
92+
assert isinstance(inputs, tuple) # sanity check
93+
return ndarrays_to_tensors(inputs)
8094

8195

8296
def to_tensors(*inputs):

torch_np/_ndarray.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,9 @@ def clip(self, min, max, out=None):
410410
)
411411

412412
### indexing ###
413-
def __getitem__(self, *args, **kwds):
414-
t_args = _helpers.ndarrays_to_tensors(*args)
415-
return ndarray._from_tensor_and_base(
416-
self._tensor.__getitem__(*t_args, **kwds), self
417-
)
413+
def __getitem__(self, index):
414+
t_index = _helpers.ndarrays_to_tensors(index)
415+
return ndarray._from_tensor_and_base(self._tensor.__getitem__(t_index), self)
418416

419417
def __setitem__(self, index, value):
420418
value = asarray(value).get()

torch_np/_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ def empty_like(prototype, dtype=None, order="K", subok=False, shape=None):
312312

313313
def full(shape, fill_value, dtype=None, order="C", *, like=None):
314314
_util.subok_not_ok(like)
315+
if isinstance(shape, int):
316+
shape = (shape,)
315317
if order != "C":
316318
raise NotImplementedError
317319
if isinstance(fill_value, ndarray):

torch_np/tests/__init__.py

Whitespace-only changes.

torch_np/tests/test_ndarray_methods.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import itertools
22

33
import pytest
4-
from pytest import raises as assert_raises
54

65
# import numpy as np
6+
import torch
7+
from pytest import raises as assert_raises
8+
79
import torch_np as np
10+
from torch_np._ndarray import ndarray
811
from torch_np.testing import assert_equal
912

1013

torch_np/tests/test_xps.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,27 @@ def test_full(shape, data):
5858
assert np.isnan(out).all()
5959
else:
6060
assert (out == fill_value).all()
61+
62+
63+
def integer_array_indices(shape, result_shape) -> st.SearchStrategy[tuple]:
64+
# See hypothesis.extra.numpy.integer_array_indices()
65+
# n.b. result_shape only accepts a shape, as opposed to only accepting a strategy
66+
def array_for(index_shape, size):
67+
return xps.arrays(
68+
dtype=xps.integer_dtypes(),
69+
shape=index_shape,
70+
elements=st.integers(-size, size - 1),
71+
)
72+
73+
return st.tuples(*(array_for(result_shape, size) for size in shape))
74+
75+
76+
@given(
77+
x=xps.arrays(dtype=xps.integer_dtypes(), shape=xps.array_shapes()),
78+
data=st.data(),
79+
)
80+
def test_integer_indexing(x, data):
81+
result_shape = data.draw(xps.array_shapes(), label="result_shape")
82+
idx = data.draw(integer_array_indices(x.shape, result_shape), label="idx")
83+
result = x[idx]
84+
assert result.shape == result_shape

0 commit comments

Comments
 (0)