Skip to content

Commit da003cb

Browse files
committed
Rolled back removal of shape_type_c and dpnp_subtract_c_ext
1 parent 1588206 commit da003cb

File tree

3 files changed

+51
-27
lines changed

3 files changed

+51
-27
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -339,32 +339,34 @@ enum class DPNPFuncName : size_t
339339
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
340340
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
341341
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
342-
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
343-
*/
344-
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */
345-
DPNP_FN_STD, /**< Used in numpy.std() impl */
346-
DPNP_FN_SUBTRACT, /**< Used in numpy.subtract() impl */
347-
DPNP_FN_SUM, /**< Used in numpy.sum() impl */
348-
DPNP_FN_SVD, /**< Used in numpy.linalg.svd() impl */
349-
DPNP_FN_TAKE, /**< Used in numpy.take() impl */
350-
DPNP_FN_TAN, /**< Used in numpy.tan() impl */
351-
DPNP_FN_TANH, /**< Used in numpy.tanh() impl */
352-
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */
353-
DPNP_FN_TRACE, /**< Used in numpy.trace() impl */
354-
DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra
355-
parameters */
356-
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */
357-
DPNP_FN_TRAPZ_EXT, /**< Used in numpy.trapz() impl, requires extra
358-
parameters */
359-
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
360-
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
361-
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
362-
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */
363-
DPNP_FN_VANDER, /**< Used in numpy.vander() impl */
364-
DPNP_FN_VAR, /**< Used in numpy.var() impl */
365-
DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */
366-
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */
367-
DPNP_FN_LAST, /**< The latest element of the enumeration */
342+
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
343+
*/
344+
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */
345+
DPNP_FN_STD, /**< Used in numpy.std() impl */
346+
DPNP_FN_SUBTRACT, /**< Used in numpy.subtract() impl */
347+
DPNP_FN_SUBTRACT_EXT, /**< Used in numpy.subtract() impl, requires extra
348+
parameters */
349+
DPNP_FN_SUM, /**< Used in numpy.sum() impl */
350+
DPNP_FN_SVD, /**< Used in numpy.linalg.svd() impl */
351+
DPNP_FN_TAKE, /**< Used in numpy.take() impl */
352+
DPNP_FN_TAN, /**< Used in numpy.tan() impl */
353+
DPNP_FN_TANH, /**< Used in numpy.tanh() impl */
354+
DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */
355+
DPNP_FN_TRACE, /**< Used in numpy.trace() impl */
356+
DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra
357+
parameters */
358+
DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */
359+
DPNP_FN_TRAPZ_EXT, /**< Used in numpy.trapz() impl, requires extra
360+
parameters */
361+
DPNP_FN_TRI, /**< Used in numpy.tri() impl */
362+
DPNP_FN_TRIL, /**< Used in numpy.tril() impl */
363+
DPNP_FN_TRIU, /**< Used in numpy.triu() impl */
364+
DPNP_FN_TRUNC, /**< Used in numpy.trunc() impl */
365+
DPNP_FN_VANDER, /**< Used in numpy.vander() impl */
366+
DPNP_FN_VAR, /**< Used in numpy.var() impl */
367+
DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */
368+
DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */
369+
DPNP_FN_LAST, /**< The latest element of the enumeration */
368370
};
369371

370372
/**

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,19 @@ static constexpr DPNPFuncType get_divide_res_type()
13941394
return widest_type;
13951395
}
13961396

1397+
template <DPNPFuncType FT1, DPNPFuncType... FTs>
1398+
static void func_map_elemwise_2arg_3type_core(func_map_t &fmap)
1399+
{
1400+
// dpnp_subtract_c_ext is implicitly used by dpnp_ptp_c
1401+
((fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][FT1][FTs] =
1402+
{populate_func_types<FT1, FTs>(),
1403+
(void *)dpnp_subtract_c_ext<
1404+
func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
1405+
func_type_map_t::find_type<FT1>,
1406+
func_type_map_t::find_type<FTs>>}),
1407+
...);
1408+
}
1409+
13971410
template <DPNPFuncType FT1, DPNPFuncType... FTs>
13981411
static void func_map_elemwise_2arg_3type_short_core(func_map_t &fmap)
13991412
{
@@ -1441,6 +1454,12 @@ static void func_map_elemwise_2arg_3type_short_core(func_map_t &fmap)
14411454
...);
14421455
}
14431456

1457+
template <DPNPFuncType... FTs>
1458+
static void func_map_elemwise_2arg_3type_helper(func_map_t &fmap)
1459+
{
1460+
((func_map_elemwise_2arg_3type_core<FTs, FTs...>(fmap)), ...);
1461+
}
1462+
14441463
template <DPNPFuncType... FTs>
14451464
static void func_map_elemwise_2arg_3type_short_helper(func_map_t &fmap)
14461465
{
@@ -1881,6 +1900,9 @@ static void func_map_init_elemwise_2arg_3type(func_map_t &fmap)
18811900
fmap[DPNPFuncName::DPNP_FN_SUBTRACT][eft_DBL][eft_DBL] = {
18821901
eft_DBL, (void *)dpnp_subtract_c_default<double, double, double>};
18831902

1903+
func_map_elemwise_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT,
1904+
eft_DBL, eft_C64, eft_C128>(fmap);
1905+
18841906
func_map_elemwise_2arg_3type_short_helper<eft_INT, eft_LNG, eft_FLT,
18851907
eft_DBL>(fmap);
18861908

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
cimport dpctl as c_dpctl
2828
from libcpp cimport bool as cpp_bool
2929

30-
from dpnp.dpnp_algo cimport shape_elem_type
30+
from dpnp.dpnp_algo cimport shape_elem_type, shape_type_c
3131
from dpnp.dpnp_utils.dpnp_algo_utils cimport dpnp_descriptor
3232

3333

0 commit comments

Comments
 (0)