Skip to content

Merge wrapper and implementations #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 31, 2023
Merged
2 changes: 0 additions & 2 deletions torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from . import random
from ._binary_ufuncs import *
from ._detail._index_tricks import *
from ._detail._util import AxisError, UFuncTypeError
from ._dtypes import *
from ._funcs import *
from ._getlimits import errstate, finfo, iinfo
from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type
from ._unary_ufuncs import *
from ._wrapper import *

# from . import testing

Expand Down
2 changes: 1 addition & 1 deletion torch_np/_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def divmod(
out1: Optional[NDArray] = None,
out2: Optional[NDArray] = None,
/,
out: Optional[tuple[NDArray]] = (None, None),
out: tuple[Optional[NDArray], Optional[NDArray]] = (None, None),
*,
where=True,
casting="same_kind",
Expand Down
5 changes: 0 additions & 5 deletions torch_np/_detail/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from ._flips import *
from ._reductions import *

# leading underscore (ndarray.flatten yes, np.flatten no)
from .implementations import *
from .implementations import _flatten
61 changes: 0 additions & 61 deletions torch_np/_detail/_flips.py

This file was deleted.

26 changes: 0 additions & 26 deletions torch_np/_detail/_index_tricks.py

This file was deleted.

30 changes: 0 additions & 30 deletions torch_np/_detail/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,36 +170,6 @@ def typecast_tensors(tensors, target_dtype, casting):
return tuple(cast_tensors)


def axis_expand_func(func, tensor, axis, *args, **kwds):
"""Generically handle axis arguments in reductions."""
if axis is not None:
if not isinstance(axis, (list, tuple)):
axis = (axis,)
axis = normalize_axis_tuple(axis, tensor.ndim)

if axis == ():
newshape = expand_shape(tensor.shape, axis=0)
tensor = tensor.reshape(newshape)
axis = (0,)

result = func(tensor, axis=axis, *args, **kwds)

return result


def axis_ravel_func(func, tensor, axis, *args, **kwds):
"""Generically handle axis arguments in cumsum/cumprod."""
if axis is not None:
axis = normalize_axis_index(axis, tensor.ndim)

tensors, axis = axis_none_ravel(tensor, axis=axis)
tensor = tensors[0]

result = func(tensor, axis=axis, *args, **kwds)

return result


def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
"""The core logic of the array(...) function.

Expand Down
Loading