|
17 | 17 | import pytest
|
18 | 18 |
|
19 | 19 | import dpctl.tensor as dpt
|
20 |
| -import dpctl.utils as du |
21 | 20 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
|
22 | 21 |
|
23 | 22 | _all_dtypes = [
|
@@ -242,27 +241,12 @@ def test_sum_axis1_axis0():
|
242 | 241 | assert dpt.allclose(m, expected, atol=tol, rtol=tol)
|
243 | 242 |
|
244 | 243 |
|
245 |
| -def _any_complex(dtypes): |
246 |
| - return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes) |
247 |
| - |
248 |
| - |
249 |
| -def _skip_on_this_device(sycl_dev): |
250 |
| - device_mask = du.intel_device_info(sycl_dev).get("device_id", 0) & 0xFF00 |
251 |
| - return device_mask in [0x3E00, 0x9B00] |
252 |
| - |
253 |
| - |
254 | 244 | @pytest.mark.parametrize("arg_dtype", _all_dtypes[1:])
|
255 | 245 | def test_prod_arg_dtype_default_output_dtype_matrix(arg_dtype):
|
256 | 246 | q = get_queue_or_skip()
|
257 | 247 | skip_if_dtype_not_supported(arg_dtype, q)
|
258 | 248 |
|
259 | 249 | arg_dtype = dpt.dtype(arg_dtype)
|
260 |
| - if _any_complex((arg_dtype,)): |
261 |
| - if _skip_on_this_device(q.sycl_device): |
262 |
| - pytest.skip( |
263 |
| - "Product reduction for complex output are known " |
264 |
| - "to fail for Gen9 with 2024.0 compiler" |
265 |
| - ) |
266 | 250 |
|
267 | 251 | m = dpt.ones(100, dtype=arg_dtype)
|
268 | 252 | r = dpt.prod(m)
|
@@ -316,12 +300,6 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):
|
316 | 300 |
|
317 | 301 | out_dtype = dpt.dtype(out_dtype)
|
318 | 302 | arg_dtype = dpt.dtype(arg_dtype)
|
319 |
| - if _any_complex((arg_dtype, out_dtype)): |
320 |
| - if _skip_on_this_device(q.sycl_device): |
321 |
| - pytest.skip( |
322 |
| - "Product reduction for complex output are known " |
323 |
| - "to fail for Gen9 with 2024.0 compiler" |
324 |
| - ) |
325 | 303 |
|
326 | 304 | m = dpt.ones(100, dtype=arg_dtype)
|
327 | 305 | r = dpt.prod(m, dtype=out_dtype)
|
|
0 commit comments