Skip to content

Commit 0e8e6b5

Browse files
authored
Merge pull request #95 from Quansight-Labs/filter_windows
Add filter windows, tensordot
2 parents faf2350 + ee6875c commit 0e8e6b5

File tree

6 files changed

+74
-43
lines changed

6 files changed

+74
-43
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@ def binary_repr(num, width=None):
134134
raise NotImplementedError
135135

136136

137-
def blackman(M):
138-
raise NotImplementedError
139-
140-
141137
def block(arrays):
142138
raise NotImplementedError
143139

@@ -337,14 +333,6 @@ def gradient(f, *varargs, axis=None, edge_order=1):
337333
raise NotImplementedError
338334

339335

340-
def hamming(M):
341-
raise NotImplementedError
342-
343-
344-
def hanning(M):
345-
raise NotImplementedError
346-
347-
348336
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
349337
raise NotImplementedError
350338

@@ -409,10 +397,6 @@ def ix_(*args):
409397
raise NotImplementedError
410398

411399

412-
def kaiser(M, beta):
413-
raise NotImplementedError
414-
415-
416400
def lexsort(keys, axis=-1):
417401
raise NotImplementedError
418402

@@ -759,10 +743,6 @@ def take(a, indices, axis=None, out=None, mode="raise"):
759743
raise NotImplementedError
760744

761745

762-
def tensordot(a, b, axes=2):
763-
raise NotImplementedError
764-
765-
766746
def trapz(y, x=None, dx=1.0, axis=-1):
767747
raise NotImplementedError
768748

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
inf = float("inf")
1616
nan = float("nan")
17-
from math import pi # isort: skip
17+
from math import pi, e # isort: skip
1818

1919
False_ = asarray(False, bool_)
2020
True_ = asarray(True, bool_)

torch_np/_funcs.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,13 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
11581158
return result.item()
11591159

11601160

1161+
@normalizer
1162+
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
1163+
if isinstance(axes, (list, tuple)):
1164+
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
1165+
return torch.tensordot(a, b, dims=axes)
1166+
1167+
11611168
@normalizer
11621169
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
11631170
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
@@ -1850,3 +1857,36 @@ def __getitem__(self, item):
18501857

18511858
index_exp = IndexExpression(maketuple=True)
18521859
s_ = IndexExpression(maketuple=False)
1860+
1861+
1862+
# ### Filter windows ###
1863+
1864+
1865+
@normalizer
1866+
def hamming(M):
1867+
dtype = _dtypes_impl.default_float_dtype
1868+
return torch.hamming_window(M, periodic=False, dtype=dtype)
1869+
1870+
1871+
@normalizer
1872+
def hanning(M):
1873+
dtype = _dtypes_impl.default_float_dtype
1874+
return torch.hann_window(M, periodic=False, dtype=dtype)
1875+
1876+
1877+
@normalizer
1878+
def kaiser(M, beta):
1879+
dtype = _dtypes_impl.default_float_dtype
1880+
return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)
1881+
1882+
1883+
@normalizer
1884+
def blackman(M):
1885+
dtype = _dtypes_impl.default_float_dtype
1886+
return torch.blackman_window(M, periodic=False, dtype=dtype)
1887+
1888+
1889+
@normalizer
1890+
def bartlett(M):
1891+
dtype = _dtypes_impl.default_float_dtype
1892+
return torch.bartlett_window(M, periodic=False, dtype=dtype)

torch_np/_ndarray.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ class ndarray:
6464
def __init__(self, t=None):
6565
if t is None:
6666
self.tensor = torch.Tensor()
67+
elif isinstance(t, torch.Tensor):
68+
self.tensor = t
6769
else:
68-
self.tensor = torch.as_tensor(t)
70+
raise ValueError(
71+
"ndarray constructor is not recommended; prefer"
72+
"either array(...) or zeros/empty(...)"
73+
)
6974

7075
@property
7176
def shape(self):

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2988,15 +2988,21 @@ def test_shape_mismatch_error_message(self):
29882988
np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7])
29892989

29902990

2991-
@pytest.mark.xfail(reason="TODO")
29922991
class TestTensordot:
29932992

29942993
def test_zero_dimension(self):
29952994
# Test resolution to issue #5663
2996-
a = np.ndarray((3,0))
2997-
b = np.ndarray((0,4))
2995+
a = np.zeros((3,0))
2996+
b = np.zeros((0,4))
29982997
td = np.tensordot(a, b, (1, 0))
29992998
assert_array_equal(td, np.dot(a, b))
2999+
3000+
@pytest.mark.xfail(reason="no einsum")
3001+
def test_zero_dimension_einsum(self):
3002+
# Test resolution to issue #5663
3003+
a = np.zeros((3,0))
3004+
b = np.zeros((0,4))
3005+
td = np.tensordot(a, b, (1, 0))
30003006
assert_array_equal(td, np.einsum('ij,jk', a, b))
30013007

30023008
def test_zero_dimensional(self):

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626

2727
# FIXME: make from torch_np
2828
from numpy.lib import (
29-
bartlett, blackman,
30-
delete, digitize, extract, gradient, hamming, hanning,
31-
insert, interp, kaiser, msort, piecewise, place,
29+
delete, digitize, extract, gradient,
30+
insert, interp, msort, piecewise, place,
3231
select, setxor1d, trapz, trim_zeros, unwrap, vectorize
3332
)
3433
from torch_np._detail._util import normalize_axis_tuple
3534

3635
from torch_np import corrcoef, cov, i0, angle, sinc, diff, meshgrid, unique
36+
from torch_np import flipud, hamming, hanning, kaiser, blackman, bartlett
37+
3738

3839
def get_mat(n):
3940
data = np.arange(n)
@@ -1701,22 +1702,21 @@ def test_period(self):
17011702
assert sm_discont.dtype == wrap_uneven.dtype
17021703

17031704

1704-
@pytest.mark.xfail(reason='TODO: implement')
17051705
@pytest.mark.parametrize(
17061706
"dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]
17071707
)
17081708
@pytest.mark.parametrize("M", [0, 1, 10])
17091709
class TestFilterwindows:
17101710

17111711
def test_hanning(self, dtype: str, M: int) -> None:
1712-
scalar = np.array(M, dtype=dtype)[()]
1712+
scalar = M
17131713

17141714
w = hanning(scalar)
1715-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1715+
ref_dtype = np.result_type(dtype, np.float64)
17161716
assert w.dtype == ref_dtype
17171717

17181718
# check symmetry
1719-
assert_equal(w, flipud(w))
1719+
assert_allclose(w, flipud(w), atol=1e-15)
17201720

17211721
# check known value
17221722
if scalar < 1:
@@ -1727,14 +1727,14 @@ def test_hanning(self, dtype: str, M: int) -> None:
17271727
assert_almost_equal(np.sum(w, axis=0), 4.500, 4)
17281728

17291729
def test_hamming(self, dtype: str, M: int) -> None:
1730-
scalar = np.array(M, dtype=dtype)[()]
1730+
scalar = M
17311731

17321732
w = hamming(scalar)
1733-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1733+
ref_dtype = np.result_type(dtype, np.float64)
17341734
assert w.dtype == ref_dtype
17351735

17361736
# check symmetry
1737-
assert_equal(w, flipud(w))
1737+
assert_allclose(w, flipud(w), atol=1e-15)
17381738

17391739
# check known value
17401740
if scalar < 1:
@@ -1745,14 +1745,14 @@ def test_hamming(self, dtype: str, M: int) -> None:
17451745
assert_almost_equal(np.sum(w, axis=0), 4.9400, 4)
17461746

17471747
def test_bartlett(self, dtype: str, M: int) -> None:
1748-
scalar = np.array(M, dtype=dtype)[()]
1748+
scalar = M
17491749

17501750
w = bartlett(scalar)
1751-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1751+
ref_dtype = np.result_type(dtype, np.float64)
17521752
assert w.dtype == ref_dtype
17531753

17541754
# check symmetry
1755-
assert_equal(w, flipud(w))
1755+
assert_allclose(w, flipud(w), atol=1e-15)
17561756

17571757
# check known value
17581758
if scalar < 1:
@@ -1763,14 +1763,14 @@ def test_bartlett(self, dtype: str, M: int) -> None:
17631763
assert_almost_equal(np.sum(w, axis=0), 4.4444, 4)
17641764

17651765
def test_blackman(self, dtype: str, M: int) -> None:
1766-
scalar = np.array(M, dtype=dtype)[()]
1766+
scalar = M
17671767

17681768
w = blackman(scalar)
1769-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1769+
ref_dtype = np.result_type(dtype, np.float64)
17701770
assert w.dtype == ref_dtype
17711771

17721772
# check symmetry
1773-
assert_equal(w, flipud(w))
1773+
assert_allclose(w, flipud(w), atol=1e-15)
17741774

17751775
# check known value
17761776
if scalar < 1:
@@ -1781,10 +1781,10 @@ def test_blackman(self, dtype: str, M: int) -> None:
17811781
assert_almost_equal(np.sum(w, axis=0), 3.7800, 4)
17821782

17831783
def test_kaiser(self, dtype: str, M: int) -> None:
1784-
scalar = np.array(M, dtype=dtype)[()]
1784+
scalar = M
17851785

17861786
w = kaiser(scalar, 0)
1787-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1787+
ref_dtype = np.result_type(dtype, np.float64)
17881788
assert w.dtype == ref_dtype
17891789

17901790
# check symmetry

0 commit comments

Comments
 (0)