Skip to content

Commit 7c0a54f

Browse files
authored
Merge pull request #1258 from IntelPython/update_elementwise_tests
update elementwise tests for exp, real, imag, and conj
2 parents 73a2b68 + ac20262 commit 7c0a54f

File tree

4 files changed

+230
-86
lines changed

4 files changed

+230
-86
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@
3636
Returns:
3737
usm_narray:
3838
An array containing the element-wise absolute values.
39-
For complex input, the absolute value is its magnitude. The data type
40-
of the returned array is determined by the Type Promotion Rules.
39+
For complex input, the absolute value is its magnitude.
40+
If `x` has a real-valued data type, the returned array has the
41+
same data type as `x`. If `x` has a complex floating-point data type,
42+
the returned array has a real-valued floating-point data type whose
43+
precision matches the precision of `x`.
4144
"""
4245

4346
abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_)
@@ -156,8 +159,8 @@
156159
Default: "K".
157160
Returns:
158161
usm_narray:
159-
An array containing the element-wise conjugate values. The data type
160-
of the returned array is determined by the Type Promotion Rules.
162+
An array containing the element-wise conjugate values.
163+
The returned array has the same data type as `x`.
161164
"""
162165

163166
conj = UnaryElementwiseFunc(
@@ -239,8 +242,7 @@
239242
Returns:
240243
usm_narray:
241244
An array containing the result of element-wise equality comparison.
242-
The data type of the returned array is determined by the
243-
Type Promotion Rules.
245+
The returned array has a data type of `bool`.
244246
"""
245247

246248
equal = BinaryElementwiseFunc(
@@ -287,6 +289,8 @@
287289
Return:
288290
usm_ndarray:
289291
An array containing the element-wise exp(x)-1 values.
292+
The data type of the returned array is determined by the Type
293+
Promotion Rules.
290294
"""
291295

292296
expm1 = UnaryElementwiseFunc(
@@ -334,7 +338,7 @@
334338
Second input array, also expected to have numeric data type.
335339
Returns:
336340
usm_narray:
337-
an array containing the result of element-wise floor division.
341+
an array containing the result of element-wise floor division.
338342
The data type of the returned array is determined by the Type
339343
Promotion Rules.
340344
"""
@@ -365,8 +369,7 @@
365369
Returns:
366370
usm_narray:
367371
An array containing the result of element-wise greater-than comparison.
368-
The data type of the returned array is determined by the
369-
Type Promotion Rules.
372+
The returned array has a data type of `bool`.
370373
"""
371374

372375
greater = BinaryElementwiseFunc(
@@ -393,8 +396,7 @@
393396
usm_narray:
394397
An array containing the result of element-wise greater-than or equal-to
395398
comparison.
396-
The data type of the returned array is determined by the
397-
Type Promotion Rules.
399+
The returned array has a data type of `bool`.
398400
"""
399401

400402
greater_equal = BinaryElementwiseFunc(
@@ -422,8 +424,10 @@
422424
Returns:
423425
usm_narray:
424426
An array containing the element-wise imaginary component of input.
425-
The data type of the returned array is determined
426-
by the Type Promotion Rules.
427+
If the input is a real-valued data type, the returned array has
428+
the same datat type. If the input is a complex floating-point
429+
data type, the returned array has a floating-point data type
430+
with the same floating-point precision as complex input.
427431
"""
428432

429433
imag = UnaryElementwiseFunc(
@@ -449,7 +453,7 @@
449453
usm_narray:
450454
An array which is True where `x` is not positive infinity,
451455
negative infinity, or NaN, False otherwise.
452-
The data type of the returned array is boolean.
456+
The data type of the returned array is `bool`.
453457
"""
454458

455459
isfinite = UnaryElementwiseFunc(
@@ -474,7 +478,7 @@
474478
Returns:
475479
usm_narray:
476480
An array which is True where `x` is positive or negative infinity,
477-
False otherwise. The data type of the returned array is boolean.
481+
False otherwise. The data type of the returned array is `bool`.
478482
"""
479483

480484
isinf = UnaryElementwiseFunc(
@@ -499,7 +503,7 @@
499503
Returns:
500504
usm_narray:
501505
An array which is True where x is NaN, False otherwise.
502-
The data type of the returned array is boolean.
506+
The data type of the returned array is `bool`.
503507
"""
504508

505509
isnan = UnaryElementwiseFunc(
@@ -527,8 +531,7 @@
527531
Returns:
528532
usm_narray:
529533
An array containing the result of element-wise less-than comparison.
530-
The data type of the returned array is determined by the
531-
Type Promotion Rules.
534+
The returned array has a data type of `bool`.
532535
"""
533536

534537
less = BinaryElementwiseFunc(
@@ -554,9 +557,7 @@
554557
Returns:
555558
usm_narray:
556559
An array containing the result of element-wise less-than or equal-to
557-
comparison.
558-
The data type of the returned array is determined by the
559-
Type Promotion Rules.
560+
comparison. The returned array has a data type of `bool`.
560561
"""
561562

562563
less_equal = BinaryElementwiseFunc(
@@ -582,6 +583,8 @@
582583
Return:
583584
usm_ndarray:
584585
An array containing the element-wise natural logarithm values.
586+
The data type of the returned array is determined by the Type
587+
Promotion Rules.
585588
"""
586589

587590
log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
@@ -601,7 +604,8 @@
601604
Default: "K".
602605
Return:
603606
usm_ndarray:
604-
An array containing the element-wise log(1+x) values.
607+
An array containing the element-wise log(1+x) values. The data type
608+
of the returned array is determined by the Type Promotion Rules.
605609
"""
606610

607611
log1p = UnaryElementwiseFunc(
@@ -850,8 +854,7 @@
850854
Returns:
851855
usm_narray:
852856
an array containing the result of element-wise inequality comparison.
853-
The data type of the returned array is determined by the
854-
Type Promotion Rules.
857+
The returned array has a data type of `bool`.
855858
"""
856859

857860
not_equal = BinaryElementwiseFunc(
@@ -919,8 +922,8 @@
919922
Default: "K".
920923
Returns:
921924
usm_narray:
922-
An array containing the element-wise projection. The data
923-
type of the returned array is determined by the Type Promotion Rules.
925+
An array containing the element-wise projection.
926+
The returned array has the same data type as `x`.
924927
"""
925928

926929
proj = UnaryElementwiseFunc(
@@ -944,8 +947,11 @@
944947
Default: "K".
945948
Returns:
946949
usm_narray:
947-
An array containing the element-wise real component of input. The data
948-
type of the returned array is determined by the Type Promotion Rules.
950+
An array containing the element-wise real component of input.
951+
If the input is a real-valued data type, the returned array has
952+
the same datat type. If the input is a complex floating-point
953+
data type, the returned array has a floating-point data type
954+
with the same floating-point precision as complex input.
949955
"""
950956

951957
real = UnaryElementwiseFunc(

dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,58 @@ template <typename argT, typename resT> struct ExpFunctor
6464

6565
resT operator()(const argT &in)
6666
{
67-
return std::exp(in);
67+
if constexpr (is_complex<argT>::value) {
68+
using realT = typename argT::value_type;
69+
70+
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
71+
72+
const realT x = std::real(in);
73+
const realT y = std::imag(in);
74+
if (std::isfinite(x)) {
75+
if (std::isfinite(y)) {
76+
return std::exp(in);
77+
}
78+
else {
79+
return resT{q_nan, q_nan};
80+
}
81+
}
82+
else if (std::isnan(x)) {
83+
/* x is nan */
84+
if (y == realT(0)) {
85+
return resT{in};
86+
}
87+
else {
88+
return resT{x, q_nan};
89+
}
90+
}
91+
else {
92+
if (!std::signbit(x)) { /* x is +inf */
93+
if (y == realT(0)) {
94+
return resT{x, y};
95+
}
96+
else if (std::isfinite(y)) {
97+
return resT{x * std::cos(y), x * std::sin(y)};
98+
}
99+
else {
100+
/* x = +inf, y = +-inf || nan */
101+
return resT{x, q_nan};
102+
}
103+
}
104+
else { /* x is -inf */
105+
if (std::isfinite(y)) {
106+
realT exp_x = std::exp(x);
107+
return resT{exp_x * std::cos(y), exp_x * std::sin(y)};
108+
}
109+
else {
110+
/* x = -inf, y = +-inf || nan */
111+
return resT{0, 0};
112+
}
113+
}
114+
}
115+
}
116+
else {
117+
return std::exp(in);
118+
}
68119
}
69120
};
70121

dpctl/tests/elementwise/test_complex.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ def test_complex_order(np_call, dpt_call, dtype):
114114
X[..., 0::2] = np.pi / 6 + 1j * np.pi / 3
115115
X[..., 1::2] = np.pi / 3 + 1j * np.pi / 6
116116

117-
for ord in ["C", "F", "A", "K"]:
118-
for perms in itertools.permutations(range(4)):
119-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
117+
for perms in itertools.permutations(range(4)):
118+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
119+
expected_Y = np_call(dpt.asnumpy(U))
120+
for ord in ["C", "F", "A", "K"]:
120121
Y = dpt_call(U, order=ord)
121-
expected_Y = np_call(dpt.asnumpy(U))
122122
assert np.allclose(dpt.asnumpy(Y), expected_Y)
123123

124124

@@ -164,35 +164,51 @@ def test_projection(dtype):
164164
"np_call, dpt_call",
165165
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
166166
)
167-
@pytest.mark.parametrize("dtype", ["f4", "f8"])
168-
@pytest.mark.parametrize("stride", [-1, 1, 2, 4, 5])
169-
def test_complex_strided(np_call, dpt_call, dtype, stride):
167+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
168+
def test_complex_strided(np_call, dpt_call, dtype):
170169
q = get_queue_or_skip()
171170
skip_if_dtype_not_supported(dtype, q)
172171

173-
N = 100
174-
rng = np.random.default_rng(42)
175-
x1 = rng.standard_normal(N, dtype)
176-
x2 = 1j * rng.standard_normal(N, dtype)
177-
x = x1 + x2
178-
y = np_call(x[::stride])
179-
z = dpt_call(dpt.asarray(x[::stride]))
172+
np.random.seed(42)
173+
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
174+
sizes = [2, 4, 6, 8, 9, 24, 72]
175+
tol = 8 * dpt.finfo(dtype).resolution
180176

181-
tol = 8 * dpt.finfo(y.dtype).resolution
182-
assert_allclose(y, dpt.asnumpy(z), atol=tol, rtol=tol)
177+
low = -1000.0
178+
high = 1000.0
179+
for ii in sizes:
180+
x1 = np.random.uniform(low=low, high=high, size=ii)
181+
x2 = np.random.uniform(low=low, high=high, size=ii)
182+
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
183+
X = dpt.asarray(Xnp)
184+
Ynp = np_call(Xnp)
185+
for jj in strides:
186+
assert_allclose(
187+
dpt.asnumpy(dpt_call(X[::jj])),
188+
Ynp[::jj],
189+
atol=tol,
190+
rtol=tol,
191+
)
183192

184193

185-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
194+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
186195
def test_complex_special_cases(dtype):
187196
q = get_queue_or_skip()
188197
skip_if_dtype_not_supported(dtype, q)
189198

190-
x = [np.nan, -np.nan, np.inf, -np.inf]
191-
with np.errstate(all="ignore"):
192-
Xnp = 1j * np.array(x, dtype=dtype)
193-
X = dpt.asarray(Xnp, dtype=Xnp.dtype)
199+
x = [np.nan, -np.nan, np.inf, -np.inf, +0.0, -0.0]
200+
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
201+
202+
Xc_np = np.array(xc, dtype=dtype)
203+
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
194204

195205
tol = 8 * dpt.finfo(dtype).resolution
196-
assert_allclose(dpt.asnumpy(dpt.real(X)), np.real(Xnp), atol=tol, rtol=tol)
197-
assert_allclose(dpt.asnumpy(dpt.imag(X)), np.imag(Xnp), atol=tol, rtol=tol)
198-
assert_allclose(dpt.asnumpy(dpt.conj(X)), np.conj(Xnp), atol=tol, rtol=tol)
206+
assert_allclose(
207+
dpt.asnumpy(dpt.real(Xc)), np.real(Xc_np), atol=tol, rtol=tol
208+
)
209+
assert_allclose(
210+
dpt.asnumpy(dpt.imag(Xc)), np.imag(Xc_np), atol=tol, rtol=tol
211+
)
212+
assert_allclose(
213+
dpt.asnumpy(dpt.conj(Xc)), np.conj(Xc_np), atol=tol, rtol=tol
214+
)

0 commit comments

Comments
 (0)