Skip to content

Commit 60cadce

Browse files
committed
TST: un-xfail tests for argmin/argmax
1 parent e4f80b6 commit 60cadce

File tree

1 file changed

+32
-62
lines changed

1 file changed

+32
-62
lines changed

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 32 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,6 +3711,7 @@ def test_np_argmin_argmax_keepdims(self, size, axis, method):
37113711
method(arr.T, axis=axis,
37123712
out=wrong_outarray, keepdims=True)
37133713

3714+
@pytest.mark.xfail(reason="TODO: implement choose")
37143715
@pytest.mark.parametrize('method', ['max', 'min'])
37153716
def test_all(self, method):
37163717
a = np.random.normal(0, 1, (4, 5, 6, 7, 8))
@@ -3753,15 +3754,7 @@ def test_ret_is_out(self, ndim, method):
37533754
ret = arg_method(axis=0, out=out)
37543755
assert ret is out
37553756

3756-
@pytest.mark.parametrize('np_array, method, idx, val',
3757-
[(np.zeros, 'argmax', 5942, "as"),
3758-
(np.ones, 'argmin', 6001, "0")])
3759-
def test_unicode(self, np_array, method, idx, val):
3760-
d = np_array(6031, dtype='<U9')
3761-
arg_method = getattr(d, method)
3762-
d[idx] = val
3763-
assert_equal(arg_method(), idx)
3764-
3757+
@pytest.mark.xfail(reason='FIXME: keepdims w/ positional args?')
37653758
@pytest.mark.parametrize('arr_method, np_method',
37663759
[('argmax', np.argmax),
37673760
('argmin', np.argmin)])
@@ -3784,22 +3777,7 @@ def test_np_vs_ndarray(self, arr_method, np_method):
37843777
np_method(a, out=out2, axis=0))
37853778
assert_equal(out1, out2)
37863779

3787-
@pytest.mark.leaks_references(reason="replaces None with NULL.")
3788-
@pytest.mark.parametrize('method, vals',
3789-
[('argmax', (10, 30)),
3790-
('argmin', (30, 10))])
3791-
def test_object_with_NULLs(self, method, vals):
3792-
# See gh-6032
3793-
a = np.empty(4, dtype='O')
3794-
arg_method = getattr(a, method)
3795-
ctypes.memset(a.ctypes.data, 0, a.nbytes)
3796-
assert_equal(arg_method(), 0)
3797-
a[3] = vals[0]
3798-
assert_equal(arg_method(), 3)
3799-
a[1] = vals[1]
3800-
assert_equal(arg_method(), 1)
38013780

3802-
@pytest.mark.xfail(reason='TODO')
38033781
class TestArgmax:
38043782
usg_data = [
38053783
([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], 0),
@@ -3839,19 +3817,20 @@ class TestArgmax:
38393817
))
38403818
)]
38413819
nan_arr = darr + [
3842-
([0, 1, 2, 3, complex(0, np.nan)], 4),
3843-
([0, 1, 2, 3, complex(np.nan, 0)], 4),
3844-
([0, 1, 2, complex(np.nan, 0), 3], 3),
3845-
([0, 1, 2, complex(0, np.nan), 3], 3),
3846-
([complex(0, np.nan), 0, 1, 2, 3], 0),
3847-
([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
3848-
([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
3849-
([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
3850-
([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
3851-
3852-
([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
3853-
([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
3854-
([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
3820+
# RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble'
3821+
# ([0, 1, 2, 3, complex(0, np.nan)], 4),
3822+
# ([0, 1, 2, 3, complex(np.nan, 0)], 4),
3823+
# ([0, 1, 2, complex(np.nan, 0), 3], 3),
3824+
# ([0, 1, 2, complex(0, np.nan), 3], 3),
3825+
# ([complex(0, np.nan), 0, 1, 2, 3], 0),
3826+
# ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
3827+
# ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
3828+
# ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
3829+
# ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
3830+
3831+
# ([complex(0, 0), complex(0, 2), complex(0, 1)], 1),
3832+
# ([complex(1, 0), complex(0, 2), complex(0, 1)], 0),
3833+
# ([complex(1, 0), complex(0, 2), complex(1, 1)], 2),
38553834

38563835
([False, False, False, False, True], 4),
38573836
([False, False, False, True, False], 3),
@@ -3905,7 +3884,7 @@ def test_maximum_signed_integers(self):
39053884
a = a.repeat(129)
39063885
assert_equal(np.argmax(a), 129)
39073886

3908-
@pytest.mark.xfail(reason='TODO')
3887+
39093888
class TestArgmin:
39103889
usg_data = [
39113890
([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], 8),
@@ -3945,19 +3924,20 @@ class TestArgmin:
39453924
))
39463925
)]
39473926
nan_arr = darr + [
3948-
([0, 1, 2, 3, complex(0, np.nan)], 4),
3949-
([0, 1, 2, 3, complex(np.nan, 0)], 4),
3950-
([0, 1, 2, complex(np.nan, 0), 3], 3),
3951-
([0, 1, 2, complex(0, np.nan), 3], 3),
3952-
([complex(0, np.nan), 0, 1, 2, 3], 0),
3953-
([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
3954-
([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
3955-
([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
3956-
([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
3957-
3958-
([complex(0, 0), complex(0, 2), complex(0, 1)], 0),
3959-
([complex(1, 0), complex(0, 2), complex(0, 1)], 2),
3960-
([complex(1, 0), complex(0, 2), complex(1, 1)], 1),
3927+
# RuntimeError: "min_values_cpu" not implemented for 'ComplexDouble'
3928+
# ([0, 1, 2, 3, complex(0, np.nan)], 4),
3929+
# ([0, 1, 2, 3, complex(np.nan, 0)], 4),
3930+
# ([0, 1, 2, complex(np.nan, 0), 3], 3),
3931+
# ([0, 1, 2, complex(0, np.nan), 3], 3),
3932+
# ([complex(0, np.nan), 0, 1, 2, 3], 0),
3933+
# ([complex(np.nan, np.nan), 0, 1, 2, 3], 0),
3934+
# ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, 1)], 0),
3935+
# ([complex(np.nan, np.nan), complex(np.nan, 2), complex(np.nan, 1)], 0),
3936+
# ([complex(np.nan, 0), complex(np.nan, 2), complex(np.nan, np.nan)], 0),
3937+
3938+
# ([complex(0, 0), complex(0, 2), complex(0, 1)], 0),
3939+
# ([complex(1, 0), complex(0, 2), complex(0, 1)], 2),
3940+
# ([complex(1, 0), complex(0, 2), complex(1, 1)], 1),
39613941

39623942
([True, True, True, True, False], 4),
39633943
([True, True, True, False, True], 3),
@@ -4010,7 +3990,7 @@ def test_minimum_signed_integers(self):
40103990
a = a.repeat(129)
40113991
assert_equal(np.argmin(a), 129)
40123992

4013-
@pytest.mark.xfail(reason='TODO')
3993+
40143994
class TestMinMax:
40153995

40163996
def test_scalar(self):
@@ -4026,16 +4006,6 @@ def test_axis(self):
40264006
assert_raises(np.AxisError, np.amax, [1, 2, 3], 1000)
40274007
assert_equal(np.amax([[1, 2, 3]], axis=1), 3)
40284008

4029-
def test_datetime(self):
4030-
# Do not ignore NaT
4031-
for dtype in ('m8[s]', 'm8[Y]'):
4032-
a = np.arange(10).astype(dtype)
4033-
assert_equal(np.amin(a), a[0])
4034-
assert_equal(np.amax(a), a[9])
4035-
a[3] = 'NaT'
4036-
assert_equal(np.amin(a), a[3])
4037-
assert_equal(np.amax(a), a[3])
4038-
40394009

40404010
class TestNewaxis:
40414011
def test_basic(self):

0 commit comments

Comments
 (0)