Skip to content

Commit 862c12d

Browse files
committed
MAINT: filter windows only accept ints, not array_likes
1 parent 6c44658 commit 862c12d

File tree

2 files changed

+15
-20
lines changed

2 files changed

+15
-20
lines changed

torch_np/_funcs.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,35 +1913,30 @@ def __getitem__(self, item):
19131913

19141914

19151915
@normalizer
1916-
def hamming(M: ArrayLike):
1917-
M = int(M.item())
1916+
def hamming(M):
19181917
dtype = _dtypes_impl.default_float_dtype
19191918
return torch.hamming_window(M, periodic=False, dtype=dtype)
19201919

19211920

19221921
@normalizer
1923-
def hanning(M: ArrayLike):
1924-
M = int(M.item())
1922+
def hanning(M):
19251923
dtype = _dtypes_impl.default_float_dtype
19261924
return torch.hann_window(M, periodic=False, dtype=dtype)
19271925

19281926

19291927
@normalizer
1930-
def kaiser(M: ArrayLike, beta):
1931-
M = int(M.item())
1928+
def kaiser(M, beta):
19321929
dtype = _dtypes_impl.default_float_dtype
19331930
return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)
19341931

19351932

19361933
@normalizer
1937-
def blackman(M: ArrayLike):
1938-
M = int(M.item())
1934+
def blackman(M):
19391935
dtype = _dtypes_impl.default_float_dtype
19401936
return torch.blackman_window(M, periodic=False, dtype=dtype)
19411937

19421938

19431939
@normalizer
1944-
def bartlett(M: ArrayLike):
1945-
M = int(M.item())
1940+
def bartlett(M):
19461941
dtype = _dtypes_impl.default_float_dtype
19471942
return torch.bartlett_window(M, periodic=False, dtype=dtype)

torch_np/tests/numpy_tests/lib/test_function_base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,10 +1709,10 @@ def test_period(self):
17091709
class TestFilterwindows:
17101710

17111711
def test_hanning(self, dtype: str, M: int) -> None:
1712-
scalar = np.array(M, dtype=dtype)[()]
1712+
scalar = M
17131713

17141714
w = hanning(scalar)
1715-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1715+
ref_dtype = np.result_type(dtype, np.float64)
17161716
assert w.dtype == ref_dtype
17171717

17181718
# check symmetry
@@ -1727,10 +1727,10 @@ def test_hanning(self, dtype: str, M: int) -> None:
17271727
assert_almost_equal(np.sum(w, axis=0), 4.500, 4)
17281728

17291729
def test_hamming(self, dtype: str, M: int) -> None:
1730-
scalar = np.array(M, dtype=dtype)[()]
1730+
scalar = M
17311731

17321732
w = hamming(scalar)
1733-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1733+
ref_dtype = np.result_type(dtype, np.float64)
17341734
assert w.dtype == ref_dtype
17351735

17361736
# check symmetry
@@ -1745,10 +1745,10 @@ def test_hamming(self, dtype: str, M: int) -> None:
17451745
assert_almost_equal(np.sum(w, axis=0), 4.9400, 4)
17461746

17471747
def test_bartlett(self, dtype: str, M: int) -> None:
1748-
scalar = np.array(M, dtype=dtype)[()]
1748+
scalar = M
17491749

17501750
w = bartlett(scalar)
1751-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1751+
ref_dtype = np.result_type(dtype, np.float64)
17521752
assert w.dtype == ref_dtype
17531753

17541754
# check symmetry
@@ -1763,10 +1763,10 @@ def test_bartlett(self, dtype: str, M: int) -> None:
17631763
assert_almost_equal(np.sum(w, axis=0), 4.4444, 4)
17641764

17651765
def test_blackman(self, dtype: str, M: int) -> None:
1766-
scalar = np.array(M, dtype=dtype)[()]
1766+
scalar = M
17671767

17681768
w = blackman(scalar)
1769-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1769+
ref_dtype = np.result_type(dtype, np.float64)
17701770
assert w.dtype == ref_dtype
17711771

17721772
# check symmetry
@@ -1781,10 +1781,10 @@ def test_blackman(self, dtype: str, M: int) -> None:
17811781
assert_almost_equal(np.sum(w, axis=0), 3.7800, 4)
17821782

17831783
def test_kaiser(self, dtype: str, M: int) -> None:
1784-
scalar = np.array(M, dtype=dtype)[()]
1784+
scalar = M
17851785

17861786
w = kaiser(scalar, 0)
1787-
ref_dtype = np.result_type(scalar.dtype, np.float64)
1787+
ref_dtype = np.result_type(dtype, np.float64)
17881788
assert w.dtype == ref_dtype
17891789

17901790
# check symmetry

0 commit comments

Comments
 (0)