Skip to content

Commit 0e2438e

Browse files
committed
ENH: unique
1 parent 108c0e4 commit 0e2438e

File tree

5 files changed

+117
-111
lines changed

5 files changed

+117
-111
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -779,18 +779,6 @@ def union1d(ar1, ar2):
779779
raise NotImplementedError
780780

781781

782-
def unique(
783-
ar,
784-
return_index=False,
785-
return_inverse=False,
786-
return_counts=False,
787-
axis=None,
788-
*,
789-
equal_nan=True,
790-
):
791-
raise NotImplementedError
792-
793-
794782
def unpackbits(a, /, axis=None, count=None, bitorder="big"):
795783
raise NotImplementedError
796784

torch_np/_detail/implementations.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,3 +766,40 @@ def dot(t_a, t_b):
766766
else:
767767
result = torch.matmul(t_a, t_b)
768768
return result
769+
770+
771+
# ### unique et al ###
772+
773+
774+
def unique(
775+
tensor,
776+
return_index=False,
777+
return_inverse=False,
778+
return_counts=False,
779+
axis=None,
780+
*,
781+
equal_nan=True,
782+
):
783+
if return_index or not equal_nan:
784+
raise NotImplementedError
785+
786+
if axis is None:
787+
tensor = tensor.ravel()
788+
axis = 0
789+
axis = _util.normalize_axis_index(axis, tensor.ndim)
790+
791+
is_half = tensor.dtype == torch.float16
792+
if is_half:
793+
tensor = tensor.to(torch.float32)
794+
795+
result = torch.unique(
796+
tensor, return_inverse=return_inverse, return_counts=return_counts, dim=axis
797+
)
798+
799+
if is_half:
800+
if isinstance(result, tuple):
801+
result = (result[0].to(torch.float16),) + result[1:]
802+
else:
803+
result = result.to(torch.float16)
804+
805+
return result

torch_np/_wrapper.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,34 @@ def argsort(a, axis=-1, kind=None, order=None):
11531153
return asarray(result)
11541154

11551155

1156+
# ### unqiue et al ###
1157+
1158+
1159+
def unique(
1160+
ar,
1161+
return_index=False,
1162+
return_inverse=False,
1163+
return_counts=False,
1164+
axis=None,
1165+
*,
1166+
equal_nan=True,
1167+
):
1168+
tensor = asarray(ar).get()
1169+
result = _impl.unique(
1170+
tensor,
1171+
return_index=return_index,
1172+
return_inverse=return_inverse,
1173+
return_counts=return_counts,
1174+
axis=axis,
1175+
equal_nan=equal_nan,
1176+
)
1177+
1178+
if isinstance(result, tuple):
1179+
return tuple(asarray(x) for x in result)
1180+
else:
1181+
return asarray(result)
1182+
1183+
11561184
###### mapping from numpy API objects to wrappers from this module ######
11571185

11581186
# All is in the mapping dict in _mapping.py

torch_np/tests/numpy_tests/lib/test_arraysetops.py

Lines changed: 46 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import torch_np as np
55

66
from torch_np.testing import (assert_array_equal, assert_equal)
7+
8+
from torch_np import unique
9+
710
from numpy.lib.arraysetops import (
8-
ediff1d, intersect1d, setxor1d, union1d, setdiff1d, unique, in1d, isin
11+
ediff1d, intersect1d, setxor1d, union1d, setdiff1d, in1d, isin
912
)
1013
import pytest
1114
from pytest import raises as assert_raises
@@ -615,7 +618,6 @@ def test_manyways(self):
615618
assert_array_equal(c1, c2)
616619

617620

618-
@pytest.mark.xfail(reason='TODO')
619621
class TestUnique:
620622

621623
def test_unique_1d(self):
@@ -627,10 +629,10 @@ def check_all(a, b, i1, i2, c, dt):
627629
v = unique(a)
628630
assert_array_equal(v, b, msg)
629631

630-
msg = base_msg.format('return_index', dt)
631-
v, j = unique(a, True, False, False)
632-
assert_array_equal(v, b, msg)
633-
assert_array_equal(j, i1, msg)
632+
# msg = base_msg.format('return_index', dt)
633+
# v, j = unique(a, True, False, False)
634+
# assert_array_equal(v, b, msg)
635+
# assert_array_equal(j, i1, msg)
634636

635637
msg = base_msg.format('return_inverse', dt)
636638
v, j = unique(a, False, True, False)
@@ -642,31 +644,31 @@ def check_all(a, b, i1, i2, c, dt):
642644
assert_array_equal(v, b, msg)
643645
assert_array_equal(j, c, msg)
644646

645-
msg = base_msg.format('return_index and return_inverse', dt)
646-
v, j1, j2 = unique(a, True, True, False)
647-
assert_array_equal(v, b, msg)
648-
assert_array_equal(j1, i1, msg)
649-
assert_array_equal(j2, i2, msg)
647+
# msg = base_msg.format('return_index and return_inverse', dt)
648+
# v, j1, j2 = unique(a, True, True, False)
649+
# assert_array_equal(v, b, msg)
650+
# assert_array_equal(j1, i1, msg)
651+
# assert_array_equal(j2, i2, msg)
650652

651-
msg = base_msg.format('return_index and return_counts', dt)
652-
v, j1, j2 = unique(a, True, False, True)
653-
assert_array_equal(v, b, msg)
654-
assert_array_equal(j1, i1, msg)
655-
assert_array_equal(j2, c, msg)
653+
# msg = base_msg.format('return_index and return_counts', dt)
654+
# v, j1, j2 = unique(a, True, False, True)
655+
# assert_array_equal(v, b, msg)
656+
# assert_array_equal(j1, i1, msg)
657+
# assert_array_equal(j2, c, msg)
656658

657659
msg = base_msg.format('return_inverse and return_counts', dt)
658660
v, j1, j2 = unique(a, False, True, True)
659661
assert_array_equal(v, b, msg)
660662
assert_array_equal(j1, i2, msg)
661663
assert_array_equal(j2, c, msg)
662664

663-
msg = base_msg.format(('return_index, return_inverse '
664-
'and return_counts'), dt)
665-
v, j1, j2, j3 = unique(a, True, True, True)
666-
assert_array_equal(v, b, msg)
667-
assert_array_equal(j1, i1, msg)
668-
assert_array_equal(j2, i2, msg)
669-
assert_array_equal(j3, c, msg)
665+
# msg = base_msg.format(('return_index, return_inverse '
666+
# 'and return_counts'), dt)
667+
# v, j1, j2, j3 = unique(a, True, True, True)
668+
# assert_array_equal(v, b, msg)
669+
# assert_array_equal(j1, i1, msg)
670+
# assert_array_equal(j2, i2, msg)
671+
# assert_array_equal(j3, c, msg)
670672

671673
a = [5, 7, 1, 2, 1, 5, 7]*10
672674
b = [1, 2, 5, 7]
@@ -678,30 +680,20 @@ def check_all(a, b, i1, i2, c, dt):
678680
types = []
679681
types.extend(np.typecodes['AllInteger'])
680682
types.extend(np.typecodes['AllFloat'])
681-
types.append('datetime64[D]')
682-
types.append('timedelta64[D]')
683683
for dt in types:
684+
685+
if dt in 'FD':
686+
# RuntimeError: "unique" not implemented for 'ComplexFloat'
687+
continue
688+
684689
aa = np.array(a, dt)
685690
bb = np.array(b, dt)
686691
check_all(aa, bb, i1, i2, c, dt)
687692

688-
# test for object arrays
689-
dt = 'O'
690-
aa = np.empty(len(a), dt)
691-
aa[:] = a
692-
bb = np.empty(len(b), dt)
693-
bb[:] = b
694-
check_all(aa, bb, i1, i2, c, dt)
695-
696-
# test for structured arrays
697-
dt = [('', 'i'), ('', 'i')]
698-
aa = np.array(list(zip(a, a)), dt)
699-
bb = np.array(list(zip(b, b)), dt)
700-
check_all(aa, bb, i1, i2, c, dt)
701-
702693
# test for ticket #2799
703-
aa = [1. + 0.j, 1 - 1.j, 1]
704-
assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])
694+
# RuntimeError: "unique" not implemented for 'ComplexFloat'
695+
# aa = [1. + 0.j, 1 - 1.j, 1]
696+
# assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])
705697

706698
# test for ticket #4785
707699
a = [(1, 2), (1, 2), (2, 3)]
@@ -713,23 +705,21 @@ def check_all(a, b, i1, i2, c, dt):
713705
assert_array_equal(a2, unq)
714706
assert_array_equal(a2_inv, inv)
715707

716-
# test for chararrays with return_inverse (gh-5099)
717-
a = np.chararray(5)
718-
a[...] = ''
719-
a2, a2_inv = np.unique(a, return_inverse=True)
720-
assert_array_equal(a2_inv, np.zeros(5))
721708

722709
# test for ticket #9137
723710
a = []
724-
a1_idx = np.unique(a, return_index=True)[1]
711+
# a1_idx = np.unique(a, return_index=True)[1]
725712
a2_inv = np.unique(a, return_inverse=True)[1]
726-
a3_idx, a3_inv = np.unique(a, return_index=True,
727-
return_inverse=True)[1:]
728-
assert_equal(a1_idx.dtype, np.intp)
713+
# a3_idx, a3_inv = np.unique(a, return_index=True,
714+
# return_inverse=True)[1:]
715+
# assert_equal(a1_idx.dtype, np.intp)
729716
assert_equal(a2_inv.dtype, np.intp)
730-
assert_equal(a3_idx.dtype, np.intp)
731-
assert_equal(a3_inv.dtype, np.intp)
717+
# assert_equal(a3_idx.dtype, np.intp)
718+
# assert_equal(a3_inv.dtype, np.intp)
732719

720+
721+
@pytest.mark.xfail(reason='unique with nans')
722+
def test_unique_1d_2(self):
733723
# test for ticket 2111 - float
734724
a = [2.0, np.nan, 1.0, np.nan]
735725
ua = [1.0, 2.0, np.nan]
@@ -752,30 +742,6 @@ def check_all(a, b, i1, i2, c, dt):
752742
assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))
753743
assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))
754744

755-
# test for ticket 2111 - datetime64
756-
nat = np.datetime64('nat')
757-
a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]
758-
ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]
759-
ua_idx = [2, 0, 1]
760-
ua_inv = [1, 2, 0, 2]
761-
ua_cnt = [1, 1, 2]
762-
assert_equal(np.unique(a), ua)
763-
assert_equal(np.unique(a, return_index=True), (ua, ua_idx))
764-
assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))
765-
assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))
766-
767-
# test for ticket 2111 - timedelta
768-
nat = np.timedelta64('nat')
769-
a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]
770-
ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]
771-
ua_idx = [2, 0, 1]
772-
ua_inv = [1, 2, 0, 2]
773-
ua_cnt = [1, 1, 2]
774-
assert_equal(np.unique(a), ua)
775-
assert_equal(np.unique(a, return_index=True), (ua, ua_idx))
776-
assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))
777-
assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))
778-
779745
# test for gh-19300
780746
all_nans = [np.nan] * 4
781747
ua = [np.nan]
@@ -802,14 +768,11 @@ def test_unique_axis_list(self):
802768
assert_array_equal(unique(inp, axis=0), unique(inp_arr, axis=0), msg)
803769
assert_array_equal(unique(inp, axis=1), unique(inp_arr, axis=1), msg)
804770

771+
@pytest.mark.xfail(reason='TODO: implement take')
805772
def test_unique_axis(self):
806773
types = []
807774
types.extend(np.typecodes['AllInteger'])
808775
types.extend(np.typecodes['AllFloat'])
809-
types.append('datetime64[D]')
810-
types.append('timedelta64[D]')
811-
types.append([('a', int), ('b', int)])
812-
types.append([('a', int), ('b', float)])
813776

814777
for dtype in types:
815778
self._run_axis_tests(dtype)
@@ -830,6 +793,7 @@ def test_unique_1d_with_axis(self, axis):
830793
uniq = unique(x, axis=axis)
831794
assert_array_equal(uniq, [1, 2, 3, 4])
832795

796+
@pytest.mark.xfail(reason='unique / return_index')
833797
def test_unique_axis_zeros(self):
834798
# issue 15559
835799
single_zero = np.empty(shape=(2, 0), dtype=np.int8)
@@ -866,24 +830,11 @@ def test_unique_axis_zeros(self):
866830
assert_array_equal(unique(multiple_zeros, axis=axis),
867831
np.empty(shape=expected_shape))
868832

869-
def test_unique_masked(self):
870-
# issue 8664
871-
x = np.array([64, 0, 1, 2, 3, 63, 63, 0, 0, 0, 1, 2, 0, 63, 0],
872-
dtype='uint8')
873-
y = np.ma.masked_equal(x, 0)
874-
875-
v = np.unique(y)
876-
v2, i, c = np.unique(y, return_index=True, return_counts=True)
877-
878-
msg = 'Unique returned different results when asked for index'
879-
assert_array_equal(v.data, v2.data, msg)
880-
assert_array_equal(v.mask, v2.mask, msg)
881-
882833
def test_unique_sort_order_with_axis(self):
883834
# These tests fail if sorting along axis is done by treating subarrays
884835
# as unsigned byte strings. See gh-10495.
885836
fmt = "sort order incorrect for integer type '%s'"
886-
for dt in 'bhilq':
837+
for dt in 'bhil':
887838
a = np.array([[-1], [0]], dt)
888839
b = np.unique(a, axis=0)
889840
assert_array_equal(a, b, fmt % dt)
@@ -932,6 +883,7 @@ def _run_axis_tests(self, dtype):
932883
msg = "Unique's return_counts=True failed with axis=1"
933884
assert_array_equal(cnt, np.array([2, 1, 1]), msg)
934885

886+
@pytest.mark.xfail(reason='unique / return_index / nans')
935887
def test_unique_nanequals(self):
936888
# issue 20326
937889
a = np.array([1, 1, np.nan, np.nan, np.nan])

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
bartlett, blackman,
3030
delete, digitize, extract, gradient, hamming, hanning,
3131
insert, interp, kaiser, msort, piecewise, place,
32-
select, setxor1d, trapz, trim_zeros, unique, unwrap, vectorize
32+
select, setxor1d, trapz, trim_zeros, unwrap, vectorize
3333
)
3434
from torch_np._detail._util import normalize_axis_tuple
3535

36-
from torch_np import corrcoef, cov, i0, angle, sinc, diff, meshgrid
36+
from torch_np import corrcoef, cov, i0, angle, sinc, diff, meshgrid, unique
3737

3838
def get_mat(n):
3939
data = np.arange(n)
@@ -1864,15 +1864,16 @@ def test_array_like(self):
18641864
assert_array_equal(y1, y3)
18651865

18661866

1867-
@pytest.mark.xfail(reason='TODO: implement')
18681867
class TestUnique:
18691868

18701869
def test_simple(self):
18711870
x = np.array([4, 3, 2, 1, 1, 2, 3, 4, 0])
18721871
assert_(np.all(unique(x) == [0, 1, 2, 3, 4]))
1872+
18731873
assert_(unique(np.array([1, 1, 1, 1, 1])) == np.array([1]))
1874-
x = ['widget', 'ham', 'foo', 'bar', 'foo', 'ham']
1875-
assert_(np.all(unique(x) == ['bar', 'foo', 'ham', 'widget']))
1874+
1875+
@pytest.mark.xfail(reason="unique not implemented for 'ComplexDouble'")
1876+
def test_simple_complex(self):
18761877
x = np.array([5 + 6j, 1 + 1j, 1 + 10j, 10, 5 + 6j])
18771878
assert_(np.all(unique(x) == [1 + 1j, 1 + 10j, 5 + 6j, 10]))
18781879

0 commit comments

Comments
 (0)