Skip to content

Commit 8ef97a0

Browse files
committed
update doc_string
1 parent 9b3b889 commit 8ef97a0

File tree

3 files changed

+171
-78
lines changed

3 files changed

+171
-78
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@
3636
Returns:
3737
usm_narray:
3838
An array containing the element-wise absolute values.
39-
For complex input, the absolute value is its magnitude. The data type
40-
of the returned array is determined by the Type Promotion Rules.
39+
For complex input, the absolute value is its magnitude.
40+
If `x` has a real-valued data type, the returned array has the
41+
same data type as `x`. If `x` has a complex floating-point data type,
42+
the returned array has a real-valued floating-point data type whose
43+
precision matches the precision of `x`.
4144
"""
4245

4346
abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_)
@@ -131,7 +134,7 @@
131134
Default: "K".
132135
Returns:
133136
usm_narray:
134-
An array containing the element-wise ceiling of input array.
137+
An array containing the ceiling of each element in `x`.
135138
The returned array has the same data type as `x`.
136139
"""
137140

@@ -156,8 +159,8 @@
156159
Default: "K".
157160
Returns:
158161
usm_narray:
159-
An array containing the element-wise conjugate values. The data type
160-
of the returned array is determined by the Type Promotion Rules.
162+
An array containing the element-wise conjugate values.
163+
The returned array has the same data type as `x`.
161164
"""
162165

163166
conj = UnaryElementwiseFunc(
@@ -239,8 +242,7 @@
239242
Returns:
240243
usm_narray:
241244
An array containing the result of element-wise equality comparison.
242-
The data type of the returned array is determined by the
243-
Type Promotion Rules.
245+
The returned array has a data type of `bool`.
244246
"""
245247

246248
equal = BinaryElementwiseFunc(
@@ -287,6 +289,8 @@
287289
Return:
288290
usm_ndarray:
289291
An array containing the element-wise exp(x)-1 values.
292+
The data type of the returned array is determined by the Type
293+
Promotion Rules..
290294
"""
291295

292296
expm1 = UnaryElementwiseFunc(
@@ -311,7 +315,7 @@
311315
Default: "K".
312316
Returns:
313317
usm_narray:
314-
An array containing the element-wise floor of input array.
318+
An array containing the floor of each element in `x`.
315319
The returned array has the same data type as `x`.
316320
"""
317321

@@ -334,7 +338,7 @@
334338
Second input array, also expected to have numeric data type.
335339
Returns:
336340
usm_narray:
337-
an array containing the result of element-wise floor division.
341+
an array containing the result of element-wise floor division.
338342
The data type of the returned array is determined by the Type
339343
Promotion Rules.
340344
"""
@@ -365,8 +369,7 @@
365369
Returns:
366370
usm_narray:
367371
An array containing the result of element-wise greater-than comparison.
368-
The data type of the returned array is determined by the
369-
Type Promotion Rules.
372+
The returned array has a data type of `bool`.
370373
"""
371374

372375
greater = BinaryElementwiseFunc(
@@ -393,8 +396,7 @@
393396
usm_narray:
394397
An array containing the result of element-wise greater-than or equal-to
395398
comparison.
396-
The data type of the returned array is determined by the
397-
Type Promotion Rules.
399+
The returned array has a data type of `bool`.
398400
"""
399401

400402
greater_equal = BinaryElementwiseFunc(
@@ -422,8 +424,10 @@
422424
Returns:
423425
usm_narray:
424426
An array containing the element-wise imaginary component of input.
425-
The data type of the returned array is determined
426-
by the Type Promotion Rules.
427+
If the input is a real-valued data type, the returned array has
428+
the same datat type. If the input is a complex floating-point
429+
data type, the returned array has a floating-point data type
430+
with the same floating-point precision as complex input.
427431
"""
428432

429433
imag = UnaryElementwiseFunc(
@@ -449,7 +453,7 @@
449453
usm_narray:
450454
An array which is True where `x` is not positive infinity,
451455
negative infinity, or NaN, False otherwise.
452-
The data type of the returned array is boolean.
456+
The data type of the returned array is `bool`.
453457
"""
454458

455459
isfinite = UnaryElementwiseFunc(
@@ -474,7 +478,7 @@
474478
Returns:
475479
usm_narray:
476480
An array which is True where `x` is positive or negative infinity,
477-
False otherwise. The data type of the returned array is boolean.
481+
False otherwise. The data type of the returned array is `bool`.
478482
"""
479483

480484
isinf = UnaryElementwiseFunc(
@@ -499,7 +503,7 @@
499503
Returns:
500504
usm_narray:
501505
An array which is True where x is NaN, False otherwise.
502-
The data type of the returned array is boolean.
506+
The data type of the returned array is `bool`.
503507
"""
504508

505509
isnan = UnaryElementwiseFunc(
@@ -527,8 +531,7 @@
527531
Returns:
528532
usm_narray:
529533
An array containing the result of element-wise less-than comparison.
530-
The data type of the returned array is determined by the
531-
Type Promotion Rules.
534+
The returned array has a data type of `bool`.
532535
"""
533536

534537
less = BinaryElementwiseFunc(
@@ -554,9 +557,7 @@
554557
Returns:
555558
usm_narray:
556559
An array containing the result of element-wise less-than or equal-to
557-
comparison.
558-
The data type of the returned array is determined by the
559-
Type Promotion Rules.
560+
comparison. The returned array has a data type of `bool`.
560561
"""
561562

562563
less_equal = BinaryElementwiseFunc(
@@ -582,6 +583,8 @@
582583
Return:
583584
usm_ndarray:
584585
An array containing the element-wise natural logarithm values.
586+
The data type of the returned array is determined by the Type
587+
Promotion Rules.
585588
"""
586589

587590
log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
@@ -601,7 +604,8 @@
601604
Default: "K".
602605
Return:
603606
usm_ndarray:
604-
An array containing the element-wise log(1+x) values.
607+
An array containing the element-wise log(1+x) values. The data type
608+
of the returned array is determined by the Type Promotion Rules.
605609
"""
606610

607611
log1p = UnaryElementwiseFunc(
@@ -684,8 +688,7 @@
684688
Returns:
685689
usm_narray:
686690
an array containing the result of element-wise inequality comparison.
687-
The data type of the returned array is determined by the
688-
Type Promotion Rules.
691+
The returned array has a data type of `bool`.
689692
"""
690693

691694
not_equal = BinaryElementwiseFunc(
@@ -715,8 +718,8 @@
715718
Default: "K".
716719
Returns:
717720
usm_narray:
718-
An array containing the element-wise projection. The data
719-
type of the returned array is determined by the Type Promotion Rules.
721+
An array containing the element-wise projection.
722+
The returned array has the same data type as `x`.
720723
"""
721724

722725
proj = UnaryElementwiseFunc(
@@ -740,8 +743,11 @@
740743
Default: "K".
741744
Returns:
742745
usm_narray:
743-
An array containing the element-wise real component of input. The data
744-
type of the returned array is determined by the Type Promotion Rules.
746+
An array containing the element-wise real component of input.
747+
If the input is a real-valued data type, the returned array has
748+
the same datat type. If the input is a complex floating-point
749+
data type, the returned array has a floating-point data type
750+
with the same floating-point precision as complex input.
745751
"""
746752

747753
real = UnaryElementwiseFunc(
@@ -870,7 +876,7 @@
870876
Default: "K".
871877
Returns:
872878
usm_narray:
873-
An array containing the element-wise truncated value of input array.
879+
An array containing the truncated value of each element in `x`.
874880
The returned array has the same data type as `x`.
875881
"""
876882

dpctl/tests/elementwise/test_complex.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -164,35 +164,51 @@ def test_projection(dtype):
164164
"np_call, dpt_call",
165165
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
166166
)
167-
@pytest.mark.parametrize("dtype", ["f4", "f8"])
168-
@pytest.mark.parametrize("stride", [-1, 1, 2, 4, 5])
169-
def test_complex_strided(np_call, dpt_call, dtype, stride):
167+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
168+
def test_complex_strided(np_call, dpt_call, dtype):
170169
q = get_queue_or_skip()
171170
skip_if_dtype_not_supported(dtype, q)
172171

173-
N = 100
174-
rng = np.random.default_rng(42)
175-
x1 = rng.standard_normal(N, dtype)
176-
x2 = 1j * rng.standard_normal(N, dtype)
177-
x = x1 + x2
178-
y = np_call(x[::stride])
179-
z = dpt_call(dpt.asarray(x[::stride]))
172+
np.random.seed(42)
173+
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
174+
sizes = np.arange(2, 100)
175+
tol = 8 * dpt.finfo(dtype).resolution
180176

181-
tol = 8 * dpt.finfo(y.dtype).resolution
182-
assert_allclose(y, dpt.asnumpy(z), atol=tol, rtol=tol)
177+
low = -1000.0
178+
high = 1000.0
179+
for ii in sizes:
180+
x1 = np.random.uniform(low=low, high=high, size=ii)
181+
x2 = np.random.uniform(low=low, high=high, size=ii)
182+
Xnp = np.array([complex(v1, v2) for v1, v2 in zip(x1, x2)], dtype=dtype)
183+
X = dpt.asarray(Xnp)
184+
Ynp = np_call(Xnp)
185+
for jj in strides:
186+
assert_allclose(
187+
dpt.asnumpy(dpt_call(X[::jj])),
188+
Ynp[::jj],
189+
atol=tol,
190+
rtol=tol,
191+
)
183192

184193

185-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
194+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
186195
def test_complex_special_cases(dtype):
187196
q = get_queue_or_skip()
188197
skip_if_dtype_not_supported(dtype, q)
189198

190-
x = [np.nan, -np.nan, np.inf, -np.inf]
191-
with np.errstate(all="ignore"):
192-
Xnp = 1j * np.array(x, dtype=dtype)
193-
X = dpt.asarray(Xnp, dtype=Xnp.dtype)
199+
x = [np.nan, -np.nan, np.inf, -np.inf, +0.0, -0.0]
200+
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
201+
202+
Xc_np = np.array(xc, dtype=dtype)
203+
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
194204

195205
tol = 8 * dpt.finfo(dtype).resolution
196-
assert_allclose(dpt.asnumpy(dpt.real(X)), np.real(Xnp), atol=tol, rtol=tol)
197-
assert_allclose(dpt.asnumpy(dpt.imag(X)), np.imag(Xnp), atol=tol, rtol=tol)
198-
assert_allclose(dpt.asnumpy(dpt.conj(X)), np.conj(Xnp), atol=tol, rtol=tol)
206+
assert_allclose(
207+
dpt.asnumpy(dpt.real(Xc)), np.real(Xc_np), atol=tol, rtol=tol
208+
)
209+
assert_allclose(
210+
dpt.asnumpy(dpt.imag(Xc)), np.imag(Xc_np), atol=tol, rtol=tol
211+
)
212+
assert_allclose(
213+
dpt.asnumpy(dpt.conj(Xc)), np.conj(Xc_np), atol=tol, rtol=tol
214+
)

0 commit comments

Comments
 (0)