Skip to content

Commit f01991b

Browse files
committed
Docstrings added for argmax, argmin, max, and min
1 parent 78829e7 commit f01991b

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

dpctl/tensor/_reduction.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,58 @@ def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
246246

247247

248248
def max(x, axis=None, keepdims=False):
249+
"""max(x, axis=None, dtype=None, keepdims=False)
250+
251+
Calculates the maximum value of the input array `x`.
252+
253+
Args:
254+
x (usm_ndarray):
255+
input array.
256+
axis (Optional[int, Tuple[int,...]]):
257+
axis or axes along which maxima must be computed. If a tuple
258+
of unique integers, the maxima are computed over multiple axes.
259+
If `None`, the max is computed over the entire array.
260+
Default: `None`.
261+
keepdims (Optional[bool]):
262+
if `True`, the reduced axes (dimensions) are included in the result
263+
as singleton dimensions, so that the returned array remains
264+
compatible with the input arrays according to Array Broadcasting
265+
rules. Otherwise, if `False`, the reduced axes are not included in
266+
the returned array. Default: `False`.
267+
Returns:
268+
usm_ndarray:
269+
an array containing the maxima. If the max was computed over the
270+
entire array, a zero-dimensional array is returned. The returned
271+
array has the same data type as `x`.
272+
"""
249273
return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis)
250274

251275

252276
def min(x, axis=None, keepdims=False):
277+
"""min(x, axis=None, dtype=None, keepdims=False)
278+
279+
Calculates the minimum value of the input array `x`.
280+
281+
Args:
282+
x (usm_ndarray):
283+
input array.
284+
axis (Optional[int, Tuple[int,...]]):
285+
axis or axes along which minima must be computed. If a tuple
286+
of unique integers, the minima are computed over multiple axes.
287+
If `None`, the min is computed over the entire array.
288+
Default: `None`.
289+
keepdims (Optional[bool]):
290+
if `True`, the reduced axes (dimensions) are included in the result
291+
as singleton dimensions, so that the returned array remains
292+
compatible with the input arrays according to Array Broadcasting
293+
rules. Otherwise, if `False`, the reduced axes are not included in
294+
the returned array. Default: `False`.
295+
Returns:
296+
usm_ndarray:
297+
an array containing the minima. If the min was computed over the
298+
entire array, a zero-dimensional array is returned. The returned
299+
array has the same data type as `x`.
300+
"""
253301
return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis)
254302

255303

@@ -303,8 +351,64 @@ def _search_over_axis(x, axis, keepdims, _reduction_fn):
303351

304352

305353
def argmax(x, axis=None, keepdims=False):
354+
"""argmax(x, axis=None, dtype=None, keepdims=False)
355+
356+
Returns the indices of the maximum values of the input array `x` along a
357+
specified axis.
358+
359+
When the maximum value occurs multiple times, the indices corresponding to
360+
the first occurrence are returned.
361+
362+
Args:
363+
x (usm_ndarray):
364+
input array.
365+
axis (Optional[int]):
366+
axis along which to search. If `None`, returns the index of the
367+
maximum value of the flattened array.
368+
Default: `None`.
369+
keepdims (Optional[bool]):
370+
if `True`, the reduced axes (dimensions) are included in the result
371+
as singleton dimensions, so that the returned array remains
372+
compatible with the input arrays according to Array Broadcasting
373+
rules. Otherwise, if `False`, the reduced axes are not included in
374+
the returned array. Default: `False`.
375+
Returns:
376+
usm_ndarray:
377+
an array containing the indices of the first occurrence of the
378+
maximum values. If the entire array was searched, a
379+
zero-dimensional array is returned. The returned array has the
380+
default array index data type for the device of `x`.
381+
"""
306382
return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis)
307383

308384

309385
def argmin(x, axis=None, keepdims=False):
386+
"""argmin(x, axis=None, dtype=None, keepdims=False)
387+
388+
Returns the indices of the minimum values of the input array `x` along a
389+
specified axis.
390+
391+
When the minimum value occurs multiple times, the indices corresponding to
392+
the first occurrence are returned.
393+
394+
Args:
395+
x (usm_ndarray):
396+
input array.
397+
axis (Optional[int]):
398+
axis along which to search. If `None`, returns the index of the
399+
minimum value of the flattened array.
400+
Default: `None`.
401+
keepdims (Optional[bool]):
402+
if `True`, the reduced axes (dimensions) are included in the result
403+
as singleton dimensions, so that the returned array remains
404+
compatible with the input arrays according to Array Broadcasting
405+
rules. Otherwise, if `False`, the reduced axes are not included in
406+
the returned array. Default: `False`.
407+
Returns:
408+
usm_ndarray:
409+
an array containing the indices of the first occurrence of the
410+
minimum values. If the entire array was searched, a
411+
zero-dimensional array is returned. The returned array has the
412+
default array index data type for the device of `x`.
413+
"""
310414
return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis)

0 commit comments

Comments
 (0)