Skip to content

Commit 9235a29

Browse files
committed
update tests
1 parent cb88d05 commit 9235a29

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

tests/test_sycl_queue.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,21 +1135,21 @@ def test_out_multi_dot(device):
11351135
assert_sycl_queue_equal(result.sycl_queue, exec_q)
11361136

11371137

1138-
@pytest.mark.parametrize("type", ["complex128"])
1138+
@pytest.mark.parametrize("func", ["fft", "ifft"])
11391139
@pytest.mark.parametrize(
11401140
"device",
11411141
valid_devices,
11421142
ids=[device.filter_string for device in valid_devices],
11431143
)
1144-
def test_fft(type, device):
1145-
data = numpy.arange(100, dtype=numpy.dtype(type))
1144+
def test_fft(func, device):
1145+
data = numpy.arange(100, dtype=numpy.complex128)
11461146

11471147
dpnp_data = dpnp.array(data, device=device)
11481148

1149-
expected = numpy.fft.fft(data)
1150-
result = dpnp.fft.fft(dpnp_data)
1149+
expected = getattr(numpy.fft, func)(data)
1150+
result = getattr(dpnp.fft, func)(dpnp_data)
11511151

1152-
assert_allclose(result, expected, rtol=1e-4, atol=1e-7)
1152+
assert_dtype_allclose(result, expected)
11531153

11541154
expected_queue = dpnp_data.get_array().sycl_queue
11551155
result_queue = result.get_array().sycl_queue

tests/test_usm_type.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,17 @@ def test_eigenvalue(func, shape, usm_type):
914914
assert a.usm_type == dp_val.usm_type
915915

916916

917+
@pytest.mark.parametrize("func", ["fft", "ifft"])
918+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
919+
def test_fft(func, usm_type):
920+
data = numpy.arange(100, dtype=numpy.complex128)
921+
dpnp_data = dp.array(data, usm_type=usm_type)
922+
result = getattr(dp.fft, func)(dpnp_data)
923+
924+
assert dpnp_data.usm_type == usm_type
925+
assert result.usm_type == usm_type
926+
927+
917928
@pytest.mark.parametrize(
918929
"usm_type_matrix", list_of_usm_types, ids=list_of_usm_types
919930
)

0 commit comments

Comments
 (0)