Skip to content

Commit 7851aed

Browse files
authored
Merge pull request #98 from Quansight-Labs/take_rebase
take() implementation, rebase
2 parents 0e8e6b5 + 73ea76c commit 7851aed

File tree

6 files changed

+48
-40
lines changed

6 files changed

+48
-40
lines changed

torch_np/_funcs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,24 @@ def asfarray():
930930
# ### put/take_along_axis ###
931931

932932

933+
@normalizer
934+
def take(
935+
a: ArrayLike,
936+
indices: ArrayLike,
937+
axis=None,
938+
out: Optional[NDArray] = None,
939+
mode="raise",
940+
):
941+
if mode != "raise":
942+
raise NotImplementedError(f"{mode=}")
943+
944+
(a,), axis = _util.axis_none_ravel(a, axis=axis)
945+
axis = _util.normalize_axis_index(axis, a.ndim)
946+
idx = (slice(None),) * axis + (indices, ...)
947+
result = a[idx]
948+
return result
949+
950+
933951
@normalizer
934952
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
935953
(arr,), axis = _util.axis_none_ravel(arr, axis=axis)

torch_np/_ndarray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ def __setitem__(self, index, value):
407407
value = _helpers.ndarrays_to_tensors(value)
408408
return self.tensor.__setitem__(index, value)
409409

410+
take = _funcs.take
411+
410412

411413
# This is the ideally the only place which talks to ndarray directly.
412414
# The rest goes through asarray (preferred) or array.

torch_np/tests/numpy_tests/core/test_indexing.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,9 +1053,8 @@ def test_non_integer_argument_errors(self):
10531053

10541054
assert_raises(TypeError, np.reshape, a, (1., 1., -1))
10551055
assert_raises(TypeError, np.reshape, a, (np.array(1.), -1))
1056-
pytest.xfail("XXX: take not implemented")
10571056
assert_raises(TypeError, np.take, a, [0], 1.)
1058-
assert_raises(TypeError, np.take, a, [0], np.float64(1.))
1057+
assert_raises((TypeError, RuntimeError), np.take, a, [0], np.float64(1.))
10591058

10601059
@pytest.mark.skip(
10611060
reason=(
@@ -1089,7 +1088,6 @@ def test_bool_as_int_argument_errors(self):
10891088
# array is thus also deprecated, but not with the same message:
10901089
assert_warns(DeprecationWarning, operator.index, np.True_)
10911090

1092-
pytest.xfail("XXX: take not implemented")
10931091
assert_raises(TypeError, np.take, args=(a, [0], False))
10941092

10951093
pytest.skip("torch consumes boolean tensors as ints, no bother raising here")
@@ -1138,8 +1136,7 @@ def test_array_to_index_error(self):
11381136
# so no exception is expected. The raising is effectively tested above.
11391137
a = np.array([[[1]]])
11401138

1141-
pytest.xfail("XXX: take not implemented")
1142-
assert_raises(TypeError, np.take, a, [0], a)
1139+
assert_raises((TypeError, RuntimeError), np.take, a, [0], a)
11431140

11441141
pytest.skip(
11451142
"Multi-dimensional tensors are indexable just as long as they only "

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3817,57 +3817,41 @@ def test_kwargs(self):
38173817
np.putmask(a=x, values=[-1, -2], mask=[0, 1])
38183818

38193819

3820-
@pytest.mark.xfail(reason='TODO')
38213820
class TestTake:
38223821
def tst_basic(self, x):
38233822
ind = list(range(x.shape[0]))
3824-
assert_array_equal(x.take(ind, axis=0), x)
3823+
assert_array_equal(np.take(x, ind, axis=0), x)
38253824

38263825
def test_ip_types(self):
3827-
unchecked_types = [bytes, str, np.void]
3828-
38293826
x = np.random.random(24)*100
3830-
x.shape = 2, 3, 4
3827+
x = np.reshape(x, (2, 3, 4))
38313828
for types in np.sctypes.values():
38323829
for T in types:
3833-
if T not in unchecked_types:
3834-
self.tst_basic(x.copy().astype(T))
3835-
3836-
# Also test string of a length which uses an untypical length
3837-
self.tst_basic(x.astype("S3"))
3830+
self.tst_basic(x.copy().astype(T))
38383831

38393832
def test_raise(self):
38403833
x = np.random.random(24)*100
3841-
x.shape = 2, 3, 4
3842-
assert_raises(IndexError, x.take, [0, 1, 2], axis=0)
3843-
assert_raises(IndexError, x.take, [-3], axis=0)
3844-
assert_array_equal(x.take([-1], axis=0)[0], x[1])
3834+
x = np.reshape(x, (2, 3, 4))
3835+
assert_raises(IndexError, np.take, x, [0, 1, 2], axis=0)
3836+
assert_raises(IndexError, np.take, x, [-3], axis=0)
3837+
assert_array_equal(np.take(x, [-1], axis=0)[0], x[1])
38453838

3839+
@pytest.mark.xfail(reason="XXX: take(..., mode='clip')")
38463840
def test_clip(self):
38473841
x = np.random.random(24)*100
3848-
x.shape = 2, 3, 4
3849-
assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0])
3850-
assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1])
3842+
x = np.reshape(x, (2, 3, 4))
3843+
assert_array_equal(np.take(x, [-1], axis=0, mode='clip')[0], x[0])
3844+
assert_array_equal(np.take(x, [2], axis=0, mode='clip')[0], x[1])
38513845

3846+
@pytest.mark.xfail(reason="XXX: take(..., mode='wrap')")
38523847
def test_wrap(self):
38533848
x = np.random.random(24)*100
3854-
x.shape = 2, 3, 4
3855-
assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1])
3856-
assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
3857-
assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])
3858-
3859-
@pytest.mark.parametrize('dtype', ('>i4', '<i4'))
3860-
def test_byteorder(self, dtype):
3861-
x = np.array([1, 2, 3], dtype)
3862-
assert_array_equal(x.take([0, 2, 1]), [1, 3, 2])
3863-
3864-
def test_record_array(self):
3865-
# Note mixed byteorder.
3866-
rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
3867-
dtype=[('x', '<f8'), ('y', '>f8'), ('z', '<f8')])
3868-
rec1 = rec.take([1])
3869-
assert_(rec1['x'] == 5.0 and rec1['y'] == 4.0)
3849+
x = np.reshape(x, (2, 3, 4))
3850+
assert_array_equal(np.take(x, [-1], axis=0, mode='wrap')[0], x[1])
3851+
assert_array_equal(np.take(x, [2], axis=0, mode='wrap')[0], x[0])
3852+
assert_array_equal(np.take(x, [3], axis=0, mode='wrap')[0], x[1])
38703853

3854+
@pytest.mark.xfail(reason="XXX: take(mode='wrap')")
38713855
def test_out_overlap(self):
38723856
# gh-6272 check overlap on out
38733857
x = np.arange(5)

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,6 @@ def test_sum(self):
282282

283283
assert_equal(tgt, out)
284284

285-
@pytest.mark.xfail(reason="TODO implement take(...)")
286285
def test_take(self):
287286
tgt = [2, 3, 5]
288287
indices = [1, 2, 4]

torch_np/tests/numpy_tests/lib/test_arraysetops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,6 @@ def test_unique_axis_list(self):
768768
assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg)
769769
assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg)
770770

771-
@pytest.mark.xfail(reason='TODO: implement take')
772771
def test_unique_axis(self):
773772
types = []
774773
types.extend(np.typecodes['AllInteger'])
@@ -857,6 +856,15 @@ def _run_axis_tests(self, dtype):
857856
result = np.array([[0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 1, 0]])
858857
assert_array_equal(unique(data, axis=1), result.astype(dtype), msg)
859858

859+
pytest.xfail("torch has different unique ordering behaviour")
860+
# e.g.
861+
#
862+
# >>> x = np.array([[[1, 1], [0, 1]], [[1, 0], [0, 0]]])
863+
# >>> np.unique(x, axis=2)
864+
# [[1, 1], [0, 1]], [[1, 0], [0, 0]]
865+
# >>> torch.unique(torch.as_tensor(x), dim=2)
866+
# [[1, 1], [1, 0]], [[0, 1], [0, 0]]
867+
#
860868
msg = 'Unique with 3d array and axis=2 failed'
861869
data3d = np.array([[[1, 1],
862870
[1, 0]],

0 commit comments

Comments
 (0)