4
4
import torch_np as np
5
5
6
6
from torch_np .testing import (assert_array_equal , assert_equal )
7
+
8
+ from torch_np import unique
9
+
7
10
from numpy .lib .arraysetops import (
8
- ediff1d , intersect1d , setxor1d , union1d , setdiff1d , unique , in1d , isin
11
+ ediff1d , intersect1d , setxor1d , union1d , setdiff1d , in1d , isin
9
12
)
10
13
import pytest
11
14
from pytest import raises as assert_raises
@@ -615,7 +618,6 @@ def test_manyways(self):
615
618
assert_array_equal (c1 , c2 )
616
619
617
620
618
- @pytest .mark .xfail (reason = 'TODO' )
619
621
class TestUnique :
620
622
621
623
def test_unique_1d (self ):
@@ -627,10 +629,10 @@ def check_all(a, b, i1, i2, c, dt):
627
629
v = unique (a )
628
630
assert_array_equal (v , b , msg )
629
631
630
- msg = base_msg .format ('return_index' , dt )
631
- v , j = unique (a , True , False , False )
632
- assert_array_equal (v , b , msg )
633
- assert_array_equal (j , i1 , msg )
632
+ # msg = base_msg.format('return_index', dt)
633
+ # v, j = unique(a, True, False, False)
634
+ # assert_array_equal(v, b, msg)
635
+ # assert_array_equal(j, i1, msg)
634
636
635
637
msg = base_msg .format ('return_inverse' , dt )
636
638
v , j = unique (a , False , True , False )
@@ -642,31 +644,31 @@ def check_all(a, b, i1, i2, c, dt):
642
644
assert_array_equal (v , b , msg )
643
645
assert_array_equal (j , c , msg )
644
646
645
- msg = base_msg .format ('return_index and return_inverse' , dt )
646
- v , j1 , j2 = unique (a , True , True , False )
647
- assert_array_equal (v , b , msg )
648
- assert_array_equal (j1 , i1 , msg )
649
- assert_array_equal (j2 , i2 , msg )
647
+ # msg = base_msg.format('return_index and return_inverse', dt)
648
+ # v, j1, j2 = unique(a, True, True, False)
649
+ # assert_array_equal(v, b, msg)
650
+ # assert_array_equal(j1, i1, msg)
651
+ # assert_array_equal(j2, i2, msg)
650
652
651
- msg = base_msg .format ('return_index and return_counts' , dt )
652
- v , j1 , j2 = unique (a , True , False , True )
653
- assert_array_equal (v , b , msg )
654
- assert_array_equal (j1 , i1 , msg )
655
- assert_array_equal (j2 , c , msg )
653
+ # msg = base_msg.format('return_index and return_counts', dt)
654
+ # v, j1, j2 = unique(a, True, False, True)
655
+ # assert_array_equal(v, b, msg)
656
+ # assert_array_equal(j1, i1, msg)
657
+ # assert_array_equal(j2, c, msg)
656
658
657
659
msg = base_msg .format ('return_inverse and return_counts' , dt )
658
660
v , j1 , j2 = unique (a , False , True , True )
659
661
assert_array_equal (v , b , msg )
660
662
assert_array_equal (j1 , i2 , msg )
661
663
assert_array_equal (j2 , c , msg )
662
664
663
- msg = base_msg .format (('return_index, return_inverse '
664
- 'and return_counts' ), dt )
665
- v , j1 , j2 , j3 = unique (a , True , True , True )
666
- assert_array_equal (v , b , msg )
667
- assert_array_equal (j1 , i1 , msg )
668
- assert_array_equal (j2 , i2 , msg )
669
- assert_array_equal (j3 , c , msg )
665
+ # msg = base_msg.format(('return_index, return_inverse '
666
+ # 'and return_counts'), dt)
667
+ # v, j1, j2, j3 = unique(a, True, True, True)
668
+ # assert_array_equal(v, b, msg)
669
+ # assert_array_equal(j1, i1, msg)
670
+ # assert_array_equal(j2, i2, msg)
671
+ # assert_array_equal(j3, c, msg)
670
672
671
673
a = [5 , 7 , 1 , 2 , 1 , 5 , 7 ]* 10
672
674
b = [1 , 2 , 5 , 7 ]
@@ -678,30 +680,20 @@ def check_all(a, b, i1, i2, c, dt):
678
680
types = []
679
681
types .extend (np .typecodes ['AllInteger' ])
680
682
types .extend (np .typecodes ['AllFloat' ])
681
- types .append ('datetime64[D]' )
682
- types .append ('timedelta64[D]' )
683
683
for dt in types :
684
+
685
+ if dt in 'FD' :
686
+ # RuntimeError: "unique" not implemented for 'ComplexFloat'
687
+ continue
688
+
684
689
aa = np .array (a , dt )
685
690
bb = np .array (b , dt )
686
691
check_all (aa , bb , i1 , i2 , c , dt )
687
692
688
- # test for object arrays
689
- dt = 'O'
690
- aa = np .empty (len (a ), dt )
691
- aa [:] = a
692
- bb = np .empty (len (b ), dt )
693
- bb [:] = b
694
- check_all (aa , bb , i1 , i2 , c , dt )
695
-
696
- # test for structured arrays
697
- dt = [('' , 'i' ), ('' , 'i' )]
698
- aa = np .array (list (zip (a , a )), dt )
699
- bb = np .array (list (zip (b , b )), dt )
700
- check_all (aa , bb , i1 , i2 , c , dt )
701
-
702
693
# test for ticket #2799
703
- aa = [1. + 0.j , 1 - 1.j , 1 ]
704
- assert_array_equal (np .unique (aa ), [1. - 1.j , 1. + 0.j ])
694
+ # RuntimeError: "unique" not implemented for 'ComplexFloat'
695
+ # aa = [1. + 0.j, 1 - 1.j, 1]
696
+ # assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])
705
697
706
698
# test for ticket #4785
707
699
a = [(1 , 2 ), (1 , 2 ), (2 , 3 )]
@@ -713,23 +705,21 @@ def check_all(a, b, i1, i2, c, dt):
713
705
assert_array_equal (a2 , unq )
714
706
assert_array_equal (a2_inv , inv )
715
707
716
- # test for chararrays with return_inverse (gh-5099)
717
- a = np .chararray (5 )
718
- a [...] = ''
719
- a2 , a2_inv = np .unique (a , return_inverse = True )
720
- assert_array_equal (a2_inv , np .zeros (5 ))
721
708
722
709
# test for ticket #9137
723
710
a = []
724
- a1_idx = np .unique (a , return_index = True )[1 ]
711
+ # a1_idx = np.unique(a, return_index=True)[1]
725
712
a2_inv = np .unique (a , return_inverse = True )[1 ]
726
- a3_idx , a3_inv = np .unique (a , return_index = True ,
727
- return_inverse = True )[1 :]
728
- assert_equal (a1_idx .dtype , np .intp )
713
+ # a3_idx, a3_inv = np.unique(a, return_index=True,
714
+ # return_inverse=True)[1:]
715
+ # assert_equal(a1_idx.dtype, np.intp)
729
716
assert_equal (a2_inv .dtype , np .intp )
730
- assert_equal (a3_idx .dtype , np .intp )
731
- assert_equal (a3_inv .dtype , np .intp )
717
+ # assert_equal(a3_idx.dtype, np.intp)
718
+ # assert_equal(a3_inv.dtype, np.intp)
732
719
720
+
721
+ @pytest .mark .xfail (reason = 'unique with nans' )
722
+ def test_unique_1d_2 (self ):
733
723
# test for ticket 2111 - float
734
724
a = [2.0 , np .nan , 1.0 , np .nan ]
735
725
ua = [1.0 , 2.0 , np .nan ]
@@ -752,30 +742,6 @@ def check_all(a, b, i1, i2, c, dt):
752
742
assert_equal (np .unique (a , return_inverse = True ), (ua , ua_inv ))
753
743
assert_equal (np .unique (a , return_counts = True ), (ua , ua_cnt ))
754
744
755
- # test for ticket 2111 - datetime64
756
- nat = np .datetime64 ('nat' )
757
- a = [np .datetime64 ('2020-12-26' ), nat , np .datetime64 ('2020-12-24' ), nat ]
758
- ua = [np .datetime64 ('2020-12-24' ), np .datetime64 ('2020-12-26' ), nat ]
759
- ua_idx = [2 , 0 , 1 ]
760
- ua_inv = [1 , 2 , 0 , 2 ]
761
- ua_cnt = [1 , 1 , 2 ]
762
- assert_equal (np .unique (a ), ua )
763
- assert_equal (np .unique (a , return_index = True ), (ua , ua_idx ))
764
- assert_equal (np .unique (a , return_inverse = True ), (ua , ua_inv ))
765
- assert_equal (np .unique (a , return_counts = True ), (ua , ua_cnt ))
766
-
767
- # test for ticket 2111 - timedelta
768
- nat = np .timedelta64 ('nat' )
769
- a = [np .timedelta64 (1 , 'D' ), nat , np .timedelta64 (1 , 'h' ), nat ]
770
- ua = [np .timedelta64 (1 , 'h' ), np .timedelta64 (1 , 'D' ), nat ]
771
- ua_idx = [2 , 0 , 1 ]
772
- ua_inv = [1 , 2 , 0 , 2 ]
773
- ua_cnt = [1 , 1 , 2 ]
774
- assert_equal (np .unique (a ), ua )
775
- assert_equal (np .unique (a , return_index = True ), (ua , ua_idx ))
776
- assert_equal (np .unique (a , return_inverse = True ), (ua , ua_inv ))
777
- assert_equal (np .unique (a , return_counts = True ), (ua , ua_cnt ))
778
-
779
745
# test for gh-19300
780
746
all_nans = [np .nan ] * 4
781
747
ua = [np .nan ]
@@ -802,14 +768,11 @@ def test_unique_axis_list(self):
802
768
assert_array_equal (unique (inp , axis = 0 ), unique (inp_arr , axis = 0 ), msg )
803
769
assert_array_equal (unique (inp , axis = 1 ), unique (inp_arr , axis = 1 ), msg )
804
770
771
+ @pytest .mark .xfail (reason = 'TODO: implement take' )
805
772
def test_unique_axis (self ):
806
773
types = []
807
774
types .extend (np .typecodes ['AllInteger' ])
808
775
types .extend (np .typecodes ['AllFloat' ])
809
- types .append ('datetime64[D]' )
810
- types .append ('timedelta64[D]' )
811
- types .append ([('a' , int ), ('b' , int )])
812
- types .append ([('a' , int ), ('b' , float )])
813
776
814
777
for dtype in types :
815
778
self ._run_axis_tests (dtype )
@@ -830,6 +793,7 @@ def test_unique_1d_with_axis(self, axis):
830
793
uniq = unique (x , axis = axis )
831
794
assert_array_equal (uniq , [1 , 2 , 3 , 4 ])
832
795
796
+ @pytest .mark .xfail (reason = 'unique / return_index' )
833
797
def test_unique_axis_zeros (self ):
834
798
# issue 15559
835
799
single_zero = np .empty (shape = (2 , 0 ), dtype = np .int8 )
@@ -866,24 +830,11 @@ def test_unique_axis_zeros(self):
866
830
assert_array_equal (unique (multiple_zeros , axis = axis ),
867
831
np .empty (shape = expected_shape ))
868
832
869
- def test_unique_masked (self ):
870
- # issue 8664
871
- x = np .array ([64 , 0 , 1 , 2 , 3 , 63 , 63 , 0 , 0 , 0 , 1 , 2 , 0 , 63 , 0 ],
872
- dtype = 'uint8' )
873
- y = np .ma .masked_equal (x , 0 )
874
-
875
- v = np .unique (y )
876
- v2 , i , c = np .unique (y , return_index = True , return_counts = True )
877
-
878
- msg = 'Unique returned different results when asked for index'
879
- assert_array_equal (v .data , v2 .data , msg )
880
- assert_array_equal (v .mask , v2 .mask , msg )
881
-
882
833
def test_unique_sort_order_with_axis (self ):
883
834
# These tests fail if sorting along axis is done by treating subarrays
884
835
# as unsigned byte strings. See gh-10495.
885
836
fmt = "sort order incorrect for integer type '%s'"
886
- for dt in 'bhilq ' :
837
+ for dt in 'bhil ' :
887
838
a = np .array ([[- 1 ], [0 ]], dt )
888
839
b = np .unique (a , axis = 0 )
889
840
assert_array_equal (a , b , fmt % dt )
@@ -932,6 +883,7 @@ def _run_axis_tests(self, dtype):
932
883
msg = "Unique's return_counts=True failed with axis=1"
933
884
assert_array_equal (cnt , np .array ([2 , 1 , 1 ]), msg )
934
885
886
+ @pytest .mark .xfail (reason = 'unique / return_index / nans' )
935
887
def test_unique_nanequals (self ):
936
888
# issue 20326
937
889
a = np .array ([1 , 1 , np .nan , np .nan , np .nan ])
0 commit comments