@@ -411,22 +411,35 @@ def test_fft_error(self, xp):
411
411
assert_raises (IndexError , xp .fft .fft2 , a )
412
412
413
413
414
+ @pytest .mark .parametrize ("func" , ["fftfreq" , "rfftfreq" ])
414
415
class TestFftfreq :
415
- @pytest .mark .parametrize ("func" , ["fftfreq" , "rfftfreq" ])
416
416
@pytest .mark .parametrize ("n" , [10 , 20 ])
417
417
@pytest .mark .parametrize ("d" , [0.5 , 2 ])
418
418
def test_fftfreq (self , func , n , d ):
419
- expected = getattr (dpnp .fft , func )(n , d )
420
- result = getattr (numpy .fft , func )(n , d )
421
- assert_dtype_allclose (expected , result )
419
+ result = getattr (dpnp .fft , func )(n , d )
420
+ expected = getattr (numpy .fft , func )(n , d )
421
+ assert_dtype_allclose (result , expected )
422
422
423
- @pytest .mark .parametrize ("func" , ["fftfreq" , "rfftfreq" ])
424
- def test_error (self , func ):
425
- # n should be an integer
426
- assert_raises (ValueError , getattr (dpnp .fft , func ), 10.0 )
423
+ @pytest .mark .parametrize ("dt" , [None ] + get_float_dtypes ())
424
+ def test_dtype (self , func , dt ):
425
+ n = 15
426
+ result = getattr (dpnp .fft , func )(n , dtype = dt )
427
+ expected = getattr (numpy .fft , func )(n ).astype (dt )
428
+ assert_dtype_allclose (result , expected )
427
429
428
- # d should be an scalar
429
- assert_raises (ValueError , getattr (dpnp .fft , func ), 10 , (2 ,))
430
+ def test_error (self , func ):
431
+ func = getattr (dpnp .fft , func )
432
+ # n must be an integer
433
+ assert_raises (ValueError , func , 10.0 )
434
+
435
+ # d must be an scalar
436
+ assert_raises (ValueError , func , 10 , (2 ,))
437
+
438
+ # dtype must be None or a real-valued floating-point dtype
439
+ # which is passed as a keyword argument only
440
+ assert_raises (TypeError , func , 10 , 2 , None )
441
+ assert_raises (ValueError , func , 10 , 2 , dtype = dpnp .intp )
442
+ assert_raises (ValueError , func , 10 , 2 , dtype = dpnp .complex64 )
430
443
431
444
432
445
class TestFftn :
0 commit comments