Skip to content

Commit 1e70120

Browse files
committed
ENH: add a naive divmod, un-xfail relevant tests
Reviewed at #84
1 parent 37d6f4f commit 1e70120

File tree

4 files changed

+94
-44
lines changed

4 files changed

+94
-44
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,58 @@ def wrapped(
5050
decorated.__qualname__ = name # XXX: is this really correct?
5151
decorated.__name__ = name
5252
vars()[name] = decorated
53+
54+
55+
# a stub implementation of divmod, should be improved after
56+
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
57+
#
58+
# Implementation details: we just call two ufuncs which have been created
59+
# just above, for x1 // x2 and x1 % x2.
60+
# This means we are normalizing x1, x2 in each of the ufuncs --- note that there
61+
# is no @normalizer on divmod.
62+
63+
64+
def divmod(
65+
x1,
66+
x2,
67+
/,
68+
out=None,
69+
*,
70+
where=True,
71+
casting="same_kind",
72+
order="K",
73+
dtype=None,
74+
subok: SubokLike = False,
75+
signature=None,
76+
extobj=None,
77+
):
78+
out1, out2 = None, None
79+
if out is not None:
80+
out1, out2 = out
81+
82+
kwds = dict(
83+
where=where,
84+
casting=casting,
85+
order=order,
86+
dtype=dtype,
87+
subok=subok,
88+
signature=signature,
89+
extobj=extobj,
90+
)
91+
92+
# NB: use local names for
93+
quot = floor_divide(x1, x2, out=out1, **kwds)
94+
rem = remainder(x1, x2, out=out2, **kwds)
95+
96+
quot = _helpers.result_or_out(quot.tensor, out1)
97+
rem = _helpers.result_or_out(rem.tensor, out2)
98+
99+
return quot, rem
100+
101+
102+
def modf(x, /, *args, **kwds):
103+
quot, rem = divmod(x, 1, *args, **kwds)
104+
return rem, quot
105+
106+
107+
__all__ = __all__ + ["divmod", "modf"]

torch_np/_ndarray.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ def __rfloordiv__(self, other):
266266
def __ifloordiv__(self, other):
267267
return _binary_ufuncs.floor_divide(self, other, out=self)
268268

269+
__divmod__ = _binary_ufuncs.divmod
270+
269271
# power, self**exponent
270272
__pow__ = __rpow__ = _binary_ufuncs.float_power
271273

torch_np/_normalizations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def normalize_dtype(dtype, name=None):
4545
return torch_dtype
4646

4747

48-
def normalize_subok_like(arg, name):
48+
def normalize_subok_like(arg, name="subok"):
4949
if arg:
5050
raise ValueError(f"'{name}' parameter is not supported.")
5151

@@ -88,7 +88,7 @@ def maybe_normalize(arg, parm, return_on_failure=_sentinel):
8888
"""Normalize arg if a normalizer is registred."""
8989
normalizer = normalizers.get(parm.annotation, None)
9090
try:
91-
return normalizer(arg) if normalizer else arg
91+
return normalizer(arg, parm.name) if normalizer else arg
9292
except Exception as exc:
9393
if return_on_failure is not _sentinel:
9494
return return_on_failure

torch_np/tests/numpy_tests/core/test_scalarmath.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,6 @@ class TestModulus:
262262
def test_modulus_basic(self):
263263
dt = np.typecodes['AllInteger'] + np.typecodes['Float']
264264
for op in [floordiv_and_mod, divmod]:
265-
266-
if op == divmod:
267-
pytest.xfail(reason="__divmod__ not implemented")
268-
269265
for dt1, dt2 in itertools.product(dt, dt):
270266
for sg1, sg2 in itertools.product(_signs(dt1), _signs(dt2)):
271267
fmt = 'op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s'
@@ -279,7 +275,7 @@ def test_modulus_basic(self):
279275
else:
280276
assert_(b > rem >= 0, msg)
281277

282-
@pytest.mark.xfail(reason='divmod not implemented')
278+
@pytest.mark.slow
283279
def test_float_modulus_exact(self):
284280
# test that float results are exact for small integers. This also
285281
# holds for the same integers scaled by powers of two.
@@ -311,10 +307,6 @@ def test_float_modulus_roundoff(self):
311307
# gh-6127
312308
dt = np.typecodes['Float']
313309
for op in [floordiv_and_mod, divmod]:
314-
315-
if op == divmod:
316-
pytest.xfail(reason="__divmod__ not implemented")
317-
318310
for dt1, dt2 in itertools.product(dt, dt):
319311
for sg1, sg2 in itertools.product((+1, -1), (+1, -1)):
320312
fmt = 'op: %s, dt1: %s, dt2: %s, sg1: %s, sg2: %s'
@@ -329,41 +321,42 @@ def test_float_modulus_roundoff(self):
329321
else:
330322
assert_(b > rem >= 0, msg)
331323

332-
@pytest.mark.skip(reason='float16 on cpu is incomplete in pytorch')
333-
def test_float_modulus_corner_cases(self):
334-
# Check remainder magnitude.
335-
for dt in np.typecodes['Float']:
336-
b = np.array(1.0, dtype=dt)
337-
a = np.nextafter(np.array(0.0, dtype=dt), -b)
338-
rem = operator.mod(a, b)
339-
assert_(rem <= b, 'dt: %s' % dt)
340-
rem = operator.mod(-a, -b)
341-
assert_(rem >= -b, 'dt: %s' % dt)
324+
@pytest.mark.parametrize('dt', np.typecodes['Float'])
325+
def test_float_modulus_corner_cases(self, dt):
326+
if dt == 'e':
327+
pytest.xfail(reason="RuntimeError: 'nextafter_cpu' not implemented for 'Half'")
328+
329+
b = np.array(1.0, dtype=dt)
330+
a = np.nextafter(np.array(0.0, dtype=dt), -b)
331+
rem = operator.mod(a, b)
332+
assert_(rem <= b, 'dt: %s' % dt)
333+
rem = operator.mod(-a, -b)
334+
assert_(rem >= -b, 'dt: %s' % dt)
342335

343336
# Check nans, inf
344-
with suppress_warnings() as sup:
345-
sup.filter(RuntimeWarning, "invalid value encountered in remainder")
346-
sup.filter(RuntimeWarning, "divide by zero encountered in remainder")
347-
sup.filter(RuntimeWarning, "divide by zero encountered in floor_divide")
348-
sup.filter(RuntimeWarning, "divide by zero encountered in divmod")
349-
sup.filter(RuntimeWarning, "invalid value encountered in divmod")
350-
for dt in np.typecodes['Float']:
351-
fone = np.array(1.0, dtype=dt)
352-
fzer = np.array(0.0, dtype=dt)
353-
finf = np.array(np.inf, dtype=dt)
354-
fnan = np.array(np.nan, dtype=dt)
355-
rem = operator.mod(fone, fzer)
356-
assert_(np.isnan(rem), 'dt: %s' % dt)
357-
# MSVC 2008 returns NaN here, so disable the check.
358-
#rem = operator.mod(fone, finf)
359-
#assert_(rem == fone, 'dt: %s' % dt)
360-
rem = operator.mod(fone, fnan)
361-
assert_(np.isnan(rem), 'dt: %s' % dt)
362-
rem = operator.mod(finf, fone)
363-
assert_(np.isnan(rem), 'dt: %s' % dt)
364-
for op in [floordiv_and_mod, divmod]:
365-
div, mod = op(fone, fzer)
366-
assert_(np.isinf(div)) and assert_(np.isnan(mod))
337+
# with suppress_warnings() as sup:
338+
# sup.filter(RuntimeWarning, "invalid value encountered in remainder")
339+
# sup.filter(RuntimeWarning, "divide by zero encountered in remainder")
340+
# sup.filter(RuntimeWarning, "divide by zero encountered in floor_divide")
341+
# sup.filter(RuntimeWarning, "divide by zero encountered in divmod")
342+
# sup.filter(RuntimeWarning, "invalid value encountered in divmod")
343+
for dt in np.typecodes['Float']:
344+
fone = np.array(1.0, dtype=dt)
345+
fzer = np.array(0.0, dtype=dt)
346+
finf = np.array(np.inf, dtype=dt)
347+
fnan = np.array(np.nan, dtype=dt)
348+
rem = operator.mod(fone, fzer)
349+
assert_(np.isnan(rem), 'dt: %s' % dt)
350+
# MSVC 2008 returns NaN here, so disable the check.
351+
#rem = operator.mod(fone, finf)
352+
#assert_(rem == fone, 'dt: %s' % dt)
353+
rem = operator.mod(fone, fnan)
354+
assert_(np.isnan(rem), 'dt: %s' % dt)
355+
rem = operator.mod(finf, fone)
356+
assert_(np.isnan(rem), 'dt: %s' % dt)
357+
for op in [floordiv_and_mod, divmod]:
358+
div, mod = op(fone, fzer)
359+
assert_(np.isinf(div)) and assert_(np.isnan(mod))
367360

368361

369362
class TestComplexDivision:

0 commit comments

Comments
 (0)