Skip to content

Unmute prod tests muted for 2024.0 compiler #1633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions dpctl/tests/test_tensor_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import pytest

import dpctl.tensor as dpt
import dpctl.utils as du
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

_all_dtypes = [
Expand Down Expand Up @@ -242,27 +241,12 @@ def test_sum_axis1_axis0():
assert dpt.allclose(m, expected, atol=tol, rtol=tol)


def _any_complex(dtypes):
return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes)


def _skip_on_this_device(sycl_dev):
device_mask = du.intel_device_info(sycl_dev).get("device_id", 0) & 0xFF00
return device_mask in [0x3E00, 0x9B00]


@pytest.mark.parametrize("arg_dtype", _all_dtypes[1:])
def test_prod_arg_dtype_default_output_dtype_matrix(arg_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arg_dtype, q)

arg_dtype = dpt.dtype(arg_dtype)
if _any_complex((arg_dtype,)):
if _skip_on_this_device(q.sycl_device):
pytest.skip(
"Product reduction for complex output are known "
"to fail for Gen9 with 2024.0 compiler"
)

m = dpt.ones(100, dtype=arg_dtype)
r = dpt.prod(m)
Expand Down Expand Up @@ -316,12 +300,6 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):

out_dtype = dpt.dtype(out_dtype)
arg_dtype = dpt.dtype(arg_dtype)
if _any_complex((arg_dtype, out_dtype)):
if _skip_on_this_device(q.sycl_device):
pytest.skip(
"Product reduction for complex output are known "
"to fail for Gen9 with 2024.0 compiler"
)

m = dpt.ones(100, dtype=arg_dtype)
r = dpt.prod(m, dtype=out_dtype)
Expand Down