Skip to content

Commit ba52635

Browse files
committed
update elementwise func tests and doc_string
1 parent f52182d commit ba52635

File tree

4 files changed

+234
-86
lines changed

4 files changed

+234
-86
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@
3636
Returns:
3737
usm_narray:
3838
An array containing the element-wise absolute values.
39-
For complex input, the absolute value is its magnitude. The data type
40-
of the returned array is determined by the Type Promotion Rules.
39+
For complex input, the absolute value is its magnitude.
40+
If `x` has a real-valued data type, the returned array has the
41+
same data type as `x`. If `x` has a complex floating-point data type,
42+
the returned array has a real-valued floating-point data type whose
43+
precision matches the precision of `x`.
4144
"""
4245

4346
abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_)
@@ -133,8 +136,8 @@
133136
Default: "K".
134137
Returns:
135138
usm_narray:
136-
An array containing the element-wise conjugate values. The data type
137-
of the returned array is determined by the Type Promotion Rules.
139+
An array containing the element-wise conjugate values.
140+
The returned array has the same data type as `x`.
138141
"""
139142

140143
conj = UnaryElementwiseFunc(
@@ -216,8 +219,7 @@
216219
Returns:
217220
usm_narray:
218221
An array containing the result of element-wise equality comparison.
219-
The data type of the returned array is determined by the
220-
Type Promotion Rules.
222+
The returned array has a data type of `bool`.
221223
"""
222224

223225
equal = BinaryElementwiseFunc(
@@ -264,6 +266,8 @@
264266
Return:
265267
usm_ndarray:
266268
An array containing the element-wise exp(x)-1 values.
269+
The data type of the returned array is determined by the Type
270+
Promotion Rules.
267271
"""
268272

269273
expm1 = UnaryElementwiseFunc(
@@ -288,7 +292,7 @@
288292
Second input array, also expected to have numeric data type.
289293
Returns:
290294
usm_narray:
291-
an array containing the result of element-wise floor division.
295+
an array containing the result of element-wise floor division.
292296
The data type of the returned array is determined by the Type
293297
Promotion Rules.
294298
"""
@@ -319,8 +323,7 @@
319323
Returns:
320324
usm_narray:
321325
An array containing the result of element-wise greater-than comparison.
322-
The data type of the returned array is determined by the
323-
Type Promotion Rules.
326+
The returned array has a data type of `bool`.
324327
"""
325328

326329
greater = BinaryElementwiseFunc(
@@ -347,8 +350,7 @@
347350
usm_narray:
348351
An array containing the result of element-wise greater-than or equal-to
349352
comparison.
350-
The data type of the returned array is determined by the
351-
Type Promotion Rules.
353+
The returned array has a data type of `bool`.
352354
"""
353355

354356
greater_equal = BinaryElementwiseFunc(
@@ -376,8 +378,10 @@
376378
Returns:
377379
usm_narray:
378380
An array containing the element-wise imaginary component of input.
379-
The data type of the returned array is determined
380-
by the Type Promotion Rules.
381+
If the input is a real-valued data type, the returned array has
382+
the same datat type. If the input is a complex floating-point
383+
data type, the returned array has a floating-point data type
384+
with the same floating-point precision as complex input.
381385
"""
382386

383387
imag = UnaryElementwiseFunc(
@@ -403,7 +407,7 @@
403407
usm_narray:
404408
An array which is True where `x` is not positive infinity,
405409
negative infinity, or NaN, False otherwise.
406-
The data type of the returned array is boolean.
410+
The data type of the returned array is `bool`.
407411
"""
408412

409413
isfinite = UnaryElementwiseFunc(
@@ -428,7 +432,7 @@
428432
Returns:
429433
usm_narray:
430434
An array which is True where `x` is positive or negative infinity,
431-
False otherwise. The data type of the returned array is boolean.
435+
False otherwise. The data type of the returned array is `bool`.
432436
"""
433437

434438
isinf = UnaryElementwiseFunc(
@@ -453,7 +457,7 @@
453457
Returns:
454458
usm_narray:
455459
An array which is True where x is NaN, False otherwise.
456-
The data type of the returned array is boolean.
460+
The data type of the returned array is `bool`.
457461
"""
458462

459463
isnan = UnaryElementwiseFunc(
@@ -481,8 +485,7 @@
481485
Returns:
482486
usm_narray:
483487
An array containing the result of element-wise less-than comparison.
484-
The data type of the returned array is determined by the
485-
Type Promotion Rules.
488+
The returned array has a data type of `bool`.
486489
"""
487490

488491
less = BinaryElementwiseFunc(
@@ -508,9 +511,7 @@
508511
Returns:
509512
usm_narray:
510513
An array containing the result of element-wise less-than or equal-to
511-
comparison.
512-
The data type of the returned array is determined by the
513-
Type Promotion Rules.
514+
comparison. The returned array has a data type of `bool`.
514515
"""
515516

516517
less_equal = BinaryElementwiseFunc(
@@ -536,6 +537,8 @@
536537
Return:
537538
usm_ndarray:
538539
An array containing the element-wise natural logarithm values.
540+
The data type of the returned array is determined by the Type
541+
Promotion Rules.
539542
"""
540543

541544
log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
@@ -555,7 +558,8 @@
555558
Default: "K".
556559
Return:
557560
usm_ndarray:
558-
An array containing the element-wise log(1+x) values.
561+
An array containing the element-wise log(1+x) values. The data type
562+
of the returned array is determined by the Type Promotion Rules.
559563
"""
560564

561565
log1p = UnaryElementwiseFunc(
@@ -804,8 +808,7 @@
804808
Returns:
805809
usm_narray:
806810
an array containing the result of element-wise inequality comparison.
807-
The data type of the returned array is determined by the
808-
Type Promotion Rules.
811+
The returned array has a data type of `bool`.
809812
"""
810813

811814
not_equal = BinaryElementwiseFunc(
@@ -873,8 +876,8 @@
873876
Default: "K".
874877
Returns:
875878
usm_narray:
876-
An array containing the element-wise projection. The data
877-
type of the returned array is determined by the Type Promotion Rules.
879+
An array containing the element-wise projection.
880+
The returned array has the same data type as `x`.
878881
"""
879882

880883
proj = UnaryElementwiseFunc(
@@ -898,8 +901,11 @@
898901
Default: "K".
899902
Returns:
900903
usm_narray:
901-
An array containing the element-wise real component of input. The data
902-
type of the returned array is determined by the Type Promotion Rules.
904+
An array containing the element-wise real component of input.
905+
If the input is a real-valued data type, the returned array has
906+
the same datat type. If the input is a complex floating-point
907+
data type, the returned array has a floating-point data type
908+
with the same floating-point precision as complex input.
903909
"""
904910

905911
real = UnaryElementwiseFunc(

dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,63 @@ template <typename argT, typename resT> struct ExpFunctor
6464

6565
resT operator()(const argT &in)
6666
{
67-
return std::exp(in);
67+
if constexpr (is_complex<argT>::value) {
68+
using realT = typename argT::value_type;
69+
70+
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
71+
72+
const realT x = std::real(in);
73+
const realT y = std::imag(in);
74+
if (std::isfinite(x)) {
75+
if (std::isfinite(y)) {
76+
return std::exp(in);
77+
}
78+
else {
79+
return resT{q_nan, q_nan};
80+
}
81+
}
82+
else if (std::isnan(x)) {
83+
/* x is nan */
84+
if (y == realT(0)) {
85+
return resT{in};
86+
}
87+
else {
88+
return resT{x, q_nan};
89+
}
90+
}
91+
else {
92+
if (x > realT(0)) { /* x is +inf */
93+
if (y == realT(0)) {
94+
return resT{x, y};
95+
}
96+
else if (std::isfinite(y)) {
97+
return resT{x * std::cos(y), x * std::sin(y)};
98+
}
99+
else {
100+
/* x = +inf, y = +-inf || nan */
101+
return resT{x, q_nan};
102+
}
103+
}
104+
else { /* x is -inf */
105+
if (std::isfinite(y)) {
106+
realT exp_x = std::exp(x);
107+
return resT{exp_x * std::cos(y), exp_x * std::sin(y)};
108+
}
109+
else {
110+
/* x = -inf, y = +-inf || nan */
111+
return resT{0, 0};
112+
}
113+
}
114+
}
115+
}
116+
else {
117+
return std::exp(in);
118+
}
68119
}
120+
// resT operator()(const argT &in)
121+
// {
122+
// return std::exp(in);
123+
// }
69124
};
70125

71126
template <typename argTy,

dpctl/tests/elementwise/test_complex.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ def test_complex_order(np_call, dpt_call, dtype):
114114
X[..., 0::2] = np.pi / 6 + 1j * np.pi / 3
115115
X[..., 1::2] = np.pi / 3 + 1j * np.pi / 6
116116

117-
for ord in ["C", "F", "A", "K"]:
118-
for perms in itertools.permutations(range(4)):
119-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
117+
for perms in itertools.permutations(range(4)):
118+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
119+
expected_Y = np_call(dpt.asnumpy(U))
120+
for ord in ["C", "F", "A", "K"]:
120121
Y = dpt_call(U, order=ord)
121-
expected_Y = np_call(dpt.asnumpy(U))
122122
assert np.allclose(dpt.asnumpy(Y), expected_Y)
123123

124124

@@ -164,35 +164,51 @@ def test_projection(dtype):
164164
"np_call, dpt_call",
165165
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
166166
)
167-
@pytest.mark.parametrize("dtype", ["f4", "f8"])
168-
@pytest.mark.parametrize("stride", [-1, 1, 2, 4, 5])
169-
def test_complex_strided(np_call, dpt_call, dtype, stride):
167+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
168+
def test_complex_strided(np_call, dpt_call, dtype):
170169
q = get_queue_or_skip()
171170
skip_if_dtype_not_supported(dtype, q)
172171

173-
N = 100
174-
rng = np.random.default_rng(42)
175-
x1 = rng.standard_normal(N, dtype)
176-
x2 = 1j * rng.standard_normal(N, dtype)
177-
x = x1 + x2
178-
y = np_call(x[::stride])
179-
z = dpt_call(dpt.asarray(x[::stride]))
172+
np.random.seed(42)
173+
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
174+
sizes = [2, 4, 6, 8, 9, 24, 72]
175+
tol = 8 * dpt.finfo(dtype).resolution
180176

181-
tol = 8 * dpt.finfo(y.dtype).resolution
182-
assert_allclose(y, dpt.asnumpy(z), atol=tol, rtol=tol)
177+
low = -1000.0
178+
high = 1000.0
179+
for ii in sizes:
180+
x1 = np.random.uniform(low=low, high=high, size=ii)
181+
x2 = np.random.uniform(low=low, high=high, size=ii)
182+
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
183+
X = dpt.asarray(Xnp)
184+
Ynp = np_call(Xnp)
185+
for jj in strides:
186+
assert_allclose(
187+
dpt.asnumpy(dpt_call(X[::jj])),
188+
Ynp[::jj],
189+
atol=tol,
190+
rtol=tol,
191+
)
183192

184193

185-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
194+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
186195
def test_complex_special_cases(dtype):
187196
q = get_queue_or_skip()
188197
skip_if_dtype_not_supported(dtype, q)
189198

190-
x = [np.nan, -np.nan, np.inf, -np.inf]
191-
with np.errstate(all="ignore"):
192-
Xnp = 1j * np.array(x, dtype=dtype)
193-
X = dpt.asarray(Xnp, dtype=Xnp.dtype)
199+
x = [np.nan, -np.nan, np.inf, -np.inf, +0.0, -0.0]
200+
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
201+
202+
Xc_np = np.array(xc, dtype=dtype)
203+
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
194204

195205
tol = 8 * dpt.finfo(dtype).resolution
196-
assert_allclose(dpt.asnumpy(dpt.real(X)), np.real(Xnp), atol=tol, rtol=tol)
197-
assert_allclose(dpt.asnumpy(dpt.imag(X)), np.imag(Xnp), atol=tol, rtol=tol)
198-
assert_allclose(dpt.asnumpy(dpt.conj(X)), np.conj(Xnp), atol=tol, rtol=tol)
206+
assert_allclose(
207+
dpt.asnumpy(dpt.real(Xc)), np.real(Xc_np), atol=tol, rtol=tol
208+
)
209+
assert_allclose(
210+
dpt.asnumpy(dpt.imag(Xc)), np.imag(Xc_np), atol=tol, rtol=tol
211+
)
212+
assert_allclose(
213+
dpt.asnumpy(dpt.conj(Xc)), np.conj(Xc_np), atol=tol, rtol=tol
214+
)

0 commit comments

Comments
 (0)