Skip to content

Commit 806b4ac

Browse files
committed
add ndarray.conj, conjugate; un-xfail test_multiarray::TestStats
1 parent 3b5eeea commit 806b4ac

File tree

3 files changed

+31
-48
lines changed

3 files changed

+31
-48
lines changed

torch_np/_detail/_reductions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,18 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
143143

144144
dtype = _atleast_float(dtype, tensor.dtype)
145145

146+
is_half = dtype == torch.float16
147+
if is_half:
148+
dtype=torch.float32
149+
146150
if axis is None:
147151
result = tensor.mean(dtype=dtype)
148152
else:
149153
result = tensor.mean(dtype=dtype, dim=axis)
150154

155+
if is_half:
156+
result = result.to(torch.float16)
157+
151158
return result
152159

153160

torch_np/_ndarray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,9 @@ def __irshift__(self, other):
343343
__pos__ = _unary_ufuncs.positive
344344
__neg__ = _unary_ufuncs.negative
345345

346+
conjugate = _unary_ufuncs.conjugate
347+
conj = conjugate
348+
346349
### methods to match namespace functions
347350

348351
def squeeze(self, axis=None):

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5046,17 +5046,14 @@ def _std(a, **args):
50465046
return a.std(**args)
50475047

50485048

5049-
@pytest.mark.xfail(reason='TODO')
50505049
class TestStats:
50515050

50525051
funcs = [_mean, _var, _std]
50535052

50545053
def setup_method(self):
5055-
np.random.seed(range(3))
5054+
np.random.seed(3)
50565055
self.rmat = np.random.random((4, 5))
50575056
self.cmat = self.rmat + 1j * self.rmat
5058-
self.omat = np.array([Decimal(repr(r)) for r in self.rmat.flat])
5059-
self.omat = self.omat.reshape(4, 5)
50605057

50615058
def test_python_type(self):
50625059
for x in (np.float16(1.), 1, 1., 1+0j):
@@ -5093,16 +5090,6 @@ def test_dtype_from_input(self):
50935090
icodes = np.typecodes['AllInteger']
50945091
fcodes = np.typecodes['AllFloat']
50955092

5096-
# object type
5097-
for f in self.funcs:
5098-
mat = np.array([[Decimal(1)]*3]*3)
5099-
tgt = mat.dtype.type
5100-
res = f(mat, axis=1).dtype.type
5101-
assert_(res is tgt)
5102-
# scalar case
5103-
res = type(f(mat, axis=None))
5104-
assert_(res is Decimal)
5105-
51065093
# integer types
51075094
for f in self.funcs:
51085095
for c in icodes:
@@ -5137,6 +5124,7 @@ def test_dtype_from_input(self):
51375124
res = f(mat, axis=None).dtype.type
51385125
assert_(res is tgt)
51395126

5127+
@pytest.mark.xfail(reason='TODO: dtype in reductions')
51405128
def test_dtype_from_dtype(self):
51415129
mat = np.eye(3)
51425130

@@ -5182,29 +5170,29 @@ def test_ddof_too_big(self):
51825170
dim = self.rmat.shape[1]
51835171
for f in [_var, _std]:
51845172
for ddof in range(dim, dim + 2):
5185-
with warnings.catch_warnings(record=True) as w:
5186-
warnings.simplefilter('always')
5173+
# with warnings.catch_warnings(record=True) as w:
5174+
# warnings.simplefilter('always')
51875175
res = f(self.rmat, axis=1, ddof=ddof)
51885176
assert_(not (res < 0).any())
5189-
assert_(len(w) > 0)
5190-
assert_(issubclass(w[0].category, RuntimeWarning))
5177+
# assert_(len(w) > 0)
5178+
# assert_(issubclass(w[0].category, RuntimeWarning))
51915179

51925180
def test_empty(self):
51935181
A = np.zeros((0, 3))
51945182
for f in self.funcs:
51955183
for axis in [0, None]:
5196-
with warnings.catch_warnings(record=True) as w:
5197-
warnings.simplefilter('always')
5184+
# with warnings.catch_warnings(record=True) as w:
5185+
# warnings.simplefilter('always')
51985186
assert_(np.isnan(f(A, axis=axis)).all())
5199-
assert_(len(w) > 0)
5200-
assert_(issubclass(w[0].category, RuntimeWarning))
5187+
# assert_(len(w) > 0)
5188+
# assert_(issubclass(w[0].category, RuntimeWarning))
52015189
for axis in [1]:
5202-
with warnings.catch_warnings(record=True) as w:
5203-
warnings.simplefilter('always')
5190+
# with warnings.catch_warnings(record=True) as w:
5191+
# warnings.simplefilter('always')
52045192
assert_equal(f(A, axis=axis), np.zeros([]))
52055193

52065194
def test_mean_values(self):
5207-
for mat in [self.rmat, self.cmat, self.omat]:
5195+
for mat in [self.rmat, self.cmat]:
52085196
for axis in [0, 1]:
52095197
tgt = mat.sum(axis=axis)
52105198
res = _mean(mat, axis=axis) * mat.shape[axis]
@@ -5222,9 +5210,10 @@ def test_mean_float16(self):
52225210
def test_mean_axis_error(self):
52235211
# Ensure that AxisError is raised instead of IndexError when axis is
52245212
# out of bounds, see gh-15817.
5225-
with assert_raises(np.exceptions.AxisError):
5213+
with assert_raises(np.AxisError):
52265214
np.arange(10).mean(axis=2)
52275215

5216+
@pytest.mark.xfail(reason='implement mean(..., where=...)')
52285217
def test_mean_where(self):
52295218
a = np.arange(16).reshape((4, 4))
52305219
wh_full = np.array([[False, True, False, True],
@@ -5262,7 +5251,7 @@ def test_mean_where(self):
52625251
assert_equal(np.mean(a, where=False), np.nan)
52635252

52645253
def test_var_values(self):
5265-
for mat in [self.rmat, self.cmat, self.omat]:
5254+
for mat in [self.rmat, self.cmat]:
52665255
for axis in [0, 1, None]:
52675256
msqr = _mean(mat * mat.conj(), axis=axis)
52685257
mean = _mean(mat, axis=axis)
@@ -5295,6 +5284,7 @@ def test_var_dimensions(self):
52955284
res = _var(mat, axis=axis)
52965285
assert_almost_equal(res, tgt)
52975286

5287+
@pytest.mark.skip(reason='endianness')
52985288
def test_var_complex_byteorder(self):
52995289
# Test that var fast-path does not cause failures for complex arrays
53005290
# with non-native byteorder
@@ -5305,9 +5295,10 @@ def test_var_complex_byteorder(self):
53055295
def test_var_axis_error(self):
53065296
# Ensure that AxisError is raised instead of IndexError when axis is
53075297
# out of bounds, see gh-15817.
5308-
with assert_raises(np.exceptions.AxisError):
5298+
with assert_raises(np.AxisError):
53095299
np.arange(10).var(axis=2)
53105300

5301+
@pytest.mark.xfail(reason="implement var(..., where=...)")
53115302
def test_var_where(self):
53125303
a = np.arange(25).reshape((5, 5))
53135304
wh_full = np.array([[False, True, False, True, True],
@@ -5346,12 +5337,13 @@ def test_var_where(self):
53465337
assert_equal(np.var(a, where=False), np.nan)
53475338

53485339
def test_std_values(self):
5349-
for mat in [self.rmat, self.cmat, self.omat]:
5340+
for mat in [self.rmat, self.cmat]:
53505341
for axis in [0, 1, None]:
53515342
tgt = np.sqrt(_var(mat, axis=axis))
53525343
res = _std(mat, axis=axis)
53535344
assert_almost_equal(res, tgt)
53545345

5346+
@pytest.mark.xfail(reason="implement std(..., where=...)")
53555347
def test_std_where(self):
53565348
a = np.arange(25).reshape((5,5))[::-1]
53575349
whf = np.array([[False, True, False, True, True],
@@ -5396,25 +5388,6 @@ def test_std_where(self):
53965388
with pytest.warns(RuntimeWarning) as w:
53975389
assert_equal(np.std(a, where=False), np.nan)
53985390

5399-
def test_subclass(self):
5400-
class TestArray(np.ndarray):
5401-
def __new__(cls, data, info):
5402-
result = np.array(data)
5403-
result = result.view(cls)
5404-
result.info = info
5405-
return result
5406-
5407-
def __array_finalize__(self, obj):
5408-
self.info = getattr(obj, "info", '')
5409-
5410-
dat = TestArray([[1, 2, 3, 4], [5, 6, 7, 8]], 'jubba')
5411-
res = dat.mean(1)
5412-
assert_(res.info == dat.info)
5413-
res = dat.std(1)
5414-
assert_(res.info == dat.info)
5415-
res = dat.var(1)
5416-
assert_(res.info == dat.info)
5417-
54185391

54195392
class TestVdot:
54205393
def test_basic(self):

0 commit comments

Comments
 (0)