Skip to content

Commit 64093b2

Browse files
Add support of dpnp.extract() (#1340)
* Add dpnp.extract() using dpctl.tensor.extract()
1 parent 29a2063 commit 64093b2

File tree

8 files changed

+75
-15
lines changed

8 files changed

+75
-15
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"diag_indices",
5555
"diag_indices_from",
5656
"diagonal",
57+
"extract",
5758
"fill_diagonal",
5859
"indices",
5960
"nonzero",
@@ -232,6 +233,40 @@ def diagonal(x1, offset=0, axis1=0, axis2=1):
232233
return call_origin(numpy.diagonal, x1, offset, axis1, axis2)
233234

234235

236+
def extract(condition, x):
237+
"""
238+
Return the elements of an array that satisfy some condition.
239+
For full documentation refer to :obj:`numpy.extract`.
240+
241+
Returns
242+
-------
243+
y : dpnp.ndarray
244+
Rank 1 array of values from `x` where `condition` is True.
245+
246+
Limitations
247+
-----------
248+
Parameters `condition` and `x` are supported either as
249+
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
250+
Parameter `x` must be the same shape as `condition`.
251+
Otherwise the function will be executed sequentially on CPU.
252+
"""
253+
254+
check_input_type = lambda x: isinstance(x, (dpnp_array, dpt.usm_ndarray))
255+
if check_input_type(condition) and check_input_type(x):
256+
if condition.shape != x.shape:
257+
pass
258+
else:
259+
dpt_condition = (
260+
condition.get_array()
261+
if isinstance(condition, dpnp_array)
262+
else condition
263+
)
264+
dpt_array = x.get_array() if isinstance(x, dpnp_array) else x
265+
return dpnp_array._create_from_usm_ndarray(dpt.extract(dpt_condition, dpt_array))
266+
267+
return call_origin(numpy.extract, condition, x)
268+
269+
235270
def fill_diagonal(x1, val, wrap=False):
236271
"""
237272
Fill the main diagonal of the given array of any dimensionality.
@@ -296,7 +331,7 @@ def nonzero(x, /):
296331
-------
297332
y : tuple[dpnp.ndarray]
298333
Indices of elements that are non-zero.
299-
334+
300335
Limitations
301336
-----------
302337
Parameters `x` is supported as either :class:`dpnp.ndarray`

tests/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from tests.third_party.cupy import testing as cupy_testing
2+
from .helper import has_support_aspect64
23
import dpnp
34
import numpy
45

@@ -17,6 +18,8 @@
1718

1819

1920
def _shaped_arange(shape, xp=dpnp, dtype=dpnp.float64, order='C'):
21+
if dtype is dpnp.float64:
22+
dtype = dpnp.float32 if not has_support_aspect64() else dtype
2023
res = xp.array(orig_shaped_arange(shape, xp=numpy, dtype=dtype, order=order), dtype=dtype)
2124
return res
2225

tests/helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,12 @@ def is_win_platform():
9191
Return True if a test is runing on Windows OS, False otherwise.
9292
"""
9393
return platform.startswith('win')
94+
95+
96+
def has_support_aspect64(device=None):
97+
"""
98+
Return True if the device supports 64-bit precision floating point operations,
99+
False otherwise.
100+
"""
101+
dev = dpctl.select_default_device() if device is None else device
102+
return dev.has_aspect_fp64

tests/skipped_tests.tbl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr
456456
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
457457
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
458458
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
459-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract
460-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_empty_1dim
461-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_no_bool
462-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_shape_mismatch
463-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch
464-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch2
465459
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
466460
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
467461
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -650,12 +650,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr
650650
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
651651
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
652652
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
653-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract
654-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_empty_1dim
655-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_no_bool
656-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_shape_mismatch
657-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch
658-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch2
659653
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
660654
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
661655
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist

tests/test_indexing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pytest
2+
from .helper import get_all_dtypes
3+
24

35
import dpnp
46

@@ -53,6 +55,18 @@ def test_diagonal(array, offset):
5355
assert_array_equal(expected, result)
5456

5557

58+
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
59+
@pytest.mark.parametrize("cond_dtype", get_all_dtypes())
60+
def test_extract_1d(arr_dtype, cond_dtype):
61+
a = numpy.array([-2, -1, 0, 1, 2, 3], dtype=arr_dtype)
62+
ia = dpnp.array(a)
63+
cond = numpy.array([1, -1, 2, 0, -2, 3], dtype=cond_dtype)
64+
icond = dpnp.array(cond)
65+
expected = numpy.extract(cond, a)
66+
result = dpnp.extract(icond, ia)
67+
assert_array_equal(expected, result)
68+
69+
5670
@pytest.mark.parametrize("val",
5771
[-1, 0, 1],
5872
ids=['-1', '0', '1'])

tests/third_party/cupy/indexing_tests/test_indexing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def test_extract_no_bool(self, xp, dtype):
166166
b = xp.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=dtype)
167167
return xp.extract(b, a)
168168

169+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
169170
@testing.numpy_cupy_array_equal()
170171
def test_extract_shape_mismatch(self, xp):
171172
a = testing.shaped_arange((2, 3), xp)
@@ -174,20 +175,23 @@ def test_extract_shape_mismatch(self, xp):
174175
[True, False]])
175176
return xp.extract(b, a)
176177

178+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
177179
@testing.numpy_cupy_array_equal()
178180
def test_extract_size_mismatch(self, xp):
179181
a = testing.shaped_arange((3, 3), xp)
180182
b = xp.array([[True, False, True],
181183
[False, True, False]])
182184
return xp.extract(b, a)
183185

186+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
184187
@testing.numpy_cupy_array_equal()
185188
def test_extract_size_mismatch2(self, xp):
186189
a = testing.shaped_arange((3, 3), xp)
187190
b = xp.array([[True, False, True, False],
188191
[False, True, False, True]])
189192
return xp.extract(b, a)
190193

194+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
191195
@testing.numpy_cupy_array_equal()
192196
def test_extract_empty_1dim(self, xp):
193197
a = testing.shaped_arange((3, 3), xp)

tests/third_party/cupy/testing/helper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# from dpnp.core import internal
1616
from tests.third_party.cupy.testing import array
1717
from tests.third_party.cupy.testing import parameterized
18+
from dpctl import select_default_device
1819
# import dpnp
1920
# import dpnp.scipy.sparse
2021

@@ -654,9 +655,15 @@ def test_func(self, *args, **kw):
654655
return test_func
655656
return decorator
656657

658+
def _get_supported_float_dtypes():
659+
if select_default_device().has_aspect_fp64:
660+
return (numpy.float64, numpy.float32)
661+
else:
662+
return (numpy.float32,)
663+
657664

658665
_complex_dtypes = ()
659-
_regular_float_dtypes = (numpy.float64, numpy.float32)
666+
_regular_float_dtypes = _get_supported_float_dtypes()
660667
_float_dtypes = _regular_float_dtypes
661668
_signed_dtypes = ()
662669
_unsigned_dtypes = tuple(numpy.dtype(i).type for i in 'BHILQ')
@@ -667,7 +674,7 @@ def test_func(self, *args, **kw):
667674

668675

669676
def _make_all_dtypes(no_float16, no_bool, no_complex):
670-
return (numpy.float64, numpy.float32, numpy.int64, numpy.int32)
677+
return (numpy.int64, numpy.int32) + _get_supported_float_dtypes()
671678
# if no_float16:
672679
# dtypes = _regular_float_dtypes
673680
# else:

0 commit comments

Comments
 (0)