@@ -1698,3 +1698,203 @@ def median(
1698
1698
return quantile (
1699
1699
a , 0.5 , axis = axis , overwrite_input = overwrite_input , out = out , keepdims = keepdims
1700
1700
)
1701
+
1702
+
1703
+ @normalizer
1704
+ def average (
1705
+ a : ArrayLike ,
1706
+ axis = None ,
1707
+ weights : ArrayLike = None ,
1708
+ returned = False ,
1709
+ * ,
1710
+ keepdims = NoValue ,
1711
+ ):
1712
+ result , wsum = _impl .average (a , axis , weights , returned = returned , keepdims = keepdims )
1713
+ if returned :
1714
+ return result , wsum
1715
+ else :
1716
+ return result
1717
+
1718
+
1719
+ @normalizer
1720
+ def diff (
1721
+ a : ArrayLike ,
1722
+ n = 1 ,
1723
+ axis = - 1 ,
1724
+ prepend : Optional [ArrayLike ] = NoValue ,
1725
+ append : Optional [ArrayLike ] = NoValue ,
1726
+ ):
1727
+ axis = _util .normalize_axis_index (axis , a .ndim )
1728
+
1729
+ if n < 0 :
1730
+ raise ValueError (f"order must be non-negative but got { n } " )
1731
+
1732
+ if n == 0 :
1733
+ # match numpy and return the input immediately
1734
+ return a
1735
+
1736
+ if prepend is not None :
1737
+ shape = list (a .shape )
1738
+ shape [axis ] = prepend .shape [axis ] if prepend .ndim > 0 else 1
1739
+ prepend = torch .broadcast_to (prepend , shape )
1740
+
1741
+ if append is not None :
1742
+ shape = list (a .shape )
1743
+ shape [axis ] = append .shape [axis ] if append .ndim > 0 else 1
1744
+ append = torch .broadcast_to (append , shape )
1745
+
1746
+ result = torch .diff (a , n , axis = axis , prepend = prepend , append = append )
1747
+
1748
+ return result
1749
+
1750
+
1751
+ # ### math functions ###
1752
+
1753
+
1754
+ @normalizer
1755
+ def angle (z : ArrayLike , deg = False ):
1756
+ result = torch .angle (z )
1757
+ if deg :
1758
+ result = result * 180 / torch .pi
1759
+ return result
1760
+
1761
+
1762
+ @normalizer
1763
+ def sinc (x : ArrayLike ):
1764
+ result = torch .sinc (x )
1765
+ return result
1766
+
1767
+
1768
+ @normalizer
1769
+ def real (a : ArrayLike ):
1770
+ result = torch .real (a )
1771
+ return result
1772
+
1773
+
1774
+ @normalizer
1775
+ def imag (a : ArrayLike ):
1776
+ if a .is_complex ():
1777
+ result = a .imag
1778
+ else :
1779
+ result = torch .zeros_like (a )
1780
+ return result
1781
+
1782
+
1783
+ @normalizer
1784
+ def round_ (a : ArrayLike , decimals = 0 , out : Optional [NDArray ] = None ) -> OutArray :
1785
+ if a .is_floating_point ():
1786
+ result = torch .round (a , decimals = decimals )
1787
+ elif a .is_complex ():
1788
+ # RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1789
+ result = (
1790
+ torch .round (a .real , decimals = decimals )
1791
+ + torch .round (a .imag , decimals = decimals ) * 1j
1792
+ )
1793
+ else :
1794
+ # RuntimeError: "round_cpu" not implemented for 'int'
1795
+ result = a
1796
+ return result , out
1797
+
1798
+
1799
+ around = round_
1800
+ round = round_
1801
+
1802
+
1803
+ @normalizer
1804
+ def real_if_close (a : ArrayLike , tol = 100 ):
1805
+ # XXX: copies vs views; numpy seems to return a copy?
1806
+ if not torch .is_complex (a ):
1807
+ return a
1808
+ if tol > 1 :
1809
+ # Undocumented in numpy: if tol < 1, it's an absolute tolerance!
1810
+ # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon
1811
+ # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577
1812
+ tol = tol * torch .finfo (a .dtype ).eps
1813
+
1814
+ mask = torch .abs (a .imag ) < tol
1815
+ return a .real if mask .all () else a
1816
+
1817
+
1818
+ @normalizer
1819
+ def iscomplex (x : ArrayLike ):
1820
+ if torch .is_complex (x ):
1821
+ return torch .as_tensor (x ).imag != 0
1822
+ result = torch .zeros_like (x , dtype = torch .bool )
1823
+ if result .ndim == 0 :
1824
+ result = result .item ()
1825
+ return result
1826
+
1827
+
1828
+ @normalizer
1829
+ def isreal (x : ArrayLike ):
1830
+ if torch .is_complex (x ):
1831
+ return torch .as_tensor (x ).imag == 0
1832
+ result = torch .ones_like (x , dtype = torch .bool )
1833
+ if result .ndim == 0 :
1834
+ result = result .item ()
1835
+ return result
1836
+
1837
+
1838
+ @normalizer
1839
+ def iscomplexobj (x : ArrayLike ):
1840
+ result = torch .is_complex (x )
1841
+ return result
1842
+
1843
+
1844
+ @normalizer
1845
+ def isrealobj (x : ArrayLike ):
1846
+ result = not torch .is_complex (x )
1847
+ return result
1848
+
1849
+
1850
+ @normalizer
1851
+ def isneginf (x : ArrayLike , out : Optional [NDArray ] = None ):
1852
+ result = torch .isneginf (x , out = out )
1853
+ return result
1854
+
1855
+
1856
+ @normalizer
1857
+ def isposinf (x : ArrayLike , out : Optional [NDArray ] = None ):
1858
+ result = torch .isposinf (x , out = out )
1859
+ return result
1860
+
1861
+
1862
+ @normalizer
1863
+ def i0 (x : ArrayLike ):
1864
+ result = torch .special .i0 (x )
1865
+ return result
1866
+
1867
+
1868
+ @normalizer (return_on_failure = False )
1869
+ def isscalar (a : ArrayLike ):
1870
+ # XXX: this is a stub
1871
+ if a is False :
1872
+ return a
1873
+ return a .numel () == 1
1874
+
1875
+
1876
+ """
1877
+ Vendored objects from numpy.lib.index_tricks
1878
+ """
1879
+
1880
+
1881
+ class IndexExpression :
1882
+ """
1883
+ Written by Konrad Hinsen <[email protected] >
1884
+ last revision: 1999-7-23
1885
+
1886
+ Cosmetic changes by T. Oliphant 2001
1887
+ """
1888
+
1889
+ def __init__ (self , maketuple ):
1890
+ self .maketuple = maketuple
1891
+
1892
+ def __getitem__ (self , item ):
1893
+ if self .maketuple and not isinstance (item , tuple ):
1894
+ return (item ,)
1895
+ else :
1896
+ return item
1897
+
1898
+
1899
+ index_exp = IndexExpression (maketuple = True )
1900
+ s_ = IndexExpression (maketuple = False )
0 commit comments