Skip to content

Commit e697454

Browse files
committed
Remove skipping of complex prod tests on some hardware
These tests were skipped due to 2024.0 compiler issues. With 2024.1 compiler, this should no longer be an issue
1 parent a0c2aac commit e697454

File tree

1 file changed

+0
-22
lines changed

1 file changed

+0
-22
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import pytest
1818

1919
import dpctl.tensor as dpt
20-
import dpctl.utils as du
2120
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2221

2322
_all_dtypes = [
@@ -242,27 +241,12 @@ def test_sum_axis1_axis0():
242241
assert dpt.allclose(m, expected, atol=tol, rtol=tol)
243242

244243

245-
def _any_complex(dtypes):
246-
return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes)
247-
248-
249-
def _skip_on_this_device(sycl_dev):
250-
device_mask = du.intel_device_info(sycl_dev).get("device_id", 0) & 0xFF00
251-
return device_mask in [0x3E00, 0x9B00]
252-
253-
254244
@pytest.mark.parametrize("arg_dtype", _all_dtypes[1:])
255245
def test_prod_arg_dtype_default_output_dtype_matrix(arg_dtype):
256246
q = get_queue_or_skip()
257247
skip_if_dtype_not_supported(arg_dtype, q)
258248

259249
arg_dtype = dpt.dtype(arg_dtype)
260-
if _any_complex((arg_dtype,)):
261-
if _skip_on_this_device(q.sycl_device):
262-
pytest.skip(
263-
"Product reduction for complex output are known "
264-
"to fail for Gen9 with 2024.0 compiler"
265-
)
266250

267251
m = dpt.ones(100, dtype=arg_dtype)
268252
r = dpt.prod(m)
@@ -316,12 +300,6 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):
316300

317301
out_dtype = dpt.dtype(out_dtype)
318302
arg_dtype = dpt.dtype(arg_dtype)
319-
if _any_complex((arg_dtype, out_dtype)):
320-
if _skip_on_this_device(q.sycl_device):
321-
pytest.skip(
322-
"Product reduction for complex output are known "
323-
"to fail for Gen9 with 2024.0 compiler"
324-
)
325303

326304
m = dpt.ones(100, dtype=arg_dtype)
327305
r = dpt.prod(m, dtype=out_dtype)

0 commit comments

Comments
 (0)