Skip to content

Commit 6ba8cfb

Browse files
committed
TST: un-xfail where(...) tests
1 parent 2ff2e00 commit 6ba8cfb

File tree

3 files changed

+16
-22
lines changed

3 files changed

+16
-22
lines changed

torch_np/_detail/implementations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,16 @@ def where(condition, x, y):
683683
if not selector:
684684
raise ValueError("either both or neither of x and y should be given")
685685

686+
if condition.dtype != torch.bool:
687+
condition = condition.to(torch.bool)
688+
686689
if x is None and y is None:
687690
result = torch.where(condition)
688691
else:
689-
result = torch.where(condition, x, y)
692+
try:
693+
result = torch.where(condition, x, y)
694+
except RuntimeError as e:
695+
raise ValueError(*e.args)
690696
return result
691697

692698

torch_np/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def array_or_scalar(values, py_type=float):
3838

3939
def seed(seed=None):
4040
if seed is not None:
41-
torch.random.manual_seed()
41+
torch.random.manual_seed(seed)
4242

4343

4444
def random_sample(size=None):

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7610,7 +7610,7 @@ def __int__(self):
76107610
assert_raises(NotImplementedError,
76117611
int_func, np.array([NotConvertible()]))
76127612

7613-
@pytest.mark.xfail(reason='TODO')
7613+
76147614
class TestWhere:
76157615
def test_basic(self):
76167616
dts = [bool, np.int16, np.int32, np.int64, np.double, np.complex128]
@@ -7633,18 +7633,18 @@ def test_basic(self):
76337633
assert_equal(np.where(c[1::2], d[1::2], e[1::2]), r[1::2])
76347634
assert_equal(np.where(c[::3], d[::3], e[::3]), r[::3])
76357635
assert_equal(np.where(c[1::3], d[1::3], e[1::3]), r[1::3])
7636-
assert_equal(np.where(c[::-2], d[::-2], e[::-2]), r[::-2])
7637-
assert_equal(np.where(c[::-3], d[::-3], e[::-3]), r[::-3])
7638-
assert_equal(np.where(c[1::-3], d[1::-3], e[1::-3]), r[1::-3])
7636+
# assert_equal(np.where(c[::-2], d[::-2], e[::-2]), r[::-2])
7637+
# assert_equal(np.where(c[::-3], d[::-3], e[::-3]), r[::-3])
7638+
# assert_equal(np.where(c[1::-3], d[1::-3], e[1::-3]), r[1::-3])
76397639

76407640
def test_exotic(self):
7641-
# object
7642-
assert_array_equal(np.where(True, None, None), np.array(None))
76437641
# zero sized
76447642
m = np.array([], dtype=bool).reshape(0, 3)
76457643
b = np.array([], dtype=np.float64).reshape(0, 3)
76467644
assert_array_equal(np.where(m, 0, b), np.array([]).reshape(0, 3))
76477645

7646+
@pytest.mark.skip(reason='object arrays')
7647+
def test_exotic_2(self):
76487648
# object cast
76497649
d = np.array([-1.34, -0.16, -0.54, -0.31, -0.08, -0.95, 0.000, 0.313,
76507650
0.547, -0.18, 0.876, 0.236, 1.969, 0.310, 0.699, 1.013,
@@ -7695,7 +7695,7 @@ def test_ndim(self):
76957695
def test_dtype_mix(self):
76967696
c = np.array([False, True, False, False, False, False, True, False,
76977697
False, False, True, False])
7698-
a = np.uint32(1)
7698+
a = np.uint8(1)
76997699
b = np.array([5., 0., 3., 2., -1., -4., 0., -10., 10., 1., 0., 3.],
77007700
dtype=np.float64)
77017701
r = np.array([5., 1., 3., 2., -1., -4., 1., -10., 10., 1., 1., 3.],
@@ -7716,6 +7716,7 @@ def test_dtype_mix(self):
77167716
c[tmpmask] = 0
77177717
assert_equal(np.where(c, b, a), r)
77187718

7719+
@pytest.mark.skip(reason='endianness')
77197720
def test_foreign(self):
77207721
c = np.array([False, True, False, False, False, False, True, False,
77217722
False, False, True, False])
@@ -7742,19 +7743,6 @@ def test_error(self):
77427743
assert_raises(ValueError, np.where, c, a, a)
77437744
assert_raises(ValueError, np.where, c[0], a, b)
77447745

7745-
def test_string(self):
7746-
# gh-4778 check strings are properly filled with nulls
7747-
a = np.array("abc")
7748-
b = np.array("x" * 753)
7749-
assert_equal(np.where(True, a, b), "abc")
7750-
assert_equal(np.where(False, b, a), "abc")
7751-
7752-
# check native datatype sized strings
7753-
a = np.array("abcd")
7754-
b = np.array("x" * 8)
7755-
assert_equal(np.where(True, a, b), "abcd")
7756-
assert_equal(np.where(False, b, a), "abcd")
7757-
77587746
def test_empty_result(self):
77597747
# pass empty where result through an assignment which reads the data of
77607748
# empty arrays, error detectable with valgrind, see gh-8922

0 commit comments

Comments
 (0)