Skip to content

Commit e753e72

Browse files
committed
dpnp.power() doesn't work properly with a scalar
1 parent 1c7b85f commit e753e72

File tree

9 files changed

+177
-101
lines changed

9 files changed

+177
-101
lines changed

dpnp/backend/include/dpnp_gen_2arg_3type_tbl.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ MACRO_2ARG_3TYPES_OP(dpnp_multiply_c,
176176
MACRO_UNPACK_TYPES(float, double, std::complex<float>, std::complex<double>))
177177

178178
MACRO_2ARG_3TYPES_OP(dpnp_power_c,
179-
sycl::pow((double)input1_elem, (double)input2_elem),
180-
nullptr,
181-
std::false_type,
179+
static_cast<_DataType_output>(std::pow(input1_elem, input2_elem)),
180+
sycl::pow(x1, x2),
181+
MACRO_UNPACK_TYPES(float, double),
182182
oneapi::mkl::vm::pow,
183-
MACRO_UNPACK_TYPES(float, double))
183+
MACRO_UNPACK_TYPES(float, double, std::complex<float>, std::complex<double>))
184184

185185
MACRO_2ARG_3TYPES_OP(dpnp_subtract_c,
186186
input1_elem - input2_elem,

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,28 +1247,34 @@ static void func_map_elemwise_2arg_3type_core(func_map_t& fmap)
12471247
func_type_map_t::find_type<FT1>,
12481248
func_type_map_t::find_type<FTs>>}),
12491249
...);
1250+
((fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][FT1][FTs] =
1251+
{get_divide_res_type<FT1, FTs>(),
1252+
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs>()>,
1253+
func_type_map_t::find_type<FT1>,
1254+
func_type_map_t::find_type<FTs>>,
1255+
get_divide_res_type<FT1, FTs, std::false_type>(),
1256+
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs, std::false_type>()>,
1257+
func_type_map_t::find_type<FT1>,
1258+
func_type_map_t::find_type<FTs>>}),
1259+
...);
12501260
((fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][FT1][FTs] =
12511261
{populate_func_types<FT1, FTs>(),
12521262
(void*)dpnp_multiply_c_ext<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
12531263
func_type_map_t::find_type<FT1>,
12541264
func_type_map_t::find_type<FTs>>}),
12551265
...);
1266+
((fmap[DPNPFuncName::DPNP_FN_POWER_EXT][FT1][FTs] =
1267+
{populate_func_types<FT1, FTs>(),
1268+
(void*)dpnp_power_c_ext<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
1269+
func_type_map_t::find_type<FT1>,
1270+
func_type_map_t::find_type<FTs>>}),
1271+
...);
12561272
((fmap[DPNPFuncName::DPNP_FN_SUBTRACT_EXT][FT1][FTs] =
12571273
{populate_func_types<FT1, FTs>(),
12581274
(void*)dpnp_subtract_c_ext<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>,
12591275
func_type_map_t::find_type<FT1>,
12601276
func_type_map_t::find_type<FTs>>}),
12611277
...);
1262-
((fmap[DPNPFuncName::DPNP_FN_DIVIDE_EXT][FT1][FTs] =
1263-
{get_divide_res_type<FT1, FTs>(),
1264-
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs>()>,
1265-
func_type_map_t::find_type<FT1>,
1266-
func_type_map_t::find_type<FTs>>,
1267-
get_divide_res_type<FT1, FTs, std::false_type>(),
1268-
(void*)dpnp_divide_c_ext<func_type_map_t::find_type<get_divide_res_type<FT1, FTs, std::false_type>()>,
1269-
func_type_map_t::find_type<FT1>,
1270-
func_type_map_t::find_type<FTs>>}),
1271-
...);
12721278
}
12731279

12741280
template <DPNPFuncType... FTs>
@@ -1855,39 +1861,6 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
18551861
fmap[DPNPFuncName::DPNP_FN_POWER][eft_DBL][eft_DBL] = {eft_DBL,
18561862
(void*)dpnp_power_c_default<double, double, double>};
18571863

1858-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_INT][eft_INT] = {eft_INT,
1859-
(void*)dpnp_power_c_ext<int32_t, int32_t, int32_t>};
1860-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_INT][eft_LNG] = {eft_LNG,
1861-
(void*)dpnp_power_c_ext<int64_t, int32_t, int64_t>};
1862-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_INT][eft_FLT] = {eft_DBL,
1863-
(void*)dpnp_power_c_ext<double, int32_t, float>};
1864-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_INT][eft_DBL] = {eft_DBL,
1865-
(void*)dpnp_power_c_ext<double, int32_t, double>};
1866-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_LNG][eft_INT] = {eft_LNG,
1867-
(void*)dpnp_power_c_ext<int64_t, int64_t, int32_t>};
1868-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_LNG][eft_LNG] = {eft_LNG,
1869-
(void*)dpnp_power_c_ext<int64_t, int64_t, int64_t>};
1870-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_LNG][eft_FLT] = {eft_DBL,
1871-
(void*)dpnp_power_c_ext<double, int64_t, float>};
1872-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_LNG][eft_DBL] = {eft_DBL,
1873-
(void*)dpnp_power_c_ext<double, int64_t, double>};
1874-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_FLT][eft_INT] = {eft_DBL,
1875-
(void*)dpnp_power_c_ext<double, float, int32_t>};
1876-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_FLT][eft_LNG] = {eft_DBL,
1877-
(void*)dpnp_power_c_ext<double, float, int64_t>};
1878-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1879-
(void*)dpnp_power_c_ext<float, float, float>};
1880-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_FLT][eft_DBL] = {eft_DBL,
1881-
(void*)dpnp_power_c_ext<double, float, double>};
1882-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_DBL][eft_INT] = {eft_DBL,
1883-
(void*)dpnp_power_c_ext<double, double, int32_t>};
1884-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_DBL][eft_LNG] = {eft_DBL,
1885-
(void*)dpnp_power_c_ext<double, double, int64_t>};
1886-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_DBL][eft_FLT] = {eft_DBL,
1887-
(void*)dpnp_power_c_ext<double, double, float>};
1888-
fmap[DPNPFuncName::DPNP_FN_POWER_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1889-
(void*)dpnp_power_c_ext<double, double, double>};
1890-
18911864
fmap[DPNPFuncName::DPNP_FN_SUBTRACT][eft_INT][eft_INT] = {
18921865
eft_INT, (void*)dpnp_subtract_c_default<int32_t, int32_t, int32_t>};
18931866
fmap[DPNPFuncName::DPNP_FN_SUBTRACT][eft_INT][eft_LNG] = {

dpnp/dpnp_array.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,11 @@ def __int__(self):
211211

212212
# '__invert__',
213213
# '__ior__',
214-
# '__ipow__',
214+
215+
def __ipow__(self, other):
216+
dpnp.power(self, other, out=self)
217+
return self
218+
215219
# '__irshift__',
216220
# '__isub__',
217221
# '__iter__',
@@ -279,7 +283,10 @@ def __rmul__(self, other):
279283
return dpnp.multiply(other, self)
280284

281285
# '__ror__',
282-
# '__rpow__',
286+
287+
def __rpow__(self, other):
288+
return dpnp.power(other, self)
289+
283290
# '__rrshift__',
284291
# '__rshift__',
285292

dpnp/dpnp_iface_mathematical.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
import dpnp
4747
import numpy
48+
import dpctl.tensor as dpt
4849

4950

5051
__all__ = [
@@ -1325,18 +1326,35 @@ def negative(x1, **kwargs):
13251326
return call_origin(numpy.negative, x1, **kwargs)
13261327

13271328

1328-
def power(x1, x2, dtype=None, out=None, where=True, **kwargs):
1329+
def power(x1,
1330+
x2,
1331+
/,
1332+
out=None,
1333+
*,
1334+
where=True,
1335+
dtype=None,
1336+
subok=True,
1337+
**kwargs):
13291338
"""
13301339
First array elements raised to powers from second array, element-wise.
13311340
1341+
An integer type (of either negative or positive value, but not zero)
1342+
raised to a negative integer power will return an array of zeroes.
1343+
13321344
For full documentation refer to :obj:`numpy.power`.
13331345
1346+
Returns
1347+
-------
1348+
y : dpnp.ndarray
1349+
The bases in `x1` raised to the exponents in `x2`.
1350+
13341351
Limitations
13351352
-----------
1336-
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
1337-
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
1353+
Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar,
1354+
but not both (at least either `x1` or `x2` should be as :class:`dpnp.ndarray`).
1355+
Parameters `where`, `dtype` and `subok` are supported with their default values.
13381356
Keyword arguments ``kwargs`` are currently unsupported.
1339-
Otherwise the functions will be executed sequentially on CPU.
1357+
Otherwise the function will be executed sequentially on CPU.
13401358
Input array data types are limited by supported DPNP :ref:`Data types`.
13411359
13421360
See Also
@@ -1348,40 +1366,44 @@ def power(x1, x2, dtype=None, out=None, where=True, **kwargs):
13481366
13491367
Example
13501368
-------
1351-
>>> import dpnp as np
1352-
>>> a = np.array([1, 2, 3, 4, 5])
1353-
>>> b = np.array([2, 2, 2, 2, 2])
1354-
>>> result = np.power(a, b)
1369+
>>> import dpnp as dp
1370+
>>> a = dp.array([1, 2, 3, 4, 5])
1371+
>>> b = dp.array([2, 2, 2, 2, 2])
1372+
>>> result = dp.power(a, b)
13551373
>>> [x for x in result]
13561374
[1, 4, 9, 16, 25]
13571375
13581376
"""
13591377

1360-
x1_is_scalar = dpnp.isscalar(x1)
1361-
x2_is_scalar = dpnp.isscalar(x2)
1362-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
1363-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
1378+
if where is not True:
1379+
pass
1380+
elif dtype is not None:
1381+
pass
1382+
elif subok is not True:
1383+
pass
1384+
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
1385+
# at least either x1 or x2 has to be an array
1386+
pass
1387+
else:
1388+
# get USM type and queue to copy scalar from the host memory into a USM allocation
1389+
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)
13641390

1365-
if x1_desc and x2_desc and not kwargs:
1366-
if not x1_desc and not x1_is_scalar:
1367-
pass
1368-
elif not x2_desc and not x2_is_scalar:
1369-
pass
1370-
elif x1_is_scalar and x2_is_scalar:
1371-
pass
1372-
elif x1_desc and x1_desc.ndim == 0:
1373-
pass
1374-
elif x2_desc and x2_desc.ndim == 0:
1375-
pass
1376-
elif dtype is not None:
1377-
pass
1378-
elif not where:
1379-
pass
1391+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
1392+
alloc_usm_type=usm_type, alloc_queue=queue)
1393+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
1394+
alloc_usm_type=usm_type, alloc_queue=queue)
1395+
1396+
if out is not None:
1397+
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
1398+
raise TypeError("return array must be of supported array type")
1399+
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
13801400
else:
1381-
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) if out is not None else None
1382-
return dpnp_power(x1_desc, x2_desc, dtype, out_desc, where).get_pyobj()
1401+
out_desc = None
1402+
1403+
if x1_desc and x2_desc:
1404+
return dpnp_power(x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where).get_pyobj()
13831405

1384-
return call_origin(numpy.power, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
1406+
return call_origin(numpy.power, x1, x2, out=out, where=where, dtype=dtype, subok=subok, **kwargs)
13851407

13861408

13871409
def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=True):

tests/skipped_tests.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,6 @@ tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticModf::test_m
761761
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_10_{name='remainder', nargs=2}::test_raises_with_numpy_input
762762
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_11_{name='mod', nargs=2}::test_raises_with_numpy_input
763763
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_1_{name='angle', nargs=1}::test_raises_with_numpy_input
764-
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_5_{name='power', nargs=2}::test_raises_with_numpy_input
765764
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_8_{name='floor_divide', nargs=2}::test_raises_with_numpy_input
766765

767766
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,6 @@ tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticBinary2_para
976976
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_10_{name='remainder', nargs=2}::test_raises_with_numpy_input
977977
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_11_{name='mod', nargs=2}::test_raises_with_numpy_input
978978
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_1_{name='angle', nargs=1}::test_raises_with_numpy_input
979-
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_5_{name='power', nargs=2}::test_raises_with_numpy_input
980979
tests/third_party/cupy/math_tests/test_arithmetic.py::TestArithmeticRaisesWithNumpyInput_param_8_{name='floor_divide', nargs=2}::test_raises_with_numpy_input
981980

982981
tests/third_party/cupy/math_tests/test_explog.py::TestExplog::test_logaddexp

0 commit comments

Comments
 (0)