Skip to content

Commit 47d9266

Browse files
committed
update elementwise func tests and doc_string
1 parent 179ce15 commit 47d9266

File tree

3 files changed

+168
-75
lines changed

3 files changed

+168
-75
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@
3636
Returns:
3737
usm_narray:
3838
An array containing the element-wise absolute values.
39-
For complex input, the absolute value is its magnitude. The data type
40-
of the returned array is determined by the Type Promotion Rules.
39+
For complex input, the absolute value is its magnitude.
40+
If `x` has a real-valued data type, the returned array has the
41+
same data type as `x`. If `x` has a complex floating-point data type,
42+
the returned array has a real-valued floating-point data type whose
43+
precision matches the precision of `x`.
4144
"""
4245

4346
abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_)
@@ -133,8 +136,8 @@
133136
Default: "K".
134137
Returns:
135138
usm_narray:
136-
An array containing the element-wise conjugate values. The data type
137-
of the returned array is determined by the Type Promotion Rules.
139+
An array containing the element-wise conjugate values.
140+
The returned array has the same data type as `x`.
138141
"""
139142

140143
conj = UnaryElementwiseFunc(
@@ -216,8 +219,7 @@
216219
Returns:
217220
usm_narray:
218221
An array containing the result of element-wise equality comparison.
219-
The data type of the returned array is determined by the
220-
Type Promotion Rules.
222+
The returned array has a data type of `bool`.
221223
"""
222224

223225
equal = BinaryElementwiseFunc(
@@ -264,6 +266,8 @@
264266
Return:
265267
usm_ndarray:
266268
An array containing the element-wise exp(x)-1 values.
269+
The data type of the returned array is determined by the Type
270+
Promotion Rules..
267271
"""
268272

269273
expm1 = UnaryElementwiseFunc(
@@ -288,7 +292,7 @@
288292
Second input array, also expected to have numeric data type.
289293
Returns:
290294
usm_narray:
291-
an array containing the result of element-wise floor division.
295+
an array containing the result of element-wise floor division.
292296
The data type of the returned array is determined by the Type
293297
Promotion Rules.
294298
"""
@@ -319,8 +323,7 @@
319323
Returns:
320324
usm_narray:
321325
An array containing the result of element-wise greater-than comparison.
322-
The data type of the returned array is determined by the
323-
Type Promotion Rules.
326+
The returned array has a data type of `bool`.
324327
"""
325328

326329
greater = BinaryElementwiseFunc(
@@ -347,8 +350,7 @@
347350
usm_narray:
348351
An array containing the result of element-wise greater-than or equal-to
349352
comparison.
350-
The data type of the returned array is determined by the
351-
Type Promotion Rules.
353+
The returned array has a data type of `bool`.
352354
"""
353355

354356
greater_equal = BinaryElementwiseFunc(
@@ -376,8 +378,10 @@
376378
Returns:
377379
usm_narray:
378380
An array containing the element-wise imaginary component of input.
379-
The data type of the returned array is determined
380-
by the Type Promotion Rules.
381+
If the input is a real-valued data type, the returned array has
382+
the same datat type. If the input is a complex floating-point
383+
data type, the returned array has a floating-point data type
384+
with the same floating-point precision as complex input.
381385
"""
382386

383387
imag = UnaryElementwiseFunc(
@@ -403,7 +407,7 @@
403407
usm_narray:
404408
An array which is True where `x` is not positive infinity,
405409
negative infinity, or NaN, False otherwise.
406-
The data type of the returned array is boolean.
410+
The data type of the returned array is `bool`.
407411
"""
408412

409413
isfinite = UnaryElementwiseFunc(
@@ -428,7 +432,7 @@
428432
Returns:
429433
usm_narray:
430434
An array which is True where `x` is positive or negative infinity,
431-
False otherwise. The data type of the returned array is boolean.
435+
False otherwise. The data type of the returned array is `bool`.
432436
"""
433437

434438
isinf = UnaryElementwiseFunc(
@@ -453,7 +457,7 @@
453457
Returns:
454458
usm_narray:
455459
An array which is True where x is NaN, False otherwise.
456-
The data type of the returned array is boolean.
460+
The data type of the returned array is `bool`.
457461
"""
458462

459463
isnan = UnaryElementwiseFunc(
@@ -481,8 +485,7 @@
481485
Returns:
482486
usm_narray:
483487
An array containing the result of element-wise less-than comparison.
484-
The data type of the returned array is determined by the
485-
Type Promotion Rules.
488+
The returned array has a data type of `bool`.
486489
"""
487490

488491
less = BinaryElementwiseFunc(
@@ -508,9 +511,7 @@
508511
Returns:
509512
usm_narray:
510513
An array containing the result of element-wise less-than or equal-to
511-
comparison.
512-
The data type of the returned array is determined by the
513-
Type Promotion Rules.
514+
comparison. The returned array has a data type of `bool`.
514515
"""
515516

516517
less_equal = BinaryElementwiseFunc(
@@ -536,6 +537,8 @@
536537
Return:
537538
usm_ndarray:
538539
An array containing the element-wise natural logarithm values.
540+
The data type of the returned array is determined by the Type
541+
Promotion Rules.
539542
"""
540543

541544
log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
@@ -555,7 +558,8 @@
555558
Default: "K".
556559
Return:
557560
usm_ndarray:
558-
An array containing the element-wise log(1+x) values.
561+
An array containing the element-wise log(1+x) values. The data type
562+
of the returned array is determined by the Type Promotion Rules.
559563
"""
560564

561565
log1p = UnaryElementwiseFunc(
@@ -638,8 +642,7 @@
638642
Returns:
639643
usm_narray:
640644
an array containing the result of element-wise inequality comparison.
641-
The data type of the returned array is determined by the
642-
Type Promotion Rules.
645+
The returned array has a data type of `bool`.
643646
"""
644647

645648
not_equal = BinaryElementwiseFunc(
@@ -669,8 +672,8 @@
669672
Default: "K".
670673
Returns:
671674
usm_narray:
672-
An array containing the element-wise projection. The data
673-
type of the returned array is determined by the Type Promotion Rules.
675+
An array containing the element-wise projection.
676+
The returned array has the same data type as `x`.
674677
"""
675678

676679
proj = UnaryElementwiseFunc(
@@ -694,8 +697,11 @@
694697
Default: "K".
695698
Returns:
696699
usm_narray:
697-
An array containing the element-wise real component of input. The data
698-
type of the returned array is determined by the Type Promotion Rules.
700+
An array containing the element-wise real component of input.
701+
If the input is a real-valued data type, the returned array has
702+
the same datat type. If the input is a complex floating-point
703+
data type, the returned array has a floating-point data type
704+
with the same floating-point precision as complex input.
699705
"""
700706

701707
real = UnaryElementwiseFunc(

dpctl/tests/elementwise/test_complex.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -164,35 +164,51 @@ def test_projection(dtype):
164164
"np_call, dpt_call",
165165
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
166166
)
167-
@pytest.mark.parametrize("dtype", ["f4", "f8"])
168-
@pytest.mark.parametrize("stride", [-1, 1, 2, 4, 5])
169-
def test_complex_strided(np_call, dpt_call, dtype, stride):
167+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
168+
def test_complex_strided(np_call, dpt_call, dtype):
170169
q = get_queue_or_skip()
171170
skip_if_dtype_not_supported(dtype, q)
172171

173-
N = 100
174-
rng = np.random.default_rng(42)
175-
x1 = rng.standard_normal(N, dtype)
176-
x2 = 1j * rng.standard_normal(N, dtype)
177-
x = x1 + x2
178-
y = np_call(x[::stride])
179-
z = dpt_call(dpt.asarray(x[::stride]))
172+
np.random.seed(42)
173+
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
174+
sizes = np.arange(2, 100)
175+
tol = 8 * dpt.finfo(dtype).resolution
180176

181-
tol = 8 * dpt.finfo(y.dtype).resolution
182-
assert_allclose(y, dpt.asnumpy(z), atol=tol, rtol=tol)
177+
low = -1000.0
178+
high = 1000.0
179+
for ii in sizes:
180+
x1 = np.random.uniform(low=low, high=high, size=ii)
181+
x2 = np.random.uniform(low=low, high=high, size=ii)
182+
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
183+
X = dpt.asarray(Xnp)
184+
Ynp = np_call(Xnp)
185+
for jj in strides:
186+
assert_allclose(
187+
dpt.asnumpy(dpt_call(X[::jj])),
188+
Ynp[::jj],
189+
atol=tol,
190+
rtol=tol,
191+
)
183192

184193

185-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
194+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
186195
def test_complex_special_cases(dtype):
187196
q = get_queue_or_skip()
188197
skip_if_dtype_not_supported(dtype, q)
189198

190-
x = [np.nan, -np.nan, np.inf, -np.inf]
191-
with np.errstate(all="ignore"):
192-
Xnp = 1j * np.array(x, dtype=dtype)
193-
X = dpt.asarray(Xnp, dtype=Xnp.dtype)
199+
x = [np.nan, -np.nan, np.inf, -np.inf, +0.0, -0.0]
200+
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
201+
202+
Xc_np = np.array(xc, dtype=dtype)
203+
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
194204

195205
tol = 8 * dpt.finfo(dtype).resolution
196-
assert_allclose(dpt.asnumpy(dpt.real(X)), np.real(Xnp), atol=tol, rtol=tol)
197-
assert_allclose(dpt.asnumpy(dpt.imag(X)), np.imag(Xnp), atol=tol, rtol=tol)
198-
assert_allclose(dpt.asnumpy(dpt.conj(X)), np.conj(Xnp), atol=tol, rtol=tol)
206+
assert_allclose(
207+
dpt.asnumpy(dpt.real(Xc)), np.real(Xc_np), atol=tol, rtol=tol
208+
)
209+
assert_allclose(
210+
dpt.asnumpy(dpt.imag(Xc)), np.imag(Xc_np), atol=tol, rtol=tol
211+
)
212+
assert_allclose(
213+
dpt.asnumpy(dpt.conj(Xc)), np.conj(Xc_np), atol=tol, rtol=tol
214+
)

0 commit comments

Comments
 (0)