Skip to content

Commit c7d3b53

Browse files
committed
ENH: add filter windows
1 parent 5604758 commit c7d3b53

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
inf = float("inf")
1616
nan = float("nan")
17-
from math import pi # isort: skip
17+
from math import pi, e # isort: skip
1818

1919
False_ = asarray(False, bool_)
2020
True_ = asarray(True, bool_)

torch_np/_funcs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,3 +1898,41 @@ def __getitem__(self, item):
18981898

18991899
index_exp = IndexExpression(maketuple=True)
19001900
s_ = IndexExpression(maketuple=False)
1901+
1902+
1903+
# ### Filter windows ###
1904+
1905+
1906+
@normalizer
1907+
def hamming(M: ArrayLike):
1908+
M = int(M.item())
1909+
dtype = _dtypes_impl.default_float_dtype
1910+
return torch.hamming_window(M, periodic=False, dtype=dtype)
1911+
1912+
1913+
@normalizer
1914+
def hanning(M: ArrayLike):
1915+
M = int(M.item())
1916+
dtype = _dtypes_impl.default_float_dtype
1917+
return torch.hann_window(M, periodic=False, dtype=dtype)
1918+
1919+
1920+
@normalizer
1921+
def kaiser(M: ArrayLike, beta):
1922+
M = int(M.item())
1923+
dtype = _dtypes_impl.default_float_dtype
1924+
return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)
1925+
1926+
1927+
@normalizer
1928+
def blackman(M: ArrayLike):
1929+
M = int(M.item())
1930+
dtype = _dtypes_impl.default_float_dtype
1931+
return torch.blackman_window(M, periodic=False, dtype=dtype)
1932+
1933+
1934+
@normalizer
1935+
def bartlett(M: ArrayLike):
1936+
M = int(M.item())
1937+
dtype = _dtypes_impl.default_float_dtype
1938+
return torch.bartlett_window(M, periodic=False, dtype=dtype)

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626

2727
# FIXME: make from torch_np
2828
from numpy.lib import (
29-
bartlett, blackman,
30-
delete, digitize, extract, gradient, hamming, hanning,
31-
insert, interp, kaiser, msort, piecewise, place,
29+
delete, digitize, extract, gradient,
30+
insert, interp, msort, piecewise, place,
3231
select, setxor1d, trapz, trim_zeros, unwrap, vectorize
3332
)
3433
from torch_np._detail._util import normalize_axis_tuple
3534

3635
from torch_np import corrcoef, cov, i0, angle, sinc, diff, meshgrid, unique
36+
from torch_np import flipud, hamming, hanning, kaiser, blackman, bartlett
37+
3738

3839
def get_mat(n):
3940
data = np.arange(n)
@@ -1701,7 +1702,6 @@ def test_period(self):
17011702
assert sm_discont.dtype == wrap_uneven.dtype
17021703

17031704

1704-
@pytest.mark.xfail(reason='TODO: implement')
17051705
@pytest.mark.parametrize(
17061706
"dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]
17071707
)
@@ -1716,7 +1716,7 @@ def test_hanning(self, dtype: str, M: int) -> None:
17161716
assert w.dtype == ref_dtype
17171717

17181718
# check symmetry
1719-
assert_equal(w, flipud(w))
1719+
assert_allclose(w, flipud(w), atol=1e-15)
17201720

17211721
# check known value
17221722
if scalar < 1:
@@ -1734,7 +1734,7 @@ def test_hamming(self, dtype: str, M: int) -> None:
17341734
assert w.dtype == ref_dtype
17351735

17361736
# check symmetry
1737-
assert_equal(w, flipud(w))
1737+
assert_allclose(w, flipud(w), atol=1e-15)
17381738

17391739
# check known value
17401740
if scalar < 1:
@@ -1752,7 +1752,7 @@ def test_bartlett(self, dtype: str, M: int) -> None:
17521752
assert w.dtype == ref_dtype
17531753

17541754
# check symmetry
1755-
assert_equal(w, flipud(w))
1755+
assert_allclose(w, flipud(w), atol=1e-15)
17561756

17571757
# check known value
17581758
if scalar < 1:
@@ -1770,7 +1770,7 @@ def test_blackman(self, dtype: str, M: int) -> None:
17701770
assert w.dtype == ref_dtype
17711771

17721772
# check symmetry
1773-
assert_equal(w, flipud(w))
1773+
assert_allclose(w, flipud(w), atol=1e-15)
17741774

17751775
# check known value
17761776
if scalar < 1:

0 commit comments

Comments
 (0)