Skip to content

Commit 165daad

Browse files
committed
Reuse add(), multiply() and subtract() from dpctl
1 parent 2e8f40c commit 165daad

File tree

10 files changed

+242
-145
lines changed

10 files changed

+242
-145
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3636
cdef enum DPNPFuncName "DPNPFuncName":
3737
DPNP_FN_ABSOLUTE
3838
DPNP_FN_ABSOLUTE_EXT
39-
DPNP_FN_ADD
40-
DPNP_FN_ADD_EXT
4139
DPNP_FN_ALL
4240
DPNP_FN_ALL_EXT
4341
DPNP_FN_ALLCLOSE
@@ -117,7 +115,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
117115
DPNP_FN_DIAG_INDICES_EXT
118116
DPNP_FN_DIAGONAL
119117
DPNP_FN_DIAGONAL_EXT
120-
DPNP_FN_DIVIDE
121118
DPNP_FN_DOT
122119
DPNP_FN_DOT_EXT
123120
DPNP_FN_EDIFF1D
@@ -203,8 +200,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
203200
DPNP_FN_MINIMUM_EXT
204201
DPNP_FN_MODF
205202
DPNP_FN_MODF_EXT
206-
DPNP_FN_MULTIPLY
207-
DPNP_FN_MULTIPLY_EXT
208203
DPNP_FN_NANVAR
209204
DPNP_FN_NANVAR_EXT
210205
DPNP_FN_NEGATIVE
@@ -323,8 +318,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
323318
DPNP_FN_SQUARE_EXT
324319
DPNP_FN_STD
325320
DPNP_FN_STD_EXT
326-
DPNP_FN_SUBTRACT
327-
DPNP_FN_SUBTRACT_EXT
328321
DPNP_FN_SUM
329322
DPNP_FN_SUM_EXT
330323
DPNP_FN_SVD
@@ -523,8 +516,6 @@ cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1)
523516
"""
524517
Mathematical functions
525518
"""
526-
cpdef dpnp_descriptor dpnp_add(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
527-
dpnp_descriptor out=*, object where=*)
528519
cpdef dpnp_descriptor dpnp_arctan2(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
529520
dpnp_descriptor out=*, object where=*)
530521
cpdef dpnp_descriptor dpnp_hypot(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
@@ -533,15 +524,11 @@ cpdef dpnp_descriptor dpnp_maximum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
533524
dpnp_descriptor out=*, object where=*)
534525
cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
535526
dpnp_descriptor out=*, object where=*)
536-
cpdef dpnp_descriptor dpnp_multiply(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
537-
dpnp_descriptor out=*, object where=*)
538527
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
539528
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
540529
dpnp_descriptor out=*, object where=*)
541530
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
542531
dpnp_descriptor out=*, object where=*)
543-
cpdef dpnp_descriptor dpnp_subtract(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
544-
dpnp_descriptor out=*, object where=*)
545532

546533

547534
"""

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ and the rest of the library
3737

3838
__all__ += [
3939
"dpnp_absolute",
40-
"dpnp_add",
4140
"dpnp_arctan2",
4241
"dpnp_around",
4342
"dpnp_ceil",
@@ -57,7 +56,6 @@ __all__ += [
5756
"dpnp_maximum",
5857
"dpnp_minimum",
5958
"dpnp_modf",
60-
"dpnp_multiply",
6159
"dpnp_nancumprod",
6260
"dpnp_nancumsum",
6361
"dpnp_nanprod",
@@ -67,7 +65,6 @@ __all__ += [
6765
"dpnp_prod",
6866
"dpnp_remainder",
6967
"dpnp_sign",
70-
"dpnp_subtract",
7168
"dpnp_sum",
7269
"dpnp_trapz",
7370
"dpnp_trunc"
@@ -123,14 +120,6 @@ cpdef utils.dpnp_descriptor dpnp_absolute(utils.dpnp_descriptor x1):
123120
return result
124121

125122

126-
cpdef utils.dpnp_descriptor dpnp_add(utils.dpnp_descriptor x1_obj,
127-
utils.dpnp_descriptor x2_obj,
128-
object dtype=None,
129-
utils.dpnp_descriptor out=None,
130-
object where=True):
131-
return call_fptr_2in_1out_strides(DPNP_FN_ADD_EXT, x1_obj, x2_obj, dtype, out, where)
132-
133-
134123
cpdef utils.dpnp_descriptor dpnp_arctan2(utils.dpnp_descriptor x1_obj,
135124
utils.dpnp_descriptor x2_obj,
136125
object dtype=None,
@@ -426,14 +415,6 @@ cpdef tuple dpnp_modf(utils.dpnp_descriptor x1):
426415
return (result1.get_pyobj(), result2.get_pyobj())
427416

428417

429-
cpdef utils.dpnp_descriptor dpnp_multiply(utils.dpnp_descriptor x1_obj,
430-
utils.dpnp_descriptor x2_obj,
431-
object dtype=None,
432-
utils.dpnp_descriptor out=None,
433-
object where=True):
434-
return call_fptr_2in_1out_strides(DPNP_FN_MULTIPLY_EXT, x1_obj, x2_obj, dtype, out, where)
435-
436-
437418
cpdef utils.dpnp_descriptor dpnp_nancumprod(utils.dpnp_descriptor x1):
438419
cur_x1 = dpnp_copy(x1).get_pyobj()
439420

@@ -586,14 +567,6 @@ cpdef utils.dpnp_descriptor dpnp_sign(utils.dpnp_descriptor x1):
586567
return call_fptr_1in_1out_strides(DPNP_FN_SIGN_EXT, x1)
587568

588569

589-
cpdef utils.dpnp_descriptor dpnp_subtract(utils.dpnp_descriptor x1_obj,
590-
utils.dpnp_descriptor x2_obj,
591-
object dtype=None,
592-
utils.dpnp_descriptor out=None,
593-
object where=True):
594-
return call_fptr_2in_1out_strides(DPNP_FN_SUBTRACT_EXT, x1_obj, x2_obj, dtype, out, where)
595-
596-
597570
cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor x1,
598571
object axis=None,
599572
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,54 @@
3838

3939

4040
__all__ = [
41-
"dpnp_divide"
41+
"dpnp_add",
42+
"dpnp_divide",
43+
"dpnp_multiply",
44+
"dpnp_subtract"
4245
]
4346

4447

48+
_add_docstring_ = """
49+
add(x1, x2, out=None, order='K')
50+
51+
Calculates the sum for each element `x1_i` of the input array `x1` with
52+
the respective element `x2_i` of the input array `x2`.
53+
54+
Args:
55+
x1 (dpnp.ndarray):
56+
First input array, expected to have numeric data type.
57+
x2 (dpnp.ndarray):
58+
Second input array, also expected to have numeric data type.
59+
out ({None, dpnp.ndarray}, optional):
60+
Output array to populate.
61+
Array have the correct shape and the expected data type.
62+
order ("C","F","A","K", None, optional):
63+
Memory layout of the newly output array, if parameter `out` is `None`.
64+
Default: "K".
65+
Returns:
66+
dpnp.ndarray:
67+
an array containing the result of element-wise division. The data type
68+
of the returned array is determined by the Type Promotion Rules.
69+
"""
70+
71+
def dpnp_add(x1, x2, out=None, order='K'):
72+
"""
73+
Invokes add() from dpctl.tensor implementation for add() function.
74+
TODO: add a pybind11 extension of add() from OneMKL VM where possible
75+
and would be performance effective.
76+
77+
"""
78+
79+
# dpctl.tensor only works with usm_ndarray or scalar
80+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
81+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
82+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
83+
84+
func = BinaryElementwiseFunc("add", ti._add_result_type, ti._add, _add_docstring_)
85+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
86+
return dpnp_array._create_from_usm_ndarray(res_usm)
87+
88+
4589
_divide_docstring_ = """
4690
divide(x1, x2, out=None, order='K')
4791
@@ -88,3 +132,85 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
88132
func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, _divide_docstring_)
89133
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
90134
return dpnp_array._create_from_usm_ndarray(res_usm)
135+
136+
137+
_multiply_docstring_ = """
138+
multiply(x1, x2, out=None, order='K')
139+
140+
Calculates the product for each element `x1_i` of the input array `x1`
141+
with the respective element `x2_i` of the input array `x2`.
142+
143+
Args:
144+
x1 (dpnp.ndarray):
145+
First input array, expected to have numeric data type.
146+
x2 (dpnp.ndarray):
147+
Second input array, also expected to have numeric data type.
148+
out ({None, dpnp.ndarray}, optional):
149+
Output array to populate.
150+
Array have the correct shape and the expected data type.
151+
order ("C","F","A","K", None, optional):
152+
Memory layout of the newly output array, if parameter `out` is `None`.
153+
Default: "K".
154+
Returns:
155+
dpnp.ndarray:
156+
an array containing the result of element-wise division. The data type
157+
of the returned array is determined by the Type Promotion Rules.
158+
"""
159+
160+
def dpnp_multiply(x1, x2, out=None, order='K'):
161+
"""
162+
Invokes multiply() from dpctl.tensor implementation for multiply() function.
163+
TODO: add a pybind11 extension of mul() from OneMKL VM where possible
164+
and would be performance effective.
165+
166+
"""
167+
168+
# dpctl.tensor only works with usm_ndarray or scalar
169+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
170+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
171+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
172+
173+
func = BinaryElementwiseFunc("multiply", ti._multiply_result_type, ti._multiply, _multiply_docstring_)
174+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
175+
return dpnp_array._create_from_usm_ndarray(res_usm)
176+
177+
178+
_subtract_docstring_ = """
179+
subtract(x1, x2, out=None, order='K')
180+
181+
Calculates the difference bewteen each element `x1_i` of the input
182+
array `x1` and the respective element `x2_i` of the input array `x2`.
183+
184+
Args:
185+
x1 (dpnp.ndarray):
186+
First input array, expected to have numeric data type.
187+
x2 (dpnp.ndarray):
188+
Second input array, also expected to have numeric data type.
189+
out ({None, dpnp.ndarray}, optional):
190+
Output array to populate.
191+
Array have the correct shape and the expected data type.
192+
order ("C","F","A","K", None, optional):
193+
Memory layout of the newly output array, if parameter `out` is `None`.
194+
Default: "K".
195+
Returns:
196+
dpnp.ndarray:
197+
an array containing the result of element-wise division. The data type
198+
of the returned array is determined by the Type Promotion Rules.
199+
"""
200+
201+
def dpnp_subtract(x1, x2, out=None, order='K'):
202+
"""
203+
Invokes subtract() from dpctl.tensor implementation for subtract() function.
204+
TODO: add a pybind11 extension of sub() from OneMKL VM where possible
205+
and would be performance effective.
206+
207+
"""
208+
209+
# dpctl.tensor only works with usm_ndarray or scalar
210+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
211+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
212+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
213+
214+
func = BinaryElementwiseFunc("subtract", ti._subtract_result_type, ti._subtract, _subtract_docstring_)
215+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
216+
return dpnp_array._create_from_usm_ndarray(res_usm)

dpnp/dpnp_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,15 @@ def __irshift__(self, other):
249249
dpnp.right_shift(self, other, out=self)
250250
return self
251251

252-
# '__isub__',
252+
def __isub__(self, other):
253+
dpnp.subtract(self, other, out=self)
254+
return self
255+
253256
# '__iter__',
254-
# '__itruediv__',
257+
258+
def __itruediv__(self, other):
259+
dpnp.true_divide(self, other, out=self)
260+
return self
255261

256262
def __ixor__(self, other):
257263
dpnp.bitwise_xor(self, other, out=self)

0 commit comments

Comments
 (0)