Skip to content

Commit 3e3bd0a

Browse files
committed
Add tests for dtype keyword
1 parent cb055f1 commit 3e3bd0a

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

dpnp/tests/test_fft.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -411,22 +411,35 @@ def test_fft_error(self, xp):
411411
assert_raises(IndexError, xp.fft.fft2, a)
412412

413413

414+
@pytest.mark.parametrize("func", ["fftfreq", "rfftfreq"])
414415
class TestFftfreq:
415-
@pytest.mark.parametrize("func", ["fftfreq", "rfftfreq"])
416416
@pytest.mark.parametrize("n", [10, 20])
417417
@pytest.mark.parametrize("d", [0.5, 2])
418418
def test_fftfreq(self, func, n, d):
419-
expected = getattr(dpnp.fft, func)(n, d)
420-
result = getattr(numpy.fft, func)(n, d)
421-
assert_dtype_allclose(expected, result)
419+
result = getattr(dpnp.fft, func)(n, d)
420+
expected = getattr(numpy.fft, func)(n, d)
421+
assert_dtype_allclose(result, expected)
422422

423-
@pytest.mark.parametrize("func", ["fftfreq", "rfftfreq"])
424-
def test_error(self, func):
425-
# n should be an integer
426-
assert_raises(ValueError, getattr(dpnp.fft, func), 10.0)
423+
@pytest.mark.parametrize("dt", [None] + get_float_dtypes())
424+
def test_dtype(self, func, dt):
425+
n = 15
426+
result = getattr(dpnp.fft, func)(n, dtype=dt)
427+
expected = getattr(numpy.fft, func)(n).astype(dt)
428+
assert_dtype_allclose(result, expected)
427429

428-
# d should be an scalar
429-
assert_raises(ValueError, getattr(dpnp.fft, func), 10, (2,))
430+
def test_error(self, func):
431+
func = getattr(dpnp.fft, func)
432+
# n must be an integer
433+
assert_raises(ValueError, func, 10.0)
434+
435+
# d must be an scalar
436+
assert_raises(ValueError, func, 10, (2,))
437+
438+
# dtype must be None or a real-valued floating-point dtype
439+
# which is passed as a keyword argument only
440+
assert_raises(TypeError, func, 10, 2, None)
441+
assert_raises(ValueError, func, 10, 2, dtype=dpnp.intp)
442+
assert_raises(ValueError, func, 10, 2, dtype=dpnp.complex64)
430443

431444

432445
class TestFftn:

0 commit comments

Comments
 (0)