Skip to content

Commit e5506bf

Browse files
authored
Implement dpnp.ravel_multi_index and dpnp.unravel_index (#2022)
* implement dpnp.ravel_multi_index and dpnp.unravel_index * Applied review comments
1 parent acb5767 commit e5506bf

File tree

7 files changed

+355
-32
lines changed

7 files changed

+355
-32
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ env:
5454
third_party/cupy/fft_tests
5555
third_party/cupy/creation_tests
5656
third_party/cupy/indexing_tests/test_indexing.py
57+
third_party/cupy/indexing_tests/test_generate.py
5758
third_party/cupy/lib_tests
5859
third_party/cupy/linalg_tests
5960
third_party/cupy/logic_tests

dpnp/dpnp_iface_indexing.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@
7272
"put",
7373
"put_along_axis",
7474
"putmask",
75+
"ravel_multi_index",
7576
"select",
7677
"take",
7778
"take_along_axis",
7879
"tril_indices",
7980
"tril_indices_from",
8081
"triu_indices",
8182
"triu_indices_from",
83+
"unravel_index",
8284
]
8385

8486

@@ -133,6 +135,33 @@ def _build_along_axis_index(a, ind, axis):
133135
return tuple(fancy_index)
134136

135137

138+
def _ravel_multi_index_checks(multi_index, dims, order):
139+
dpnp.check_supported_arrays_type(*multi_index)
140+
ndim = len(dims)
141+
if len(multi_index) != ndim:
142+
raise ValueError(
143+
f"parameter multi_index must be a sequence of length {ndim}"
144+
)
145+
dim_mul = 1.0
146+
for d in dims:
147+
if not isinstance(d, int):
148+
raise TypeError(
149+
f"{type(d)} object cannot be interpreted as an integer"
150+
)
151+
dim_mul *= d
152+
153+
if dim_mul > dpnp.iinfo(dpnp.int64).max:
154+
raise ValueError(
155+
"invalid dims: array size defined by dims is larger than the "
156+
"maximum possible size"
157+
)
158+
if order not in ("C", "c", "F", "f", None):
159+
raise ValueError(
160+
"Unrecognized `order` keyword value, expecting "
161+
f"'C' or 'F', but got '{order}'"
162+
)
163+
164+
136165
def choose(x1, choices, out=None, mode="raise"):
137166
"""
138167
Construct an array from an index array and a set of arrays to choose from.
@@ -1429,6 +1458,112 @@ def putmask(x1, mask, values):
14291458
return call_origin(numpy.putmask, x1, mask, values, dpnp_inplace=True)
14301459

14311460

1461+
def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
1462+
"""
1463+
Converts a tuple of index arrays into an array of flat indices, applying
1464+
boundary modes to the multi-index.
1465+
1466+
For full documentation refer to :obj:`numpy.ravel_multi_index`.
1467+
1468+
Parameters
1469+
----------
1470+
multi_index : tuple of {dpnp.ndarray, usm_ndarray}
1471+
A tuple of integer arrays, one array for each dimension.
1472+
dims : tuple or list of ints
1473+
The shape of array into which the indices from ``multi_index`` apply.
1474+
mode : {"raise", "wrap" or "clip'}, optional
1475+
Specifies how out-of-bounds indices are handled. Can specify either
1476+
one mode or a tuple of modes, one mode per index:
1477+
- "raise" -- raise an error
1478+
- "wrap" -- wrap around
1479+
- "clip" -- clip to the range
1480+
In "clip" mode, a negative index which would normally wrap will
1481+
clip to 0 instead.
1482+
Default: ``"raise"``.
1483+
order : {None, "C", "F"}, optional
1484+
Determines whether the multi-index should be viewed as indexing in
1485+
row-major (C-style) or column-major (Fortran-style) order.
1486+
Default: ``"C"``.
1487+
1488+
Returns
1489+
-------
1490+
raveled_indices : dpnp.ndarray
1491+
An array of indices into the flattened version of an array of
1492+
dimensions ``dims``.
1493+
1494+
See Also
1495+
--------
1496+
:obj:`dpnp.unravel_index` : Converts array of flat indices into a tuple of
1497+
coordinate arrays.
1498+
1499+
Examples
1500+
--------
1501+
>>> import dpnp as np
1502+
>>> arr = np.array([[3, 6, 6], [4, 5, 1]])
1503+
>>> np.ravel_multi_index(arr, (7, 6))
1504+
array([22, 41, 37])
1505+
>>> np.ravel_multi_index(arr, (7, 6), order="F")
1506+
array([31, 41, 13])
1507+
>>> np.ravel_multi_index(arr, (4, 6), mode="clip")
1508+
array([22, 23, 19])
1509+
>>> np.ravel_multi_index(arr, (4, 4), mode=("clip", "wrap"))
1510+
array([12, 13, 13])
1511+
>>> arr = np.array([3, 1, 4, 1])
1512+
>>> np.ravel_multi_index(arr, (6, 7, 8, 9))
1513+
array(1621)
1514+
1515+
"""
1516+
1517+
_ravel_multi_index_checks(multi_index, dims, order)
1518+
1519+
ndim = len(dims)
1520+
if isinstance(mode, str):
1521+
mode = (mode,) * ndim
1522+
1523+
s = 1
1524+
ravel_strides = [1] * ndim
1525+
1526+
multi_index = tuple(multi_index)
1527+
usm_type_alloc, sycl_queue_alloc = get_usm_allocations(multi_index)
1528+
1529+
order = "C" if order is None else order.upper()
1530+
if order == "C":
1531+
for i in range(ndim - 2, -1, -1):
1532+
s = s * dims[i + 1]
1533+
ravel_strides[i] = s
1534+
else:
1535+
for i in range(1, ndim):
1536+
s = s * dims[i - 1]
1537+
ravel_strides[i] = s
1538+
1539+
multi_index = dpnp.broadcast_arrays(*multi_index)
1540+
raveled_indices = dpnp.zeros(
1541+
multi_index[0].shape,
1542+
dtype=dpnp.int64,
1543+
usm_type=usm_type_alloc,
1544+
sycl_queue=sycl_queue_alloc,
1545+
)
1546+
for d, stride, idx, _mode in zip(dims, ravel_strides, multi_index, mode):
1547+
if not dpnp.can_cast(idx, dpnp.int64, "same_kind"):
1548+
raise TypeError(
1549+
f"multi_index entries could not be cast from dtype({idx.dtype})"
1550+
f" to dtype({dpnp.int64}) according to the rule 'same_kind'"
1551+
)
1552+
idx = idx.astype(dpnp.int64, copy=False)
1553+
1554+
if _mode == "raise":
1555+
if dpnp.any(dpnp.logical_or(idx >= d, idx < 0)):
1556+
raise ValueError("invalid entry in coordinates array")
1557+
elif _mode == "clip":
1558+
idx = dpnp.clip(idx, 0, d - 1)
1559+
elif _mode == "wrap":
1560+
idx = idx % d
1561+
else:
1562+
raise ValueError(f"Unrecognized mode: {_mode}")
1563+
raveled_indices += stride * idx
1564+
return raveled_indices
1565+
1566+
14321567
def select(condlist, choicelist, default=0):
14331568
"""
14341569
Return an array drawn from elements in `choicelist`, depending on
@@ -2177,3 +2312,78 @@ def triu_indices_from(arr, k=0):
21772312
usm_type=arr.usm_type,
21782313
sycl_queue=arr.sycl_queue,
21792314
)
2315+
2316+
2317+
def unravel_index(indices, shape, order="C"):
2318+
"""Converts array of flat indices into a tuple of coordinate arrays.
2319+
2320+
For full documentation refer to :obj:`numpy.unravel_index`.
2321+
2322+
Parameters
2323+
----------
2324+
indices : {dpnp.ndarray, usm_ndarray}
2325+
An integer array whose elements are indices into the flattened version
2326+
of an array of dimensions ``shape``.
2327+
shape : tuple or list of ints
2328+
The shape of the array to use for unraveling ``indices``.
2329+
order : {None, "C", "F"}, optional
2330+
Determines whether the indices should be viewed as indexing in
2331+
row-major (C-style) or column-major (Fortran-style) order.
2332+
Default: ``"C"``.
2333+
2334+
Returns
2335+
-------
2336+
unraveled_coords : tuple of dpnp.ndarray
2337+
Each array in the tuple has the same shape as the indices array.
2338+
2339+
See Also
2340+
--------
2341+
:obj:`dpnp.ravel_multi_index` : Converts a tuple of index arrays into an
2342+
array of flat indices.
2343+
2344+
2345+
Examples
2346+
--------
2347+
import dpnp as np
2348+
>>> np.unravel_index(np.array([22, 41, 37]), (7, 6))
2349+
(array([3, 6, 6]), array([4, 5, 1]))
2350+
>>> np.unravel_index(np.array([31, 41, 13]), (7, 6), order="F")
2351+
(array([3, 6, 6]), array([4, 5, 1]))
2352+
2353+
>>> np.unravel_index(np.array(1621), (6, 7, 8, 9))
2354+
(array(3), array(1), array(4), array(1))
2355+
2356+
"""
2357+
2358+
dpnp.check_supported_arrays_type(indices)
2359+
2360+
if order not in ("C", "c", "F", "f", None):
2361+
raise ValueError(
2362+
"Unrecognized `order` keyword value, expecting "
2363+
f"'C' or 'F', but got '{order}'"
2364+
)
2365+
order = "C" if order is None else order.upper()
2366+
if order == "C":
2367+
shape = reversed(shape)
2368+
2369+
if not dpnp.can_cast(indices, dpnp.int64, "same_kind"):
2370+
raise TypeError(
2371+
"Iterator operand 0 dtype could not be cast from dtype("
2372+
f"{indices.dtype}) to dtype({dpnp.int64}) according to the rule "
2373+
"'same_kind'"
2374+
)
2375+
2376+
if (indices < 0).any():
2377+
raise ValueError("invalid entry in index array")
2378+
2379+
unraveled_coords = []
2380+
for dim in shape:
2381+
unraveled_coords.append(indices % dim)
2382+
indices = indices // dim
2383+
2384+
if (indices > 0).any():
2385+
raise ValueError("invalid entry in index array")
2386+
2387+
if order == "C":
2388+
unraveled_coords = reversed(unraveled_coords)
2389+
return tuple(unraveled_coords)

tests/skipped_tests.tbl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::te
6060
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1
6161
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_2
6262
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_3
63-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic
64-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_clip
65-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_nd_coords
66-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_raise
67-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_wrap
68-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_dims_overflow
69-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_float_dims
70-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_mode
71-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_multi_index_dtype
72-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_multi_index_shape
73-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_order
74-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_multi_index_broadcasting
7563
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_1
7664
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_2
7765
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_3
@@ -81,10 +69,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_6
8169
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_7
8270
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_8
8371
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_9
84-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test
85-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_dtype
86-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
87-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order
8872

8973
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises
9074
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestAxisConcatenator::te
113113
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_1
114114
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_2
115115
tests/third_party/cupy/indexing_tests/test_generate.py::TestC_::test_c_3
116-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic
117-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_clip
118-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_nd_coords
119-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_raise
120-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_basic_wrap
121-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_dims_overflow
122-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_float_dims
123-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_mode
124-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_multi_index_dtype
125-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_multi_index_shape
126-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_invalid_order
127-
tests/third_party/cupy/indexing_tests/test_generate.py::TestRavelMultiIndex::test_multi_index_broadcasting
128116
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_1
129117
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_2
130118
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_3
@@ -134,10 +122,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_6
134122
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_7
135123
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_8
136124
tests/third_party/cupy/indexing_tests/test_generate.py::TestR_::test_r_9
137-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test
138-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_dtype
139-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
140-
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order
141125

142126
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmaskDifferentDtypes::test_putmask_differnt_dtypes_raises
143127
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises

0 commit comments

Comments
 (0)