Skip to content

Commit 96b9759

Browse files
Reuse dpctl.tensor.pow for dpnp.power
1 parent 771653b commit 96b9759

File tree

4 files changed

+64
-67
lines changed

4 files changed

+64
-67
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,6 @@ cpdef dpnp_descriptor dpnp_maximum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
490490
cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
491491
dpnp_descriptor out=*, object where=*)
492492
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
493-
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
494-
dpnp_descriptor out=*, object where=*)
495493
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
496494
dpnp_descriptor out=*, object where=*)
497495

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ __all__ += [
6060
"dpnp_nanprod",
6161
"dpnp_nansum",
6262
"dpnp_negative",
63-
"dpnp_power",
6463
"dpnp_prod",
6564
"dpnp_remainder",
6665
"dpnp_sign",
@@ -476,14 +475,6 @@ cpdef utils.dpnp_descriptor dpnp_negative(dpnp_descriptor x1):
476475
return call_fptr_1in_1out_strides(DPNP_FN_NEGATIVE_EXT, x1)
477476

478477

479-
cpdef utils.dpnp_descriptor dpnp_power(utils.dpnp_descriptor x1_obj,
480-
utils.dpnp_descriptor x2_obj,
481-
object dtype=None,
482-
utils.dpnp_descriptor out=None,
483-
object where=True):
484-
return call_fptr_2in_1out_strides(DPNP_FN_POWER_EXT, x1_obj, x2_obj, dtype, out, where, func_name="power")
485-
486-
487478
cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor x1,
488479
object axis=None,
489480
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"dpnp_logical_xor",
5959
"dpnp_multiply",
6060
"dpnp_not_equal",
61+
"dpnp_power",
6162
"dpnp_sin",
6263
"dpnp_sqrt",
6364
"dpnp_square",
@@ -844,6 +845,44 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
844845
return dpnp_array._create_from_usm_ndarray(res_usm)
845846

846847

848+
_power_docstring_ = """
849+
power(x1, x2, out=None, order="K")
850+
Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array
851+
`x1` with the respective element `x2_i` of the input array `x2`.
852+
Args:
853+
x1 (dpnp.ndarray):
854+
First input array, expected to have numeric data type.
855+
x2 (dpnp.ndarray):
856+
Second input array, also expected to have numeric data type.
857+
out ({None, dpnp.ndarray}, optional):
858+
Output array to populate. Array must have the correct
859+
shape and the expected data type.
860+
order ("C","F","A","K", None, optional):
861+
Output array, if parameter `out` is `None`.
862+
Default: "K".
863+
Returns:
864+
usm_ndarray:
865+
An array containing the result of element-wise of raising each element
866+
to a specified power.
867+
The data type of the returned array is determined by the Type Promotion Rules.
868+
"""
869+
870+
871+
def dpnp_power(x1, x2, out=None, order="K"):
872+
"""Invokes pow() from dpctl.tensor implementation for power() function."""
873+
874+
# dpctl.tensor only works with usm_ndarray or scalar
875+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
876+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
877+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
878+
879+
func = BinaryElementwiseFunc(
880+
"pow", ti._pow_result_type, ti._pow, _power_docstring_
881+
)
882+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
883+
return dpnp_array._create_from_usm_ndarray(res_usm)
884+
885+
847886
_sin_docstring = """
848887
sin(x, out=None, order='K')
849888
Computes sine for each element `x_i` of input array `x`.

dpnp/dpnp_iface_mathematical.py

Lines changed: 25 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@
4848

4949
from .dpnp_algo import *
5050
from .dpnp_algo.dpnp_elementwise_common import (
51+
check_nd_call_func,
5152
dpnp_add,
5253
dpnp_divide,
5354
dpnp_floor_divide,
5455
dpnp_multiply,
56+
dpnp_power,
5557
dpnp_subtract,
5658
)
5759
from .dpnp_utils import *
@@ -1436,7 +1438,18 @@ def negative(x1, **kwargs):
14361438
return call_origin(numpy.negative, x1, **kwargs)
14371439

14381440

1439-
def power(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs):
1441+
def power(
1442+
x1,
1443+
x2,
1444+
/,
1445+
out=None,
1446+
*,
1447+
order="K",
1448+
where=True,
1449+
dtype=None,
1450+
subok=True,
1451+
**kwargs,
1452+
):
14401453
"""
14411454
First array elements raised to powers from second array, element-wise.
14421455
@@ -1455,7 +1468,6 @@ def power(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs):
14551468
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
14561469
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
14571470
Parameters `where`, `dtype` and `subok` are supported with their default values.
1458-
Keyword arguments ``kwargs`` are currently unsupported.
14591471
Otherwise the function will be executed sequentially on CPU.
14601472
Input array data types are limited by supported DPNP :ref:`Data types`.
14611473
@@ -1477,60 +1489,17 @@ def power(x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs):
14771489
14781490
"""
14791491

1480-
if kwargs:
1481-
pass
1482-
elif where is not True:
1483-
pass
1484-
elif dtype is not None:
1485-
pass
1486-
elif subok is not True:
1487-
pass
1488-
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
1489-
# at least either x1 or x2 has to be an array
1490-
pass
1491-
else:
1492-
# get USM type and queue to copy scalar from the host memory into a USM allocation
1493-
usm_type, queue = (
1494-
get_usm_allocations([x1, x2])
1495-
if dpnp.isscalar(x1) or dpnp.isscalar(x2)
1496-
else (None, None)
1497-
)
1498-
1499-
x1_desc = dpnp.get_dpnp_descriptor(
1500-
x1,
1501-
copy_when_strides=False,
1502-
copy_when_nondefault_queue=False,
1503-
alloc_usm_type=usm_type,
1504-
alloc_queue=queue,
1505-
)
1506-
x2_desc = dpnp.get_dpnp_descriptor(
1507-
x2,
1508-
copy_when_strides=False,
1509-
copy_when_nondefault_queue=False,
1510-
alloc_usm_type=usm_type,
1511-
alloc_queue=queue,
1512-
)
1513-
if x1_desc and x2_desc:
1514-
if out is not None:
1515-
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
1516-
raise TypeError(
1517-
"return array must be of supported array type"
1518-
)
1519-
out_desc = (
1520-
dpnp.get_dpnp_descriptor(
1521-
out, copy_when_nondefault_queue=False
1522-
)
1523-
or None
1524-
)
1525-
else:
1526-
out_desc = None
1527-
1528-
return dpnp_power(
1529-
x1_desc, x2_desc, dtype=dtype, out=out_desc, where=where
1530-
).get_pyobj()
1531-
1532-
return call_origin(
1533-
numpy.power, x1, x2, dtype=dtype, out=out, where=where, **kwargs
1492+
return check_nd_call_func(
1493+
numpy.power,
1494+
dpnp_power,
1495+
x1,
1496+
x2,
1497+
out=out,
1498+
where=where,
1499+
order=order,
1500+
dtype=dtype,
1501+
subok=subok,
1502+
**kwargs,
15341503
)
15351504

15361505

0 commit comments

Comments
 (0)