Skip to content

Commit 9fafd50

Browse files
Implement take_along_axis function per Python Array API
The function is planned for Python Array API 2024.12 specification.
1 parent 1ecd8a8 commit 9fafd50

File tree

2 files changed

+84
-2
lines changed

2 files changed

+84
-2
lines changed

dpctl/tensor/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,14 @@
6060
)
6161
from dpctl.tensor._device import Device
6262
from dpctl.tensor._dlpack import from_dlpack
63-
from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take
63+
from dpctl.tensor._indexing_functions import (
64+
extract,
65+
nonzero,
66+
place,
67+
put,
68+
take,
69+
take_along_axis,
70+
)
6471
from dpctl.tensor._linear_algebra_functions import (
6572
matmul,
6673
matrix_transpose,
@@ -376,4 +383,5 @@
376383
"nextafter",
377384
"diff",
378385
"count_nonzero",
386+
"take_along_axis",
379387
]

dpctl/tensor/_indexing_functions.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.utils
2323

24-
from ._copy_utils import _extract_impl, _nonzero_impl
24+
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
2525
from ._numpy_helper import normalize_axis_index
2626

2727

@@ -423,3 +423,77 @@ def nonzero(arr):
423423
if arr.ndim == 0:
424424
raise ValueError("Array of positive rank is expected")
425425
return _nonzero_impl(arr)
426+
427+
428+
def _range(sh_i, i, nd, q, usm_t, dt):
429+
ind = dpt.arange(sh_i, dtype=dt, usm_type=usm_t, sycl_queue=q)
430+
ind.shape = tuple(sh_i if i == j else 1 for j in range(nd))
431+
return ind
432+
433+
434+
def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
435+
"""
436+
Returns elements from an array at the one-dimensional indices specified
437+
by ``indices`` along a provided ``axis``.
438+
439+
Args:
440+
x (usm_ndarray):
441+
input array. Must be compatible with ``indices``, except for the
442+
axis (dimension) specified by ``axis``.
443+
indices (usm_ndarray):
444+
array indices. Must have the same rank (i.e., number of dimensions)
445+
as ``x``.
446+
axis: int
447+
axis along which to select values. If ``axis`` is negative, the
448+
function determines the axis along which to select values by
449+
counting from the last dimension. Default: ``-1``.
450+
mode (str, optional):
451+
How out-of-bounds indices will be handled. Possible values
452+
are:
453+
454+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
455+
negative indices.
456+
- ``"clip"``: clips indices to (``0 <= i < n``).
457+
458+
Default: ``"wrap"``.
459+
460+
Returns:
461+
usm_ndarray:
462+
an array having the same data type as ``x``. The returned array has
463+
the same rank (i.e., number of dimensions) as ``x`` and a shape
464+
determined according to :ref:`broadcasting`, except for the axis
465+
(dimension) specified by ``axis`` whose size must equal the size
466+
of the corresponding axis (dimension) in ``indices``.
467+
468+
Note:
469+
Treatment of the out-of-bound indices in ``indices`` array is controlled
470+
by the value of ``mode`` keyword.
471+
"""
472+
if not isinstance(x, dpt.usm_ndarray):
473+
raise TypeError
474+
if not isinstance(indices, dpt.usm_ndarray):
475+
raise TypeError
476+
x_nd = x.ndim
477+
if x_nd != indices.ndim:
478+
raise ValueError
479+
pp = normalize_axis_index(operator.index(axis), x_nd)
480+
out_usm_type = dpctl.utils.get_coerced_usm_type(
481+
(x.usm_type, indices.usm_type)
482+
)
483+
exec_q = dpctl.utils.get_execution_queue((x.sycl_queue, indices.sycl_queue))
484+
if exec_q is None:
485+
raise dpctl.utils.ExecutionPlacementError(
486+
"Execution placement can not be unambiguously inferred "
487+
"from input arguments. "
488+
)
489+
mode_i = _get_indexing_mode(mode)
490+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
491+
_ind = tuple(
492+
(
493+
indices
494+
if i == pp
495+
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
496+
)
497+
for i in range(x_nd)
498+
)
499+
return _take_multi_index(x, _ind, 0, mode=mode_i)

0 commit comments

Comments
 (0)