Skip to content

update elementwise tests for exp, real, imag, and conj #1258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@
Returns:
usm_narray:
An array containing the element-wise absolute values.
For complex input, the absolute value is its magnitude. The data type
of the returned array is determined by the Type Promotion Rules.
For complex input, the absolute value is its magnitude.
If `x` has a real-valued data type, the returned array has the
same data type as `x`. If `x` has a complex floating-point data type,
the returned array has a real-valued floating-point data type whose
precision matches the precision of `x`.
"""

abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_)
Expand Down Expand Up @@ -133,8 +136,8 @@
Default: "K".
Returns:
usm_narray:
An array containing the element-wise conjugate values. The data type
of the returned array is determined by the Type Promotion Rules.
An array containing the element-wise conjugate values.
The returned array has the same data type as `x`.
"""

conj = UnaryElementwiseFunc(
Expand Down Expand Up @@ -216,8 +219,7 @@
Returns:
usm_narray:
An array containing the result of element-wise equality comparison.
The data type of the returned array is determined by the
Type Promotion Rules.
The returned array has a data type of `bool`.
"""

equal = BinaryElementwiseFunc(
Expand Down Expand Up @@ -264,6 +266,8 @@
Return:
usm_ndarray:
An array containing the element-wise exp(x)-1 values.
The data type of the returned array is determined by the Type
Promotion Rules.
"""

expm1 = UnaryElementwiseFunc(
Expand All @@ -288,7 +292,7 @@
Second input array, also expected to have numeric data type.
Returns:
usm_narray:
an array containing the result of element-wise floor division.
an array containing the result of element-wise floor division.
The data type of the returned array is determined by the Type
Promotion Rules.
"""
Expand Down Expand Up @@ -319,8 +323,7 @@
Returns:
usm_narray:
An array containing the result of element-wise greater-than comparison.
The data type of the returned array is determined by the
Type Promotion Rules.
The returned array has a data type of `bool`.
"""

greater = BinaryElementwiseFunc(
Expand All @@ -347,8 +350,7 @@
usm_narray:
An array containing the result of element-wise greater-than or equal-to
comparison.
The data type of the returned array is determined by the
Type Promotion Rules.
The returned array has a data type of `bool`.
"""

greater_equal = BinaryElementwiseFunc(
Expand Down Expand Up @@ -376,8 +378,10 @@
Returns:
usm_narray:
An array containing the element-wise imaginary component of input.
The data type of the returned array is determined
by the Type Promotion Rules.
If the input is a real-valued data type, the returned array has
the same datat type. If the input is a complex floating-point
data type, the returned array has a floating-point data type
with the same floating-point precision as complex input.
"""

imag = UnaryElementwiseFunc(
Expand All @@ -403,7 +407,7 @@
usm_narray:
An array which is True where `x` is not positive infinity,
negative infinity, or NaN, False otherwise.
The data type of the returned array is boolean.
The data type of the returned array is `bool`.
"""

isfinite = UnaryElementwiseFunc(
Expand All @@ -428,7 +432,7 @@
Returns:
usm_narray:
An array which is True where `x` is positive or negative infinity,
False otherwise. The data type of the returned array is boolean.
False otherwise. The data type of the returned array is `bool`.
"""

isinf = UnaryElementwiseFunc(
Expand All @@ -453,7 +457,7 @@
Returns:
usm_narray:
An array which is True where x is NaN, False otherwise.
The data type of the returned array is boolean.
The data type of the returned array is `bool`.
"""

isnan = UnaryElementwiseFunc(
Expand Down Expand Up @@ -481,8 +485,7 @@
Returns:
usm_narray:
An array containing the result of element-wise less-than comparison.
The data type of the returned array is determined by the
Type Promotion Rules.
The returned array has a data type of `bool`.
"""

less = BinaryElementwiseFunc(
Expand All @@ -508,9 +511,7 @@
Returns:
usm_narray:
An array containing the result of element-wise less-than or equal-to
comparison.
The data type of the returned array is determined by the
Type Promotion Rules.
comparison. The returned array has a data type of `bool`.
"""

less_equal = BinaryElementwiseFunc(
Expand All @@ -536,6 +537,8 @@
Return:
usm_ndarray:
An array containing the element-wise natural logarithm values.
The data type of the returned array is determined by the Type
Promotion Rules.
"""

log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
Expand All @@ -555,7 +558,8 @@
Default: "K".
Return:
usm_ndarray:
An array containing the element-wise log(1+x) values.
An array containing the element-wise log(1+x) values. The data type
of the returned array is determined by the Type Promotion Rules.
"""

log1p = UnaryElementwiseFunc(
Expand Down Expand Up @@ -804,8 +808,7 @@
Returns:
usm_narray:
an array containing the result of element-wise inequality comparison.
The data type of the returned array is determined by the
Type Promotion Rules.
The returned array has a data type of `bool`.
"""

not_equal = BinaryElementwiseFunc(
Expand Down Expand Up @@ -873,8 +876,8 @@
Default: "K".
Returns:
usm_narray:
An array containing the element-wise projection. The data
type of the returned array is determined by the Type Promotion Rules.
An array containing the element-wise projection.
The returned array has the same data type as `x`.
"""

proj = UnaryElementwiseFunc(
Expand All @@ -898,8 +901,11 @@
Default: "K".
Returns:
usm_narray:
An array containing the element-wise real component of input. The data
type of the returned array is determined by the Type Promotion Rules.
An array containing the element-wise real component of input.
If the input is a real-valued data type, the returned array has
the same datat type. If the input is a complex floating-point
data type, the returned array has a floating-point data type
with the same floating-point precision as complex input.
"""

real = UnaryElementwiseFunc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,58 @@ template <typename argT, typename resT> struct ExpFunctor

resT operator()(const argT &in)
{
return std::exp(in);
if constexpr (is_complex<argT>::value) {
using realT = typename argT::value_type;

constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();

const realT x = std::real(in);
const realT y = std::imag(in);
if (std::isfinite(x)) {
if (std::isfinite(y)) {
return std::exp(in);
}
else {
return resT{q_nan, q_nan};
}
}
else if (std::isnan(x)) {
/* x is nan */
if (y == realT(0)) {
return resT{in};
}
else {
return resT{x, q_nan};
}
}
else {
if (!std::signbit(x)) { /* x is +inf */
if (y == realT(0)) {
return resT{x, y};
}
else if (std::isfinite(y)) {
return resT{x * std::cos(y), x * std::sin(y)};
}
else {
/* x = +inf, y = +-inf || nan */
return resT{x, q_nan};
}
}
else { /* x is -inf */
if (std::isfinite(y)) {
realT exp_x = std::exp(x);
return resT{exp_x * std::cos(y), exp_x * std::sin(y)};
}
else {
/* x = -inf, y = +-inf || nan */
return resT{0, 0};
}
}
}
}
else {
return std::exp(in);
}
}
};

Expand Down
64 changes: 40 additions & 24 deletions dpctl/tests/elementwise/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def test_complex_order(np_call, dpt_call, dtype):
X[..., 0::2] = np.pi / 6 + 1j * np.pi / 3
X[..., 1::2] = np.pi / 3 + 1j * np.pi / 6

for ord in ["C", "F", "A", "K"]:
for perms in itertools.permutations(range(4)):
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
for perms in itertools.permutations(range(4)):
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
expected_Y = np_call(dpt.asnumpy(U))
for ord in ["C", "F", "A", "K"]:
Y = dpt_call(U, order=ord)
expected_Y = np_call(dpt.asnumpy(U))
assert np.allclose(dpt.asnumpy(Y), expected_Y)


Expand Down Expand Up @@ -164,35 +164,51 @@ def test_projection(dtype):
"np_call, dpt_call",
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
)
@pytest.mark.parametrize("dtype", ["f4", "f8"])
@pytest.mark.parametrize("stride", [-1, 1, 2, 4, 5])
def test_complex_strided(np_call, dpt_call, dtype, stride):
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_complex_strided(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

N = 100
rng = np.random.default_rng(42)
x1 = rng.standard_normal(N, dtype)
x2 = 1j * rng.standard_normal(N, dtype)
x = x1 + x2
y = np_call(x[::stride])
z = dpt_call(dpt.asarray(x[::stride]))
np.random.seed(42)
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
sizes = [2, 4, 6, 8, 9, 24, 72]
tol = 8 * dpt.finfo(dtype).resolution

tol = 8 * dpt.finfo(y.dtype).resolution
assert_allclose(y, dpt.asnumpy(z), atol=tol, rtol=tol)
low = -1000.0
high = 1000.0
for ii in sizes:
x1 = np.random.uniform(low=low, high=high, size=ii)
x2 = np.random.uniform(low=low, high=high, size=ii)
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
X = dpt.asarray(Xnp)
Ynp = np_call(Xnp)
for jj in strides:
assert_allclose(
dpt.asnumpy(dpt_call(X[::jj])),
Ynp[::jj],
atol=tol,
rtol=tol,
)


@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_complex_special_cases(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = [np.nan, -np.nan, np.inf, -np.inf]
with np.errstate(all="ignore"):
Xnp = 1j * np.array(x, dtype=dtype)
X = dpt.asarray(Xnp, dtype=Xnp.dtype)
x = [np.nan, -np.nan, np.inf, -np.inf, +0.0, -0.0]
xc = [complex(*val) for val in itertools.product(x, repeat=2)]

Xc_np = np.array(xc, dtype=dtype)
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)

tol = 8 * dpt.finfo(dtype).resolution
assert_allclose(dpt.asnumpy(dpt.real(X)), np.real(Xnp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.imag(X)), np.imag(Xnp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.conj(X)), np.conj(Xnp), atol=tol, rtol=tol)
assert_allclose(
dpt.asnumpy(dpt.real(Xc)), np.real(Xc_np), atol=tol, rtol=tol
)
assert_allclose(
dpt.asnumpy(dpt.imag(Xc)), np.imag(Xc_np), atol=tol, rtol=tol
)
assert_allclose(
dpt.asnumpy(dpt.conj(Xc)), np.conj(Xc_np), atol=tol, rtol=tol
)
Loading