Skip to content

Use return annotation to wrap tensors into ndarrays/sequences of ndarrays #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
13 changes: 10 additions & 3 deletions torch_np/_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from . import _helpers
from ._detail import _binary_ufuncs
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
from ._normalizations import (
ArrayLike,
DTypeLike,
NDArray,
OutArray,
SubokLike,
normalizer,
)

__all__ = [
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
Expand All @@ -29,12 +36,12 @@ def wrapped(
subok: SubokLike = False,
signature=None,
extobj=None,
):
) -> OutArray:
tensors = _helpers.ufunc_preprocess(
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
)
result = torch_func(*tensors)
return _helpers.result_or_out(result, out)
return result, out

return wrapped

Expand Down
24 changes: 0 additions & 24 deletions torch_np/_decorators.py

This file was deleted.

6 changes: 6 additions & 0 deletions torch_np/_detail/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._flips import *
from ._reductions import *

# leading underscore (ndarray.flatten yes, np.flatten no)
from .implementations import *
from .implementations import _flatten
107 changes: 70 additions & 37 deletions torch_np/_detail/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from . import _dtypes_impl, _util

NoValue = None
NoValue = _util.NoValue


import functools
Expand Down Expand Up @@ -313,55 +313,51 @@ def average(a, axis, weights, returned=False, keepdims=False):
return result, wsum


def average_noweights(a_tensor, axis, keepdims=False):
result = mean(a_tensor, axis=axis, keepdims=keepdims)
scl = torch.as_tensor(a_tensor.numel() / result.numel(), dtype=result.dtype)
def average_noweights(a, axis, keepdims=False):
result = mean(a, axis=axis, keepdims=keepdims)
scl = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
return result, scl


def average_weights(a_tensor, axis, w_tensor, keepdims=False):
def average_weights(a, axis, w, keepdims=False):

# dtype
# FIXME: 1. use result_type
# 2. actually implement multiply w/dtype
if not a_tensor.dtype.is_floating_point:
if not a.dtype.is_floating_point:
result_dtype = torch.float64
a_tensor = a_tensor.to(result_dtype)
a = a.to(result_dtype)

result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
result_dtype = _dtypes_impl.result_type_impl([a.dtype, w.dtype])

a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
a = _util.cast_if_needed(a, result_dtype)
w = _util.cast_if_needed(w, result_dtype)

# axis=None ravels, so store the originals to reuse with keepdims=True below
ax, ndim = axis, a_tensor.ndim
ax, ndim = axis, a.ndim

# axis
if axis is None:
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
a_tensor, w_tensor, axis=axis
)
(a, w), axis = _util.axis_none_ravel(a, w, axis=axis)

# axis & weights
if a_tensor.shape != w_tensor.shape:
if a.shape != w.shape:
if axis is None:
raise TypeError(
"Axis must be specified when shapes of a and weights " "differ."
)
if w_tensor.ndim != 1:
if w.ndim != 1:
raise TypeError("1D weights expected when shapes of a and weights differ.")
if w_tensor.shape[0] != a_tensor.shape[axis]:
if w.shape[0] != a.shape[axis]:
raise ValueError("Length of weights not compatible with specified axis.")

# setup weight to broadcast along axis
w_tensor = torch.broadcast_to(
w_tensor, (a_tensor.ndim - 1) * (1,) + w_tensor.shape
)
w_tensor = w_tensor.swapaxes(-1, axis)
w = torch.broadcast_to(w, (a.ndim - 1) * (1,) + w.shape)
w = w.swapaxes(-1, axis)

# do the work
numerator = torch.mul(a_tensor, w_tensor).sum(axis)
denominator = w_tensor.sum(axis)
numerator = torch.mul(a, w).sum(axis)
denominator = w.sum(axis)
result = numerator / denominator

# keepdims
Expand All @@ -371,36 +367,73 @@ def average_weights(a_tensor, axis, w_tensor, keepdims=False):
return result, denominator


def quantile(a_tensor, q_tensor, axis, method, keepdims=False):

if (0 > q_tensor).any() or (q_tensor > 1).any():
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)

if not a_tensor.dtype.is_floating_point:
def quantile(
a,
q,
axis,
overwrite_input,
method,
keepdims=False,
interpolation=None,
):
if overwrite_input:
# raise NotImplementedError("overwrite_input in quantile not implemented.")
# NumPy documents that `overwrite_input` MAY modify inputs:
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
# Here we choose to work out-of-place because why not.
pass

if interpolation is not None:
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")

if (0 > q).any() or (q > 1).any():
raise ValueError("Quantiles must be in range [0, 1], got %s" % q)

if not a.dtype.is_floating_point:
dtype = _dtypes_impl.default_float_dtype
a_tensor = a_tensor.to(dtype)
a = a.to(dtype)

# edge case: torch.quantile only supports float32 and float64
if a_tensor.dtype == torch.float16:
a_tensor = a_tensor.to(torch.float32)
if a.dtype == torch.float16:
a = a.to(torch.float32)

# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
# axis
if axis is not None:
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
axis = _util.normalize_axis_tuple(axis, a.ndim)
axis = _util.allow_only_single_axis(axis)

q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
q = _util.cast_if_needed(q, a.dtype)

# axis=None ravels, so store the originals to reuse with keepdims=True below
ax, ndim = axis, a_tensor.ndim
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
ax, ndim = axis, a.ndim
(a, q), axis = _util.axis_none_ravel(a, q, axis=axis)

result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method)
result = torch.quantile(a, q, axis=axis, interpolation=method)

# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
# while the decorator expects (a, axis, ...)
# this can be fixed, of course, but the cure seems worse then the desease
if keepdims:
result = _util.apply_keepdims(result, ax, ndim)
return result


def percentile(
a,
q,
axis,
overwrite_input,
method,
keepdims=False,
interpolation=None,
):
return quantile(
a,
q / 100.0,
axis=axis,
overwrite_input=overwrite_input,
method=method,
keepdims=keepdims,
interpolation=interpolation,
)
1 change: 1 addition & 0 deletions torch_np/_detail/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import _dtypes_impl

NoValue = None

# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
def is_sequence(seq):
Expand Down
Loading