Skip to content

Commit 689831a

Browse files
Reuse dpctl.tensort.take for dpnp.take (#1492)
* Reuse dpctl.tensort.take for dpnp.take * Add examples and use dpnp.is_supported_array_type * Use dpnp.get_usm_ndarray in take and update examples --------- Co-authored-by: Anton <[email protected]>
1 parent 85ea6f9 commit 689831a

File tree

10 files changed

+97
-141
lines changed

10 files changed

+97
-141
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,7 @@ enum class DPNPFuncName : size_t
487487
DPNP_FN_SVD_EXT, /**< Used in numpy.linalg.svd() impl, requires extra
488488
parameters */
489489
DPNP_FN_TAKE, /**< Used in numpy.take() impl */
490-
DPNP_FN_TAKE_EXT, /**< Used in numpy.take() impl, requires extra parameters
491-
*/
492-
DPNP_FN_TAN, /**< Used in numpy.tan() impl */
490+
DPNP_FN_TAN, /**< Used in numpy.tan() impl */
493491
DPNP_FN_TAN_EXT, /**< Used in numpy.tan() impl, requires extra parameters */
494492
DPNP_FN_TANH, /**< Used in numpy.tanh() impl */
495493
DPNP_FN_TANH_EXT, /**< Used in numpy.tanh() impl, requires extra parameters

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,32 +1059,5 @@ void func_map_init_indexing_func(func_map_t &fmap)
10591059
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_C128][eft_LNG] = {
10601060
eft_C128, (void *)dpnp_take_default_c<std::complex<double>, int64_t>};
10611061

1062-
// TODO: add a handling of other indexes types once DPCtl implementation of
1063-
// data copy is ready
1064-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_BLN][eft_INT] = {
1065-
eft_BLN, (void *)dpnp_take_ext_c<bool, int32_t>};
1066-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_INT][eft_INT] = {
1067-
eft_INT, (void *)dpnp_take_ext_c<int32_t, int32_t>};
1068-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_LNG][eft_INT] = {
1069-
eft_LNG, (void *)dpnp_take_ext_c<int64_t, int32_t>};
1070-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_FLT][eft_INT] = {
1071-
eft_FLT, (void *)dpnp_take_ext_c<float, int32_t>};
1072-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_DBL][eft_INT] = {
1073-
eft_DBL, (void *)dpnp_take_ext_c<double, int32_t>};
1074-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_C128][eft_INT] = {
1075-
eft_C128, (void *)dpnp_take_ext_c<std::complex<double>, int32_t>};
1076-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_BLN][eft_LNG] = {
1077-
eft_BLN, (void *)dpnp_take_ext_c<bool, int64_t>};
1078-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_INT][eft_LNG] = {
1079-
eft_INT, (void *)dpnp_take_ext_c<int32_t, int64_t>};
1080-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_LNG][eft_LNG] = {
1081-
eft_LNG, (void *)dpnp_take_ext_c<int64_t, int64_t>};
1082-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_FLT][eft_LNG] = {
1083-
eft_FLT, (void *)dpnp_take_ext_c<float, int64_t>};
1084-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_DBL][eft_LNG] = {
1085-
eft_DBL, (void *)dpnp_take_ext_c<double, int64_t>};
1086-
fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_C128][eft_LNG] = {
1087-
eft_C128, (void *)dpnp_take_ext_c<std::complex<double>, int64_t>};
1088-
10891062
return;
10901063
}

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
295295
DPNP_FN_SUM_EXT
296296
DPNP_FN_SVD
297297
DPNP_FN_SVD_EXT
298-
DPNP_FN_TAKE
299-
DPNP_FN_TAKE_EXT
300298
DPNP_FN_TAN
301299
DPNP_FN_TAN_EXT
302300
DPNP_FN_TANH

dpnp/dpnp_algo/dpnp_algo_indexing.pxi

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ __all__ += [
4545
"dpnp_put_along_axis",
4646
"dpnp_putmask",
4747
"dpnp_select",
48-
"dpnp_take",
4948
"dpnp_take_along_axis",
5049
"dpnp_tril_indices",
5150
"dpnp_tril_indices_from",
@@ -59,13 +58,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_choose_t)(c_dpctl.DPCTLSyclQueueRe
5958
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_diag_indices)(c_dpctl.DPCTLSyclQueueRef,
6059
void * , size_t,
6160
const c_dpctl.DPCTLEventVectorRef)
62-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
63-
void *,
64-
const size_t,
65-
void * ,
66-
void * ,
67-
size_t,
68-
const c_dpctl.DPCTLEventVectorRef)
6961
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef,
7062
void * ,
7163
const size_t,
@@ -417,42 +409,6 @@ cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default)
417409
return res_array
418410

419411

420-
cpdef utils.dpnp_descriptor dpnp_take(utils.dpnp_descriptor x1, utils.dpnp_descriptor indices):
421-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
422-
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(indices.dtype)
423-
424-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TAKE_EXT, param1_type, param2_type)
425-
426-
x1_obj = x1.get_array()
427-
428-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(indices.shape,
429-
kernel_data.return_type,
430-
None,
431-
device=x1_obj.sycl_device,
432-
usm_type=x1_obj.usm_type,
433-
sycl_queue=x1_obj.sycl_queue)
434-
435-
result_sycl_queue = result.get_array().sycl_queue
436-
437-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
438-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
439-
440-
cdef custom_indexing_2in_1out_func_ptr_t func = <custom_indexing_2in_1out_func_ptr_t > kernel_data.ptr
441-
442-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
443-
x1.get_data(),
444-
x1.size,
445-
indices.get_data(),
446-
result.get_data(),
447-
indices.size,
448-
NULL) # dep_events_ref
449-
450-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
451-
c_dpctl.DPCTLEvent_Delete(event_ref)
452-
453-
return result
454-
455-
456412
cpdef object dpnp_take_along_axis(object arr, object indices, int axis):
457413
cdef long size_arr = arr.size
458414
cdef shape_type_c shape_arr = arr.shape

dpnp/dpnp_array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,15 +1033,15 @@ def sum(
10331033

10341034
# 'swapaxes',
10351035

1036-
def take(self, indices, axis=None, out=None, mode="raise"):
1036+
def take(self, indices, /, *, axis=None, out=None, mode="wrap"):
10371037
"""
1038-
Take elements from an array.
1038+
Take elements from an array along an axis.
10391039
10401040
For full documentation refer to :obj:`numpy.take`.
10411041
10421042
"""
10431043

1044-
return dpnp.take(self, indices, axis, out, mode)
1044+
return dpnp.take(self, indices, axis=axis, out=out, mode=mode)
10451045

10461046
# 'tobytes',
10471047
# 'tofile',

dpnp/dpnp_iface_indexing.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -539,39 +539,82 @@ def select(condlist, choicelist, default=0):
539539
return call_origin(numpy.select, condlist, choicelist, default)
540540

541541

542-
def take(x1, indices, axis=None, out=None, mode="raise"):
542+
def take(x, indices, /, *, axis=None, out=None, mode="wrap"):
543543
"""
544-
Take elements from an array.
544+
Take elements from an array along an axis.
545545
546546
For full documentation refer to :obj:`numpy.take`.
547547
548+
Returns
549+
-------
550+
dpnp.ndarray
551+
An array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
552+
filled with elements from `x`.
553+
548554
Limitations
549555
-----------
550-
Input array is supported as :obj:`dpnp.ndarray`.
551-
Parameters ``axis``, ``out`` and ``mode`` are supported only with default values.
552-
Parameter ``indices`` is supported as :obj:`dpnp.ndarray`.
556+
Parameters `x` and `indices` are supported either as :class:`dpnp.ndarray`
557+
or :class:`dpctl.tensor.usm_ndarray`.
558+
Parameter `indices` is supported as 1-D array of integer data type.
559+
Parameter `out` is supported only with default value.
560+
Parameter `mode` is supported with ``wrap``(default) and ``clip`` mode.
561+
Providing parameter `axis` is optional when `x` is a 1-D array.
562+
Otherwise the function will be executed sequentially on CPU.
553563
554564
See Also
555565
--------
556566
:obj:`dpnp.compress` : Take elements using a boolean mask.
557567
:obj:`take_along_axis` : Take elements by matching the array and the index arrays.
568+
569+
Notes
570+
-----
571+
How out-of-bounds indices will be handled.
572+
"wrap" - clamps indices to (-n <= i < n), then wraps negative indices.
573+
"clip" - clips indices to (0 <= i < n)
574+
575+
Examples
576+
--------
577+
>>> import dpnp as np
578+
>>> x = np.array([4, 3, 5, 7, 6, 8])
579+
>>> indices = np.array([0, 1, 4])
580+
>>> np.take(x, indices)
581+
array([4, 3, 6])
582+
583+
In this example "fancy" indexing can be used.
584+
585+
>>> x[indices]
586+
array([4, 3, 6])
587+
588+
>>> indices = dpnp.array([-1, -6, -7, 5, 6])
589+
>>> np.take(x, indices)
590+
array([8, 4, 4, 8, 8])
591+
592+
>>> np.take(x, indices, mode="clip")
593+
array([4, 4, 4, 8, 8])
594+
558595
"""
559596

560-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
561-
indices_desc = dpnp.get_dpnp_descriptor(
562-
indices, copy_when_nondefault_queue=False
563-
)
564-
if x1_desc and indices_desc:
565-
if axis is not None:
597+
if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type(
598+
indices
599+
):
600+
if indices.ndim != 1 or not dpnp.issubdtype(
601+
indices.dtype, dpnp.integer
602+
):
603+
pass
604+
elif axis is None and x.ndim > 1:
566605
pass
567606
elif out is not None:
568607
pass
569-
elif mode != "raise":
608+
elif mode not in ("clip", "wrap"):
570609
pass
571610
else:
572-
return dpnp_take(x1_desc, indices_desc).get_pyobj()
611+
dpt_array = dpnp.get_usm_ndarray(x)
612+
dpt_indices = dpnp.get_usm_ndarray(indices)
613+
return dpnp_array._create_from_usm_ndarray(
614+
dpt.take(dpt_array, dpt_indices, axis=axis, mode=mode)
615+
)
573616

574-
return call_origin(numpy.take, x1, indices, axis, out, mode)
617+
return call_origin(numpy.take, x, indices, axis, out, mode)
575618

576619

577620
def take_along_axis(x1, indices, axis):

tests/skipped_tests.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr
401401
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
402402
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
403403
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
404-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
405404
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
406405
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
407406
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ tests/test_sycl_queue.py::test_modf[level_zero:gpu:0]
4949
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-trapz-data19]
5050
tests/test_sycl_queue.py::test_1in_1out[opencl:cpu:0-trapz-data19]
5151

52-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_no_axis
5352
tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_0_{n=2, ndim=2}::test_diag_indices
5453
tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_1_{n=2, ndim=3}::test_diag_indices
5554
tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_2_{n=2, ndim=1}::test_diag_indices
@@ -597,7 +596,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr
597596
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
598597
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
599598
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
600-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
601599
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
602600
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
603601
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/test_indexing.py

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -592,61 +592,51 @@ def test_select():
592592
assert_array_equal(expected, result)
593593

594594

595+
@pytest.mark.parametrize("array_type", get_all_dtypes())
595596
@pytest.mark.parametrize(
596-
"array_type",
597-
[
598-
numpy.bool8,
599-
numpy.int32,
600-
numpy.int64,
601-
numpy.float32,
602-
numpy.float64,
603-
numpy.complex128,
604-
],
605-
ids=["bool8", "int32", "int64", "float32", "float64", "complex128"],
597+
"indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"]
606598
)
607599
@pytest.mark.parametrize(
608-
"indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"]
600+
"indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"]
609601
)
602+
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
603+
def test_take_1d(indices, array_type, indices_type, mode):
604+
a = numpy.array([-2, -1, 0, 1, 2], dtype=array_type)
605+
ind = numpy.array(indices, dtype=indices_type)
606+
ia = dpnp.array(a)
607+
iind = dpnp.array(ind)
608+
expected = numpy.take(a, ind, mode=mode)
609+
result = dpnp.take(ia, iind, mode=mode)
610+
assert_array_equal(expected, result)
611+
612+
613+
@pytest.mark.parametrize("array_type", get_all_dtypes())
610614
@pytest.mark.parametrize(
611-
"indices",
612-
[[[0, 0], [0, 0]], [[1, 2], [1, 2]], [[1, 2], [3, 4]]],
613-
ids=["[[0, 0], [0, 0]]", "[[1, 2], [1, 2]]", "[[1, 2], [3, 4]]"],
615+
"indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"]
614616
)
615617
@pytest.mark.parametrize(
616-
"array",
617-
[
618-
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
619-
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
620-
[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]],
621-
[
622-
[[[1, 2], [3, 4]], [[1, 2], [2, 1]]],
623-
[[[1, 3], [3, 1]], [[0, 1], [1, 3]]],
624-
],
625-
[
626-
[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]],
627-
[[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]],
628-
],
629-
[
630-
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
631-
[[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]],
632-
],
633-
],
634-
ids=[
635-
"[[0, 1, 2], [3, 4, 5], [6, 7, 8]]",
636-
"[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]",
637-
"[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]",
638-
"[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]",
639-
"[[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]], [[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]]]",
640-
"[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]]",
641-
],
618+
"indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"]
642619
)
643-
def test_take(array, indices, array_type, indices_type):
644-
a = numpy.array(array, dtype=array_type)
620+
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
621+
@pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"])
622+
def test_take_2d(indices, array_type, indices_type, axis, mode):
623+
a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=array_type)
645624
ind = numpy.array(indices, dtype=indices_type)
646625
ia = dpnp.array(a)
647626
iind = dpnp.array(ind)
648-
expected = numpy.take(a, ind)
649-
result = dpnp.take(ia, iind)
627+
expected = numpy.take(a, ind, axis=axis, mode=mode)
628+
result = dpnp.take(ia, iind, axis=axis, mode=mode)
629+
assert_array_equal(expected, result)
630+
631+
632+
@pytest.mark.parametrize("array_type", get_all_dtypes())
633+
@pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"])
634+
@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"])
635+
def test_take_over_index(indices, array_type, mode):
636+
a = dpnp.array([-2, -1, 0, 1, 2], dtype=array_type)
637+
ind = dpnp.array(indices, dtype=dpnp.int64)
638+
expected = dpnp.array([-2, 2], dtype=a.dtype)
639+
result = dpnp.take(a, ind, mode=mode)
650640
assert_array_equal(expected, result)
651641

652642

tests/third_party/cupy/indexing_tests/test_indexing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_take_by_array(self, xp):
2828
b = xp.array([[1, 3], [2, 0]])
2929
return a.take(b, axis=1)
3030

31+
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
3132
@testing.numpy_cupy_array_equal()
3233
def test_take_no_axis(self, xp):
3334
a = testing.shaped_arange((2, 3, 4), xp)
@@ -46,7 +47,7 @@ def test_take_index_range_overflow(self, xp, dtype):
4647
if dtype in (numpy.int32, numpy.uint32):
4748
pytest.skip()
4849
iinfo = numpy.iinfo(dtype)
49-
a = xp.broadcast_to(xp.ones(1), (iinfo.max + 1,))
50+
a = xp.broadcast_to(xp.ones(1, dtype=dtype), (iinfo.max + 1,))
5051
b = xp.array([0], dtype=dtype)
5152
return a.take(b)
5253

0 commit comments

Comments
 (0)