3
3
import array_api_strict
4
4
5
5
6
- @pytest .mark .parametrize ("func_name" , ("fft" , "ifft" , "fftn" , "ifftn" , "irfft" ,
7
- "irfftn" , "hfft" , "fftshift" , "ifftshift" ))
6
+ @pytest .mark .parametrize (
7
+ "func_name" ,
8
+ (
9
+ "fft" ,
10
+ "ifft" ,
11
+ "fftn" ,
12
+ "ifftn" ,
13
+ "irfft" ,
14
+ "irfftn" ,
15
+ "hfft" ,
16
+ "fftshift" ,
17
+ "ifftshift" ,
18
+ ),
19
+ )
8
20
def test_fft_device_support_complex (func_name ):
9
21
func = getattr (array_api_strict .fft , func_name )
10
- x = array_api_strict .asarray ([1 , 2. ],
11
- dtype = array_api_strict .complex64 ,
12
- device = array_api_strict .Device ("device1" ))
22
+ x = array_api_strict .asarray (
23
+ [1 , 2.0 ],
24
+ dtype = array_api_strict .complex64 ,
25
+ device = array_api_strict .Device ("device1" ),
26
+ )
13
27
y = func (x )
14
28
15
29
assert x .device == y .device
@@ -18,8 +32,7 @@ def test_fft_device_support_complex(func_name):
18
32
@pytest .mark .parametrize ("func_name" , ("rfft" , "rfftn" , "ihfft" ))
19
33
def test_fft_device_support_real (func_name ):
20
34
func = getattr (array_api_strict .fft , func_name )
21
- x = array_api_strict .asarray ([1 , 2. ],
22
- device = array_api_strict .Device ("device1" ))
35
+ x = array_api_strict .asarray ([1 , 2.0 ], device = array_api_strict .Device ("device1" ))
23
36
y = func (x )
24
37
25
- assert x .device == y .device
38
+ assert x .device == y .device
0 commit comments