Skip to content

Commit 3eeeeb5

Browse files
committed
update tests
1 parent a53f732 commit 3eeeeb5

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

dpnp/backend/extensions/window/hamming_kernel.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ sycl::event hamming_impl(sycl::queue &q,
6969

7070
T *res = reinterpret_cast<T *>(result);
7171

72-
sycl::event choose_ev = q.submit([&](sycl::handler &cgh) {
72+
sycl::event hamming_ev = q.submit([&](sycl::handler &cgh) {
7373
cgh.depends_on(depends);
7474

7575
cgh.parallel_for(sycl::range<1>(nelems),
7676
HammingFunctor<T>(res, nelems));
7777
});
7878

79-
return choose_ev;
79+
return hamming_ev;
8080
}
8181

8282
} // namespace dpnp::extensions::window::kernels

dpnp/dpnp_iface_window.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,37 @@
4141
# pylint: disable=protected-access
4242

4343
import dpctl.utils as dpu
44+
import numpy
4445

4546
import dpnp
4647
import dpnp.backend.extensions.window._window_impl as wi
4748

4849
__all__ = ["hamming"]
4950

5051

52+
def _validate_input(val):
53+
54+
is_numpy_array = isinstance(val, numpy.ndarray)
55+
is_array = dpnp.is_supported_array_type(val) or is_numpy_array
56+
if is_array:
57+
is_0d_arr = val.ndim == 0
58+
is_int_float = dpnp.issubdtype(val.dtype, (dpnp.integer, dpnp.floating))
59+
raise_error = not (is_0d_arr and is_int_float)
60+
if not raise_error:
61+
is_nan = numpy.isnan(val) if is_numpy_array else dpnp.isnan(val)
62+
is_inf = numpy.isinf(val) if is_numpy_array else dpnp.isinf(val)
63+
raise_error = is_nan or is_inf
64+
else:
65+
is_int = isinstance(val, (int, numpy.integer, dpnp.integer))
66+
is_float = isinstance(val, (float, numpy.floating, dpnp.floating))
67+
raise_error = not (is_int or is_float)
68+
if not raise_error:
69+
raise_error = val in [numpy.inf, -numpy.inf, numpy.nan]
70+
71+
if raise_error:
72+
raise TypeError("M must be an integer")
73+
74+
5175
def hamming(M, device=None, usm_type=None, sycl_queue=None):
5276
r"""
5377
Return the Hamming window.
@@ -127,8 +151,9 @@ def hamming(M, device=None, usm_type=None, sycl_queue=None):
127151
128152
"""
129153

130-
if not isinstance(M, (int, float, dpnp.integer, dpnp.floating)):
131-
raise TypeError("M must be an integer")
154+
# if not isinstance(M, (int, float, dpnp.integer, dpnp.floating)):
155+
# raise TypeError("M must be an integer")
156+
_validate_input(M)
132157

133158
cfd_kwarg = {
134159
"device": device,

dpnp/tests/test_window.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,44 @@
77
from .helper import assert_dtype_allclose
88

99

10+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
1011
@pytest.mark.parametrize("func", ["hamming"])
1112
@pytest.mark.parametrize(
12-
"M", [0, 1, 5.0, numpy.int64(0), numpy.int32(1), numpy.float32(5)]
13+
"M",
14+
[
15+
True,
16+
False,
17+
0,
18+
dpnp.int32(1),
19+
4,
20+
5.0,
21+
dpnp.float32(6),
22+
dpnp.array(7),
23+
numpy.array(8),
24+
],
1325
)
1426
def test_window(func, M):
1527
result = getattr(dpnp, func)(M)
28+
29+
if isinstance(M, dpnp.ndarray):
30+
M = M.asnumpy()
1631
expected = getattr(numpy, func)(M)
1732

1833
assert_dtype_allclose(result, expected)
1934

2035

36+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
2137
@pytest.mark.parametrize("func", ["hamming"])
22-
@pytest.mark.parametrize("M", [5 + 4j, numpy.array(5)])
38+
@pytest.mark.parametrize(
39+
"M",
40+
[
41+
5 + 4j,
42+
numpy.array(5 + 4j),
43+
dpnp.array([5]),
44+
numpy.inf,
45+
numpy.array(-numpy.inf),
46+
dpnp.array(dpnp.nan),
47+
],
48+
)
2349
def test_window_error(func, M):
2450
assert_raises(TypeError, getattr(dpnp, func), M)

0 commit comments

Comments
 (0)