Skip to content

Commit bb01f05

Browse files
dpnp.where() doesn`t work with 1 argument (#1326)
* Fix where operator for support passing 1 arg * Unskip and fix tests for where operator * Add support of dpnp.where() with x and y arguments * Update dpnp/backend/include/dpnp_iface_fptr.hpp * Use dpctl.tensor.nonzero() implementation --------- Co-authored-by: Anton Volkov <[email protected]> Co-authored-by: Anton <[email protected]>
1 parent cd24184 commit bb01f05

File tree

10 files changed

+484
-23
lines changed

10 files changed

+484
-23
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ typedef ssize_t shape_elem_type;
5757

5858
#include <dpctl_sycl_interface.h>
5959

60+
#include "dpnp_iface_fptr.hpp"
6061
#include "dpnp_iface_fft.hpp"
6162
#include "dpnp_iface_random.hpp"
6263

@@ -1683,6 +1684,57 @@ INP_DLLEXPORT void dpnp_var_c(void* array,
16831684
size_t naxis,
16841685
size_t ddof);
16851686

1687+
/**
1688+
* @ingroup BACKEND_API
1689+
* @brief Implementation of where function
1690+
*
1691+
* @param [in] q_ref Reference to SYCL queue.
1692+
* @param [out] result_out Output array.
1693+
* @param [in] result_size Size of output array.
1694+
* @param [in] result_ndim Number of output array dimensions.
1695+
* @param [in] result_shape Shape of output array.
1696+
* @param [in] result_strides Strides of output array.
1697+
* @param [in] condition_in Condition array.
1698+
* @param [in] condition_size Size of condition array.
1699+
* @param [in] condition_ndim Number of condition array dimensions.
1700+
* @param [in] condition_shape Shape of condition array.
1701+
* @param [in] condition_strides Strides of condition array.
1702+
* @param [in] input1_in First input array.
1703+
* @param [in] input1_size Size of first input array.
1704+
* @param [in] input1_ndim Number of first input array dimensions.
1705+
* @param [in] input1_shape Shape of first input array.
1706+
* @param [in] input1_strides Strides of first input array.
1707+
* @param [in] input2_in Second input array.
1708+
* @param [in] input2_size Size of second input array.
1709+
* @param [in] input2_ndim Number of second input array dimensions.
1710+
* @param [in] input2_shape Shape of second input array.
1711+
* @param [in] input2_strides Strides of second input array.
1712+
* @param [in] dep_event_vec_ref Reference to vector of SYCL events.
1713+
*/
1714+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
1715+
INP_DLLEXPORT DPCTLSyclEventRef dpnp_where_c(DPCTLSyclQueueRef q_ref,
1716+
void* result_out,
1717+
const size_t result_size,
1718+
const size_t result_ndim,
1719+
const shape_elem_type* result_shape,
1720+
const shape_elem_type* result_strides,
1721+
const void* condition_in,
1722+
const size_t condition_size,
1723+
const size_t condition_ndim,
1724+
const shape_elem_type* condition_shape,
1725+
const shape_elem_type* condition_strides,
1726+
const void* input1_in,
1727+
const size_t input1_size,
1728+
const size_t input1_ndim,
1729+
const shape_elem_type* input1_shape,
1730+
const shape_elem_type* input1_strides,
1731+
const void* input2_in,
1732+
const size_t input2_size,
1733+
const size_t input2_ndim,
1734+
const shape_elem_type* input2_shape,
1735+
const shape_elem_type* input2_strides,
1736+
const DPCTLEventVectorRef dep_event_vec_ref);
1737+
16861738
/**
16871739
* @ingroup BACKEND_API
16881740
* @brief Implementation of invert function

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ enum class DPNPFuncName : size_t
377377
DPNP_FN_VANDER_EXT, /**< Used in numpy.vander() impl, requires extra parameters */
378378
DPNP_FN_VAR, /**< Used in numpy.var() impl */
379379
DPNP_FN_VAR_EXT, /**< Used in numpy.var() impl, requires extra parameters */
380+
DPNP_FN_WHERE_EXT, /**< Used in numpy.where() impl, requires extra parameters */
380381
DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */
381382
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */
382383
DPNP_FN_LAST, /**< The latest element of the enumeration */

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,7 @@ void func_map_init_indexing_func(func_map_t& fmap)
10211021
fmap[DPNPFuncName::DPNP_FN_NONZERO][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_nonzero_default_c<float>};
10221022
fmap[DPNPFuncName::DPNP_FN_NONZERO][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_nonzero_default_c<double>};
10231023

1024+
fmap[DPNPFuncName::DPNP_FN_NONZERO_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_nonzero_ext_c<bool>};
10241025
fmap[DPNPFuncName::DPNP_FN_NONZERO_EXT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_nonzero_ext_c<int32_t>};
10251026
fmap[DPNPFuncName::DPNP_FN_NONZERO_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_nonzero_ext_c<int64_t>};
10261027
fmap[DPNPFuncName::DPNP_FN_NONZERO_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_nonzero_ext_c<float>};

dpnp/backend/kernels/dpnp_krnl_searching.cpp

Lines changed: 256 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2016-2020, Intel Corporation
2+
// Copyright (c) 2016-2023, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -27,6 +27,7 @@
2727

2828
#include <dpnp_iface.hpp>
2929
#include "dpnp_fptr.hpp"
30+
#include "dpnp_iterator.hpp"
3031
#include "dpnpc_memory_adapter.hpp"
3132
#include "queue_sycl.hpp"
3233

@@ -139,6 +140,258 @@ DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef,
139140
size_t,
140141
const DPCTLEventVectorRef) = dpnp_argmin_c<_DataType, _idx_DataType>;
141142

143+
144+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
145+
class dpnp_where_c_broadcast_kernel;
146+
147+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
148+
class dpnp_where_c_strides_kernel;
149+
150+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
151+
class dpnp_where_c_kernel;
152+
153+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
154+
DPCTLSyclEventRef dpnp_where_c(DPCTLSyclQueueRef q_ref,
155+
void* result_out,
156+
const size_t result_size,
157+
const size_t result_ndim,
158+
const shape_elem_type* result_shape,
159+
const shape_elem_type* result_strides,
160+
const void* condition_in,
161+
const size_t condition_size,
162+
const size_t condition_ndim,
163+
const shape_elem_type* condition_shape,
164+
const shape_elem_type* condition_strides,
165+
const void* input1_in,
166+
const size_t input1_size,
167+
const size_t input1_ndim,
168+
const shape_elem_type* input1_shape,
169+
const shape_elem_type* input1_strides,
170+
const void* input2_in,
171+
const size_t input2_size,
172+
const size_t input2_ndim,
173+
const shape_elem_type* input2_shape,
174+
const shape_elem_type* input2_strides,
175+
const DPCTLEventVectorRef dep_event_vec_ref)
176+
{
177+
/* avoid warning unused variable*/
178+
(void)dep_event_vec_ref;
179+
180+
DPCTLSyclEventRef event_ref = nullptr;
181+
182+
if (!condition_size || !input1_size || !input2_size)
183+
{
184+
return event_ref;
185+
}
186+
187+
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
188+
189+
bool* condition_data = static_cast<bool*>(const_cast<void*>(condition_in));
190+
_DataType_input1* input1_data = static_cast<_DataType_input1*>(const_cast<void*>(input1_in));
191+
_DataType_input2* input2_data = static_cast<_DataType_input2*>(const_cast<void*>(input2_in));
192+
_DataType_output* result = static_cast<_DataType_output*>(result_out);
193+
194+
bool use_broadcasting = !array_equal(input1_shape, input1_ndim, input2_shape, input2_ndim);
195+
use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input1_shape, input1_ndim);
196+
use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input2_shape, input2_ndim);
197+
198+
shape_elem_type* condition_shape_offsets = new shape_elem_type[condition_ndim];
199+
200+
get_shape_offsets_inkernel(condition_shape, condition_ndim, condition_shape_offsets);
201+
bool use_strides = !array_equal(condition_strides, condition_ndim, condition_shape_offsets, condition_ndim);
202+
delete[] condition_shape_offsets;
203+
204+
shape_elem_type* input1_shape_offsets = new shape_elem_type[input1_ndim];
205+
206+
get_shape_offsets_inkernel(input1_shape, input1_ndim, input1_shape_offsets);
207+
use_strides = use_strides || !array_equal(input1_strides, input1_ndim, input1_shape_offsets, input1_ndim);
208+
delete[] input1_shape_offsets;
209+
210+
shape_elem_type* input2_shape_offsets = new shape_elem_type[input2_ndim];
211+
212+
get_shape_offsets_inkernel(input2_shape, input2_ndim, input2_shape_offsets);
213+
use_strides = use_strides || !array_equal(input2_strides, input2_ndim, input2_shape_offsets, input2_ndim);
214+
delete[] input2_shape_offsets;
215+
216+
sycl::event event;
217+
sycl::range<1> gws(result_size);
218+
219+
if (use_broadcasting)
220+
{
221+
DPNPC_id<bool>* condition_it;
222+
const size_t condition_it_it_size_in_bytes = sizeof(DPNPC_id<bool>);
223+
condition_it = reinterpret_cast<DPNPC_id<bool>*>(dpnp_memory_alloc_c(q_ref, condition_it_it_size_in_bytes));
224+
new (condition_it) DPNPC_id<bool>(q_ref, condition_data, condition_shape, condition_strides, condition_ndim);
225+
226+
condition_it->broadcast_to_shape(result_shape, result_ndim);
227+
228+
DPNPC_id<_DataType_input1>* input1_it;
229+
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>);
230+
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(q_ref, input1_it_size_in_bytes));
231+
new (input1_it) DPNPC_id<_DataType_input1>(q_ref, input1_data, input1_shape, input1_strides, input1_ndim);
232+
233+
input1_it->broadcast_to_shape(result_shape, result_ndim);
234+
235+
DPNPC_id<_DataType_input2>* input2_it;
236+
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>);
237+
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(q_ref, input2_it_size_in_bytes));
238+
new (input2_it) DPNPC_id<_DataType_input2>(q_ref, input2_data, input2_shape, input2_strides, input2_ndim);
239+
240+
input2_it->broadcast_to_shape(result_shape, result_ndim);
241+
242+
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
243+
const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */
244+
{
245+
const bool condition = (*condition_it)[i];
246+
const _DataType_output input1_elem = (*input1_it)[i];
247+
const _DataType_output input2_elem = (*input2_it)[i];
248+
result[i] = (condition) ? input1_elem : input2_elem;
249+
}
250+
};
251+
auto kernel_func = [&](sycl::handler& cgh) {
252+
cgh.parallel_for<class dpnp_where_c_broadcast_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
253+
gws, kernel_parallel_for_func);
254+
};
255+
256+
q.submit(kernel_func).wait();
257+
258+
condition_it->~DPNPC_id();
259+
input1_it->~DPNPC_id();
260+
input2_it->~DPNPC_id();
261+
262+
return event_ref;
263+
}
264+
else if (use_strides)
265+
{
266+
if ((result_ndim != condition_ndim) || (result_ndim != input1_ndim) || (result_ndim != input2_ndim))
267+
{
268+
throw std::runtime_error("Result ndim=" + std::to_string(result_ndim) +
269+
" mismatches with either condition ndim=" + std::to_string(condition_ndim) +
270+
" or input1 ndim=" + std::to_string(input1_ndim) +
271+
" or input2 ndim=" + std::to_string(input2_ndim));
272+
}
273+
274+
/* memory transfer optimization, use USM-host for temporary speeds up tranfer to device */
275+
using usm_host_allocatorT = sycl::usm_allocator<shape_elem_type, sycl::usm::alloc::host>;
276+
277+
size_t strides_size = 4 * result_ndim;
278+
shape_elem_type* dev_strides_data = sycl::malloc_device<shape_elem_type>(strides_size, q);
279+
280+
/* create host temporary for packed strides managed by shared pointer */
281+
auto strides_host_packed =
282+
std::vector<shape_elem_type, usm_host_allocatorT>(strides_size, usm_host_allocatorT(q));
283+
284+
/* packed vector is concatenation of result_strides, condition_strides, input1_strides and input2_strides */
285+
std::copy(result_strides, result_strides + result_ndim, strides_host_packed.begin());
286+
std::copy(condition_strides, condition_strides + result_ndim, strides_host_packed.begin() + result_ndim);
287+
std::copy(input1_strides, input1_strides + result_ndim, strides_host_packed.begin() + 2 * result_ndim);
288+
std::copy(input2_strides, input2_strides + result_ndim, strides_host_packed.begin() + 3 * result_ndim);
289+
290+
auto copy_strides_ev =
291+
q.copy<shape_elem_type>(strides_host_packed.data(), dev_strides_data, strides_host_packed.size());
292+
293+
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
294+
const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */
295+
{
296+
const shape_elem_type* result_strides_data = &dev_strides_data[0];
297+
const shape_elem_type* condition_strides_data = &dev_strides_data[1];
298+
const shape_elem_type* input1_strides_data = &dev_strides_data[2];
299+
const shape_elem_type* input2_strides_data = &dev_strides_data[3];
300+
301+
size_t condition_id = 0;
302+
size_t input1_id = 0;
303+
size_t input2_id = 0;
304+
305+
for (size_t i = 0; i < result_ndim; ++i)
306+
{
307+
const size_t output_xyz_id =
308+
get_xyz_id_by_id_inkernel(output_id, result_strides_data, result_ndim, i);
309+
condition_id += output_xyz_id * condition_strides_data[i];
310+
input1_id += output_xyz_id * input1_strides_data[i];
311+
input2_id += output_xyz_id * input2_strides_data[i];
312+
}
313+
314+
const bool condition = condition_data[condition_id];
315+
const _DataType_output input1_elem = input1_data[input1_id];
316+
const _DataType_output input2_elem = input2_data[input2_id];
317+
result[output_id] = (condition) ? input1_elem : input2_elem;
318+
}
319+
};
320+
auto kernel_func = [&](sycl::handler& cgh) {
321+
cgh.depends_on(copy_strides_ev);
322+
cgh.parallel_for<class dpnp_where_c_strides_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
323+
gws, kernel_parallel_for_func);
324+
};
325+
326+
q.submit(kernel_func).wait();
327+
328+
sycl::free(dev_strides_data, q);
329+
return event_ref;
330+
}
331+
else
332+
{
333+
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
334+
const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */
335+
336+
const bool condition = condition_data[i];
337+
const _DataType_output input1_elem = input1_data[i];
338+
const _DataType_output input2_elem = input2_data[i];
339+
result[i] = (condition) ? input1_elem : input2_elem;
340+
};
341+
auto kernel_func = [&](sycl::handler& cgh) {
342+
cgh.parallel_for<class dpnp_where_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>(
343+
gws, kernel_parallel_for_func);
344+
};
345+
event = q.submit(kernel_func);
346+
}
347+
348+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
349+
return DPCTLEvent_Copy(event_ref);
350+
351+
return event_ref;
352+
}
353+
354+
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
355+
DPCTLSyclEventRef (*dpnp_where_ext_c)(DPCTLSyclQueueRef,
356+
void*,
357+
const size_t,
358+
const size_t,
359+
const shape_elem_type*,
360+
const shape_elem_type*,
361+
const void*,
362+
const size_t,
363+
const size_t,
364+
const shape_elem_type*,
365+
const shape_elem_type*,
366+
const void*,
367+
const size_t,
368+
const size_t,
369+
const shape_elem_type*,
370+
const shape_elem_type*,
371+
const void*,
372+
const size_t,
373+
const size_t,
374+
const shape_elem_type*,
375+
const shape_elem_type*,
376+
const DPCTLEventVectorRef) = dpnp_where_c<_DataType_output, _DataType_input1, _DataType_input2>;
377+
378+
template <DPNPFuncType FT1, DPNPFuncType... FTs>
379+
static void func_map_searching_2arg_3type_core(func_map_t& fmap)
380+
{
381+
((fmap[DPNPFuncName::DPNP_FN_WHERE_EXT][FT1][FTs] =
382+
{populate_func_types<FT1, FTs>(),
383+
(void*)dpnp_where_ext_c<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
384+
func_type_map_t::find_type<FT1>,
385+
func_type_map_t::find_type<FTs>>}),
386+
...);
387+
}
388+
389+
template <DPNPFuncType... FTs>
390+
static void func_map_searching_2arg_3type_helper(func_map_t& fmap)
391+
{
392+
((func_map_searching_2arg_3type_core<FTs, FTs...>(fmap)), ...);
393+
}
394+
142395
void func_map_init_searching(func_map_t& fmap)
143396
{
144397
fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_argmax_default_c<int32_t, int32_t>};
@@ -177,5 +430,7 @@ void func_map_init_searching(func_map_t& fmap)
177430
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_argmin_ext_c<double, int32_t>};
178431
fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_argmin_ext_c<double, int64_t>};
179432

433+
func_map_searching_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL, eft_C64, eft_C128>(fmap);
434+
180435
return;
181436
}

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
356356
DPNP_FN_VANDER_EXT
357357
DPNP_FN_VAR
358358
DPNP_FN_VAR_EXT
359+
DPNP_FN_WHERE_EXT
359360
DPNP_FN_ZEROS
360361
DPNP_FN_ZEROS_LIKE
361362

@@ -578,6 +579,7 @@ Searching functions
578579
"""
579580
cpdef dpnp_descriptor dpnp_argmax(dpnp_descriptor array1)
580581
cpdef dpnp_descriptor dpnp_argmin(dpnp_descriptor array1)
582+
cpdef dpnp_descriptor dpnp_where(dpnp_descriptor cond_obj, dpnp_descriptor x_obj, dpnp_descriptor y_obj)
581583

582584
"""
583585
Trigonometric functions

0 commit comments

Comments
 (0)