Skip to content

Changes to elementwise function tests that use complex data types #1412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions dpctl/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import sys

import pytest
from _device_attributes_checks import (
check,
device_selector,
Expand All @@ -38,3 +39,31 @@
"suppress_invalid_numpy_warnings",
"valid_filter",
]


def pytest_configure(config):
config.addinivalue_line(
"markers",
"broken_complex: Specified again to remove warnings ",
)


def pytest_addoption(parser):
parser.addoption(
"--runcomplex",
action="store_true",
default=False,
help="run broken complex tests on Windows",
)


def pytest_collection_modifyitems(config, items):
if config.getoption("--runcomplex"):
return
skip_complex = pytest.mark.skipif(
os.name == "nt",
reason="need --runcomplex option to run on Windows",
)
for item in items:
if "broken_complex" in item.keywords:
item.add_marker(skip_complex)
31 changes: 10 additions & 21 deletions dpctl/tests/elementwise/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@
import dpctl.tensor as dpt
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import (
_all_dtypes,
_complex_fp_dtypes,
_no_complex_dtypes,
_real_fp_dtypes,
_usm_types,
)
from .utils import _all_dtypes, _complex_fp_dtypes, _real_fp_dtypes, _usm_types


@pytest.mark.parametrize("dtype", _all_dtypes)
Expand Down Expand Up @@ -131,26 +125,21 @@ def test_abs_complex(dtype):
)


@pytest.mark.parametrize("dtype", _no_complex_dtypes)
def test_abs_out_overlap(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5, 4))

Xnp = dpt.asnumpy(X)
Ynp = np.abs(Xnp, out=Xnp)
def test_abs_out_overlap():
get_queue_or_skip()

X = dpt.arange(-3, 3, 1, dtype="i4")
expected = dpt.asarray([3, 2, 1, 0, 1, 2], dtype="i4")
Y = dpt.abs(X, out=X)

assert Y is X
assert np.allclose(dpt.asnumpy(X), Xnp)
assert dpt.all(expected == X)

Ynp = np.abs(Xnp, out=Xnp[::-1])
X = dpt.arange(-3, 3, 1, dtype="i4")
expected = expected[::-1]
Y = dpt.abs(X, out=X[::-1])
assert Y is not X
assert np.allclose(dpt.asnumpy(X), Xnp)
assert np.allclose(dpt.asnumpy(Y), Ynp)
assert dpt.all(expected == X)


@pytest.mark.parametrize("dtype", _real_fp_dtypes)
Expand Down
23 changes: 0 additions & 23 deletions dpctl/tests/elementwise/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,26 +216,3 @@ def test_exp_complex_special_cases(dtype):
tol = 8 * dpt.finfo(dtype).resolution
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)


@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
def test_exp_out_overlap(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5))

Xnp = dpt.asnumpy(X)
Ynp = np.exp(Xnp, out=Xnp)

Y = dpt.exp(X, out=X)
tol = 8 * dpt.finfo(Y.dtype).resolution
assert Y is X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)

Ynp = np.exp(Xnp, out=Xnp[::-1])
Y = dpt.exp(X, out=X[::-1])
assert Y is not X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)
29 changes: 1 addition & 28 deletions dpctl/tests/elementwise/test_hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.

import itertools
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -271,7 +270,7 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)


@pytest.mark.skipif(os.name == "nt", reason="Known problems on Windows")
@pytest.mark.broken_complex
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
Expand All @@ -294,29 +293,3 @@ def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
assert_allclose(
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
def test_hyper_out_overlap(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.linspace(-np.pi / 2, np.pi / 2, 60, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5, 4))

tol = 8 * dpt.finfo(dtype).resolution
Xnp = dpt.asnumpy(X)
with np.errstate(all="ignore"):
Ynp = np_call(Xnp, out=Xnp)

Y = dpt_call(X, out=X)
assert Y is X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)

with np.errstate(all="ignore"):
Ynp = np_call(Xnp, out=Xnp[::-1])
Y = dpt_call(X, out=X[::-1])
assert Y is not X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)
24 changes: 0 additions & 24 deletions dpctl/tests/elementwise/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,3 @@ def test_log_special_cases():
)

assert_equal(dpt.asnumpy(Y), expected)


@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
def test_log_out_overlap(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.linspace(5, 35, 60, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5, 4))

Xnp = dpt.asnumpy(X)
Ynp = np.log(Xnp, out=Xnp)

Y = dpt.log(X, out=X)
assert Y is X

tol = 8 * dpt.finfo(Y.dtype).resolution
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)

Ynp = np.log(Xnp, out=Xnp[::-1])
Y = dpt.log(X, out=X[::-1])
assert Y is not X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)
23 changes: 0 additions & 23 deletions dpctl/tests/elementwise/test_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,26 +213,3 @@ def test_round_complex_special_cases(dtype):
tol = 8 * dpt.finfo(dtype).resolution
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)


@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
def test_round_out_overlap(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.linspace(0, 1, 15, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5))

Xnp = dpt.asnumpy(X)
Ynp = np.round(Xnp, out=Xnp)

Y = dpt.round(X, out=X)
tol = 8 * dpt.finfo(Y.dtype).resolution
assert Y is X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)

Ynp = np.round(Xnp, out=Xnp[::-1])
Y = dpt.round(X, out=X[::-1])
assert Y is not X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)
24 changes: 0 additions & 24 deletions dpctl/tests/elementwise/test_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,30 +122,6 @@ def test_sqrt_order(dtype):
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)


@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
def test_sqrt_out_overlap(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5, 4))

Xnp = dpt.asnumpy(X)
Ynp = np.sqrt(Xnp, out=Xnp)

Y = dpt.sqrt(X, out=X)
assert Y is X

tol = 8 * dpt.finfo(Y.dtype).resolution
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)

Ynp = np.sqrt(Xnp, out=Xnp[::-1])
Y = dpt.sqrt(X, out=X[::-1])
assert Y is not X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)


@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
def test_sqrt_special_cases():
q = get_queue_or_skip()
Expand Down
26 changes: 0 additions & 26 deletions dpctl/tests/elementwise/test_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,29 +97,3 @@ def test_square_special_cases(dtype):
rtol=tol,
equal_nan=True,
)


@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
def test_square_out_overlap(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5, 4))

Xnp = dpt.asnumpy(X)
Ynp = np.square(Xnp, out=Xnp)

Y = dpt.square(X, out=X)
assert Y is X
assert np.allclose(dpt.asnumpy(X), Xnp)

X = dpt.linspace(0, 35, 60, dtype=dtype, sycl_queue=q)
X = dpt.reshape(X, (3, 5, 4))
Xnp = dpt.asnumpy(X)

Ynp = np.square(Xnp, out=Xnp[::-1])
Y = dpt.square(X, out=X[::-1])
assert Y is not X
assert np.allclose(dpt.asnumpy(X), Xnp)
assert np.allclose(dpt.asnumpy(Y), Ynp)
38 changes: 1 addition & 37 deletions dpctl/tests/elementwise/test_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.

import itertools
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -268,7 +267,7 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)


@pytest.mark.skipif(os.name == "nt", reason="Known problem on Windows")
@pytest.mark.broken_complex
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_trig_complex_special_cases(np_call, dpt_call, dtype):
Expand All @@ -291,38 +290,3 @@ def test_trig_complex_special_cases(np_call, dpt_call, dtype):
assert_allclose(
dpt.asnumpy(dpt.imag(dpt_call(Xc))), np.imag(Ynp), atol=tol, rtol=tol
)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
def test_trig_out_overlap(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

if os.name == "nt" and dpt.isdtype(dpt.dtype(dtype), "complex floating"):
pytest.skip("Know problems on Windows")

if np_call == np.tan:
X = dpt.linspace(-np.pi / 2, np.pi / 2, 64, dtype=dtype, sycl_queue=q)[
2:-2
]
tol = 50 * dpt.finfo(dtype).resolution
else:
X = dpt.linspace(-np.pi / 2, np.pi / 2, 60, dtype=dtype, sycl_queue=q)
tol = 8 * dpt.finfo(dtype).resolution
X = dpt.reshape(X, (3, 5, 4))

Xnp = dpt.asnumpy(X)
with np.errstate(all="ignore"):
Ynp = np_call(Xnp, out=Xnp)

Y = dpt_call(X, out=X)
assert Y is X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)

with np.errstate(all="ignore"):
Ynp = np_call(Xnp, out=Xnp[::-1])
Y = dpt_call(X, out=X[::-1])
assert Y is not X
assert_allclose(dpt.asnumpy(X), Xnp, atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(Y), Ynp, atol=tol, rtol=tol)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ omit = [
]

[tool.pytest.ini.options]
markers = [
"broken_complex: mark a test that is skipped on Windows due to complex implementation",
]
minversion = "6.0"
norecursedirs= [
".*", "*.egg*", "build", "dist", "conda-recipe",
Expand Down