Skip to content

Commit 258a2d7

Browse files
authored
Merge branch 'master' into update_tests
2 parents fe565b5 + 2bab446 commit 258a2d7

File tree

10 files changed

+372
-33
lines changed

10 files changed

+372
-33
lines changed

dpnp/dpnp_array.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,14 @@ def clip(self, min=None, max=None, out=None, **kwargs):
786786

787787
return dpnp.clip(self, min, max, out=out, **kwargs)
788788

789-
# 'compress',
789+
def compress(self, condition, axis=None, out=None):
790+
"""
791+
Select slices of an array along a given axis.
792+
793+
Refer to :obj:`dpnp.compress` for full documentation.
794+
"""
795+
796+
return dpnp.compress(condition, self, axis=axis, out=out)
790797

791798
def conj(self):
792799
"""

dpnp/dpnp_iface_bitwise.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def binary_repr(num, width=None):
140140
x2 : {dpnp.ndarray, usm_ndarray, scalar}
141141
Second input array, also expected to have integer or boolean data
142142
type. Both inputs `x1` and `x2` can not be scalars at the same time.
143+
If ``x1.shape != x2.shape``, they must be broadcastable to a common shape
144+
(which becomes the shape of the output).
143145
out : {None, dpnp.ndarray, usm_ndarray}, optional
144146
Output array to populate.
145147
Array must have the correct shape and the expected data type.
@@ -224,6 +226,8 @@ def binary_repr(num, width=None):
224226
x2 : {dpnp.ndarray, usm_ndarray, scalar}
225227
Second input array, also expected to have integer or boolean data
226228
type. Both inputs `x1` and `x2` can not be scalars at the same time.
229+
If ``x1.shape != x2.shape``, they must be broadcastable to a common shape
230+
(which becomes the shape of the output).
227231
out : {None, dpnp.ndarray, usm_ndarray}, optional
228232
Output array to populate.
229233
Array must have the correct shape and the expected data type.
@@ -299,6 +303,8 @@ def binary_repr(num, width=None):
299303
x2 : {dpnp.ndarray, usm_ndarray, scalar}
300304
Second input array, also expected to have integer or boolean data
301305
type. Both inputs `x1` and `x2` can not be scalars at the same time.
306+
If ``x1.shape != x2.shape``, they must be broadcastable to a common shape
307+
(which becomes the shape of the output).
302308
out : {None, dpnp.ndarray, usm_ndarray}, optional
303309
Output array to populate.
304310
Array must have the correct shape and the expected data type.
@@ -458,6 +464,8 @@ def binary_repr(num, width=None):
458464
Second input array, also expected to have integer data type.
459465
Each element must be greater than or equal to ``0``.
460466
Both inputs `x1` and `x2` can not be scalars at the same time.
467+
If ``x1.shape != x2.shape``, they must be broadcastable to a common shape
468+
(which becomes the shape of the output).
461469
out : {None, dpnp.ndarray, usm_ndarray}, optional
462470
Output array to populate.
463471
Array must have the correct shape and the expected data type.
@@ -532,6 +540,8 @@ def binary_repr(num, width=None):
532540
Second input array, also expected to have integer data type.
533541
Each element must be greater than or equal to ``0``.
534542
Both inputs `x1` and `x2` can not be scalars at the same time.
543+
If ``x1.shape != x2.shape``, they must be broadcastable to a common shape
544+
(which becomes the shape of the output).
535545
out : {None, dpnp.ndarray, usm_ndarray}, optional
536546
Output array to populate.
537547
Array must have the correct shape and the expected data type.

dpnp/dpnp_iface_indexing.py

Lines changed: 168 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,16 @@
3737
3838
"""
3939

40+
# pylint: disable=protected-access
41+
4042
import operator
4143

4244
import dpctl.tensor as dpt
45+
import dpctl.tensor._tensor_impl as ti
46+
import dpctl.utils as dpu
4347
import numpy
48+
from dpctl.tensor._copy_utils import _nonzero_impl
49+
from dpctl.tensor._indexing_functions import _get_indexing_mode
4450
from dpctl.tensor._numpy_helper import normalize_axis_index
4551

4652
import dpnp
@@ -55,6 +61,7 @@
5561

5662
__all__ = [
5763
"choose",
64+
"compress",
5865
"diag_indices",
5966
"diag_indices_from",
6067
"diagonal",
@@ -155,6 +162,157 @@ def choose(x1, choices, out=None, mode="raise"):
155162
return call_origin(numpy.choose, x1, choices, out, mode)
156163

157164

165+
def _take_index(x, inds, axis, q, usm_type, out=None, mode=0):
166+
# arg validation assumed done by caller
167+
x_sh = x.shape
168+
axis_end = axis + 1
169+
if 0 in x_sh[axis:axis_end] and inds.size != 0:
170+
raise IndexError("cannot take non-empty indices from an empty axis")
171+
res_sh = x_sh[:axis] + inds.shape + x_sh[axis_end:]
172+
173+
if out is not None:
174+
out = dpnp.get_usm_ndarray(out)
175+
176+
if not out.flags.writable:
177+
raise ValueError("provided `out` array is read-only")
178+
179+
if out.shape != res_sh:
180+
raise ValueError(
181+
"The shape of input and output arrays are inconsistent. "
182+
f"Expected output shape is {res_sh}, got {out.shape}"
183+
)
184+
185+
if x.dtype != out.dtype:
186+
raise TypeError(
187+
f"Output array of type {x.dtype} is needed, " f"got {out.dtype}"
188+
)
189+
190+
if dpu.get_execution_queue((q, out.sycl_queue)) is None:
191+
raise dpu.ExecutionPlacementError(
192+
"Input and output allocation queues are not compatible"
193+
)
194+
195+
if ti._array_overlap(x, out):
196+
# Allocate a temporary buffer to avoid memory overlapping.
197+
out = dpt.empty_like(out)
198+
else:
199+
out = dpt.empty(res_sh, dtype=x.dtype, usm_type=usm_type, sycl_queue=q)
200+
201+
_manager = dpu.SequentialOrderManager[q]
202+
dep_evs = _manager.submitted_events
203+
204+
h_ev, take_ev = ti._take(
205+
src=x,
206+
ind=(inds,),
207+
dst=out,
208+
axis_start=axis,
209+
mode=mode,
210+
sycl_queue=q,
211+
depends=dep_evs,
212+
)
213+
_manager.add_event_pair(h_ev, take_ev)
214+
215+
return out
216+
217+
218+
def compress(condition, a, axis=None, out=None):
219+
"""
220+
Return selected slices of an array along given axis.
221+
222+
A slice of `a` is returned for each index along `axis` where `condition`
223+
is ``True``.
224+
225+
For full documentation refer to :obj:`numpy.choose`.
226+
227+
Parameters
228+
----------
229+
condition : {array_like, dpnp.ndarray, usm_ndarray}
230+
Array that selects which entries to extract. If the length of
231+
`condition` is less than the size of `a` along `axis`, then
232+
the output is truncated to the length of `condition`.
233+
a : {dpnp.ndarray, usm_ndarray}
234+
Array to extract from.
235+
axis : {None, int}, optional
236+
Axis along which to extract slices. If ``None``, works over the
237+
flattened array.
238+
Default: ``None``.
239+
out : {None, dpnp.ndarray, usm_ndarray}, optional
240+
If provided, the result will be placed in this array. It should
241+
be of the appropriate shape and dtype.
242+
Default: ``None``.
243+
244+
Returns
245+
-------
246+
out : dpnp.ndarray
247+
A copy of the slices of `a` where `condition` is ``True``.
248+
249+
See also
250+
--------
251+
:obj:`dpnp.take` : Take elements from an array along an axis.
252+
:obj:`dpnp.choose` : Construct an array from an index array and a set of
253+
arrays to choose from.
254+
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
255+
:obj:`dpnp.diagonal` : Return specified diagonals.
256+
:obj:`dpnp.select` : Return an array drawn from elements in `choicelist`,
257+
depending on conditions.
258+
:obj:`dpnp.ndarray.compress` : Equivalent method.
259+
:obj:`dpnp.extract` : Equivalent function when working on 1-D arrays.
260+
261+
Examples
262+
--------
263+
>>> import numpy as np
264+
>>> a = np.array([[1, 2], [3, 4], [5, 6]])
265+
>>> a
266+
array([[1, 2],
267+
[3, 4],
268+
[5, 6]])
269+
>>> np.compress([0, 1], a, axis=0)
270+
array([[3, 4]])
271+
>>> np.compress([False, True, True], a, axis=0)
272+
array([[3, 4],
273+
[5, 6]])
274+
>>> np.compress([False, True], a, axis=1)
275+
array([[2],
276+
[4],
277+
[6]])
278+
279+
Working on the flattened array does not return slices along an axis but
280+
selects elements.
281+
282+
>>> np.compress([False, True], a)
283+
array([2])
284+
"""
285+
286+
dpnp.check_supported_arrays_type(a)
287+
if axis is None:
288+
if a.ndim != 1:
289+
a = dpnp.ravel(a)
290+
axis = 0
291+
axis = normalize_axis_index(operator.index(axis), a.ndim)
292+
293+
a_ary = dpnp.get_usm_ndarray(a)
294+
cond_ary = dpnp.as_usm_ndarray(
295+
condition,
296+
dtype=dpnp.bool,
297+
usm_type=a_ary.usm_type,
298+
sycl_queue=a_ary.sycl_queue,
299+
)
300+
301+
if not cond_ary.ndim == 1:
302+
raise ValueError(
303+
"`condition` must be a 1-D array or un-nested sequence"
304+
)
305+
306+
res_usm_type, exec_q = get_usm_allocations([a_ary, cond_ary])
307+
308+
# _nonzero_impl synchronizes and returns a tuple of usm_ndarray indices
309+
inds = _nonzero_impl(cond_ary)
310+
311+
res = _take_index(a_ary, inds[0], axis, exec_q, res_usm_type, out=out)
312+
313+
return dpnp.get_result_array(res, out=out)
314+
315+
158316
def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):
159317
"""
160318
Return the indices to access the main diagonal of an array.
@@ -1806,8 +1964,8 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
18061964
18071965
"""
18081966

1809-
if mode not in ("wrap", "clip"):
1810-
raise ValueError(f"`mode` must be 'wrap' or 'clip', but got `{mode}`.")
1967+
# sets mode to 0 for "wrap" and 1 for "clip", raises otherwise
1968+
mode = _get_indexing_mode(mode)
18111969

18121970
usm_a = dpnp.get_usm_ndarray(a)
18131971
if not dpnp.is_supported_array_type(indices):
@@ -1817,34 +1975,28 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
18171975
else:
18181976
usm_ind = dpnp.get_usm_ndarray(indices)
18191977

1978+
res_usm_type, exec_q = get_usm_allocations([usm_a, usm_ind])
1979+
18201980
a_ndim = a.ndim
18211981
if axis is None:
1822-
res_shape = usm_ind.shape
1823-
18241982
if a_ndim > 1:
1825-
# dpt.take requires flattened input array
1983+
# flatten input array
18261984
usm_a = dpt.reshape(usm_a, -1)
1985+
axis = 0
18271986
elif a_ndim == 0:
18281987
axis = normalize_axis_index(operator.index(axis), 1)
1829-
res_shape = usm_ind.shape
18301988
else:
18311989
axis = normalize_axis_index(operator.index(axis), a_ndim)
1832-
a_sh = a.shape
1833-
res_shape = a_sh[:axis] + usm_ind.shape + a_sh[axis + 1 :]
1834-
1835-
if usm_ind.ndim != 1:
1836-
# dpt.take supports only 1-D array of indices
1837-
usm_ind = dpt.reshape(usm_ind, -1)
18381990

18391991
if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
18401992
# dpt.take supports only integer dtype for array of indices
18411993
usm_ind = dpt.astype(usm_ind, dpnp.intp, copy=False, casting="safe")
18421994

1843-
usm_res = dpt.take(usm_a, usm_ind, axis=axis, mode=mode)
1995+
usm_res = _take_index(
1996+
usm_a, usm_ind, axis, exec_q, res_usm_type, out=out, mode=mode
1997+
)
18441998

1845-
# need to reshape the result if shape of indices array was changed
1846-
result = dpnp.reshape(usm_res, res_shape)
1847-
return dpnp.get_result_array(result, out)
1999+
return dpnp.get_result_array(usm_res, out=out)
18482000

18492001

18502002
def take_along_axis(a, indices, axis, mode="wrap"):

0 commit comments

Comments
 (0)