Skip to content

Commit 3d0c873

Browse files
committed
MAINT: IndexExpression into _funcs.py from _detail/_index_tricks.py
1 parent 4266c86 commit 3d0c873

File tree

4 files changed

+201
-28
lines changed

4 files changed

+201
-28
lines changed

torch_np/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from . import random
22
from ._binary_ufuncs import *
3-
from ._detail._index_tricks import *
43
from ._detail._util import AxisError, UFuncTypeError
54
from ._dtypes import *
65
from ._funcs import *

torch_np/_detail/_index_tricks.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

torch_np/_funcs.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,3 +1698,203 @@ def median(
16981698
return quantile(
16991699
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
17001700
)
1701+
1702+
1703+
@normalizer
1704+
def average(
1705+
a: ArrayLike,
1706+
axis=None,
1707+
weights: ArrayLike = None,
1708+
returned=False,
1709+
*,
1710+
keepdims=NoValue,
1711+
):
1712+
result, wsum = _impl.average(a, axis, weights, returned=returned, keepdims=keepdims)
1713+
if returned:
1714+
return result, wsum
1715+
else:
1716+
return result
1717+
1718+
1719+
@normalizer
1720+
def diff(
1721+
a: ArrayLike,
1722+
n=1,
1723+
axis=-1,
1724+
prepend: Optional[ArrayLike] = NoValue,
1725+
append: Optional[ArrayLike] = NoValue,
1726+
):
1727+
axis = _util.normalize_axis_index(axis, a.ndim)
1728+
1729+
if n < 0:
1730+
raise ValueError(f"order must be non-negative but got {n}")
1731+
1732+
if n == 0:
1733+
# match numpy and return the input immediately
1734+
return a
1735+
1736+
if prepend is not None:
1737+
shape = list(a.shape)
1738+
shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1
1739+
prepend = torch.broadcast_to(prepend, shape)
1740+
1741+
if append is not None:
1742+
shape = list(a.shape)
1743+
shape[axis] = append.shape[axis] if append.ndim > 0 else 1
1744+
append = torch.broadcast_to(append, shape)
1745+
1746+
result = torch.diff(a, n, axis=axis, prepend=prepend, append=append)
1747+
1748+
return result
1749+
1750+
1751+
# ### math functions ###
1752+
1753+
1754+
@normalizer
1755+
def angle(z: ArrayLike, deg=False):
1756+
result = torch.angle(z)
1757+
if deg:
1758+
result = result * 180 / torch.pi
1759+
return result
1760+
1761+
1762+
@normalizer
1763+
def sinc(x: ArrayLike):
1764+
result = torch.sinc(x)
1765+
return result
1766+
1767+
1768+
@normalizer
1769+
def real(a: ArrayLike):
1770+
result = torch.real(a)
1771+
return result
1772+
1773+
1774+
@normalizer
1775+
def imag(a: ArrayLike):
1776+
if a.is_complex():
1777+
result = a.imag
1778+
else:
1779+
result = torch.zeros_like(a)
1780+
return result
1781+
1782+
1783+
@normalizer
1784+
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None) -> OutArray:
1785+
if a.is_floating_point():
1786+
result = torch.round(a, decimals=decimals)
1787+
elif a.is_complex():
1788+
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1789+
result = (
1790+
torch.round(a.real, decimals=decimals)
1791+
+ torch.round(a.imag, decimals=decimals) * 1j
1792+
)
1793+
else:
1794+
# RuntimeError: "round_cpu" not implemented for 'int'
1795+
result = a
1796+
return result, out
1797+
1798+
1799+
around = round_
1800+
round = round_
1801+
1802+
1803+
@normalizer
1804+
def real_if_close(a: ArrayLike, tol=100):
1805+
# XXX: copies vs views; numpy seems to return a copy?
1806+
if not torch.is_complex(a):
1807+
return a
1808+
if tol > 1:
1809+
# Undocumented in numpy: if tol < 1, it's an absolute tolerance!
1810+
# Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon
1811+
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577
1812+
tol = tol * torch.finfo(a.dtype).eps
1813+
1814+
mask = torch.abs(a.imag) < tol
1815+
return a.real if mask.all() else a
1816+
1817+
1818+
@normalizer
1819+
def iscomplex(x: ArrayLike):
1820+
if torch.is_complex(x):
1821+
return torch.as_tensor(x).imag != 0
1822+
result = torch.zeros_like(x, dtype=torch.bool)
1823+
if result.ndim == 0:
1824+
result = result.item()
1825+
return result
1826+
1827+
1828+
@normalizer
1829+
def isreal(x: ArrayLike):
1830+
if torch.is_complex(x):
1831+
return torch.as_tensor(x).imag == 0
1832+
result = torch.ones_like(x, dtype=torch.bool)
1833+
if result.ndim == 0:
1834+
result = result.item()
1835+
return result
1836+
1837+
1838+
@normalizer
1839+
def iscomplexobj(x: ArrayLike):
1840+
result = torch.is_complex(x)
1841+
return result
1842+
1843+
1844+
@normalizer
1845+
def isrealobj(x: ArrayLike):
1846+
result = not torch.is_complex(x)
1847+
return result
1848+
1849+
1850+
@normalizer
1851+
def isneginf(x: ArrayLike, out: Optional[NDArray] = None):
1852+
result = torch.isneginf(x, out=out)
1853+
return result
1854+
1855+
1856+
@normalizer
1857+
def isposinf(x: ArrayLike, out: Optional[NDArray] = None):
1858+
result = torch.isposinf(x, out=out)
1859+
return result
1860+
1861+
1862+
@normalizer
1863+
def i0(x: ArrayLike):
1864+
result = torch.special.i0(x)
1865+
return result
1866+
1867+
1868+
@normalizer(return_on_failure=False)
1869+
def isscalar(a: ArrayLike):
1870+
# XXX: this is a stub
1871+
if a is False:
1872+
return a
1873+
return a.numel() == 1
1874+
1875+
1876+
"""
1877+
Vendored objects from numpy.lib.index_tricks
1878+
"""
1879+
1880+
1881+
class IndexExpression:
1882+
"""
1883+
Written by Konrad Hinsen <[email protected]>
1884+
last revision: 1999-7-23
1885+
1886+
Cosmetic changes by T. Oliphant 2001
1887+
"""
1888+
1889+
def __init__(self, maketuple):
1890+
self.maketuple = maketuple
1891+
1892+
def __getitem__(self, item):
1893+
if self.maketuple and not isinstance(item, tuple):
1894+
return (item,)
1895+
else:
1896+
return item
1897+
1898+
1899+
index_exp = IndexExpression(maketuple=True)
1900+
s_ = IndexExpression(maketuple=False)

torch_np/tests/numpy_tests/lib/test_index_tricks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from torch_np import diag_indices, diag_indices_from, fill_diagonal
16-
from torch_np._detail._index_tricks import index_exp, s_
16+
from torch_np import index_exp, s_
1717

1818

1919
@pytest.mark.xfail(reason='unravel_index not implemented')

0 commit comments

Comments
 (0)