Skip to content

Commit d2c1526

Browse files
authored
Remove skipping of complex prod tests on some hardware (#1633)
These tests were skipped due to 2024.0 compiler issues. With 2024.1 compiler, this should no longer be an issue
1 parent a0c2aac commit d2c1526

File tree

1 file changed

+0
-22
lines changed

1 file changed

+0
-22
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import pytest
1818

1919
import dpctl.tensor as dpt
20-
import dpctl.utils as du
2120
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2221

2322
_all_dtypes = [
@@ -242,27 +241,12 @@ def test_sum_axis1_axis0():
242241
assert dpt.allclose(m, expected, atol=tol, rtol=tol)
243242

244243

245-
def _any_complex(dtypes):
246-
return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes)
247-
248-
249-
def _skip_on_this_device(sycl_dev):
250-
device_mask = du.intel_device_info(sycl_dev).get("device_id", 0) & 0xFF00
251-
return device_mask in [0x3E00, 0x9B00]
252-
253-
254244
@pytest.mark.parametrize("arg_dtype", _all_dtypes[1:])
255245
def test_prod_arg_dtype_default_output_dtype_matrix(arg_dtype):
256246
q = get_queue_or_skip()
257247
skip_if_dtype_not_supported(arg_dtype, q)
258248

259249
arg_dtype = dpt.dtype(arg_dtype)
260-
if _any_complex((arg_dtype,)):
261-
if _skip_on_this_device(q.sycl_device):
262-
pytest.skip(
263-
"Product reduction for complex output are known "
264-
"to fail for Gen9 with 2024.0 compiler"
265-
)
266250

267251
m = dpt.ones(100, dtype=arg_dtype)
268252
r = dpt.prod(m)
@@ -316,12 +300,6 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):
316300

317301
out_dtype = dpt.dtype(out_dtype)
318302
arg_dtype = dpt.dtype(arg_dtype)
319-
if _any_complex((arg_dtype, out_dtype)):
320-
if _skip_on_this_device(q.sycl_device):
321-
pytest.skip(
322-
"Product reduction for complex output are known "
323-
"to fail for Gen9 with 2024.0 compiler"
324-
)
325303

326304
m = dpt.ones(100, dtype=arg_dtype)
327305
r = dpt.prod(m, dtype=out_dtype)

0 commit comments

Comments
 (0)