Skip to content

Commit 3273a6f

Browse files
Implement dpctl.tensor.multiply, dpctl.tensor.subtract
1 parent 18d1728 commit 3273a6f

File tree

5 files changed

+1266
-10
lines changed

5 files changed

+1266
-10
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@
100100
isfinite,
101101
isinf,
102102
isnan,
103+
multiply,
103104
sqrt,
105+
subtract,
104106
)
105107

106108
__all__ = [
@@ -186,5 +188,7 @@
186188
"isfinite",
187189
"sqrt",
188190
"divide",
191+
"multiply",
192+
"subtract",
189193
"equal",
190194
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# B01: ===== ADD (x1, x2)
3535

3636
_add_docstring_ = """
37-
add(x1, x2, order='K')
37+
add(x1, x2, out=None, order='K')
3838
3939
Calculates the sum for each element `x1_i` of the input array `x1` with
4040
the respective element `x2_i` of the input array `x2`.
@@ -94,7 +94,7 @@
9494

9595
# U11: ==== COS (x)
9696
_cos_docstring = """
97-
cos(x, order='K')
97+
cos(x, out=None, order='K')
9898
9999
Computes cosine for each element `x_i` for input array `x`.
100100
"""
@@ -106,7 +106,7 @@
106106

107107
# B08: ==== DIVIDE (x1, x2)
108108
_divide_docstring_ = """
109-
divide(x1, x2, order='K')
109+
divide(x1, x2, out=None, order='K')
110110
111111
Calculates the ratio for each element `x1_i` of the input array `x1` with
112112
the respective element `x2_i` of the input array `x2`.
@@ -128,7 +128,7 @@
128128

129129
# B09: ==== EQUAL (x1, x2)
130130
_equal_docstring_ = """
131-
equal(x1, x2, order='K')
131+
equal(x1, x2, out=None, order='K')
132132
133133
Calculates equality test results for each element `x1_i` of the input array `x1`
134134
with the respective element `x2_i` of the input array `x2`.
@@ -172,6 +172,8 @@
172172

173173
# U17: ==== ISFINITE (x)
174174
_isfinite_docstring_ = """
175+
isfinite(x, out=None, order='K')
176+
175177
Computes if every element of input array is a finite number.
176178
"""
177179

@@ -181,6 +183,8 @@
181183

182184
# U18: ==== ISINF (x)
183185
_isinf_docstring_ = """
186+
isinf(x, out=None, order='K')
187+
184188
Computes if every element of input array is an infinity.
185189
"""
186190

@@ -190,6 +194,8 @@
190194

191195
# U19: ==== ISNAN (x)
192196
_isnan_docstring_ = """
197+
isnan(x, out=None, order='K')
198+
193199
Computes if every element of input array is a NaN.
194200
"""
195201

@@ -231,7 +237,25 @@
231237
# FIXME: implement B18
232238

233239
# B19: ==== MULTIPLY (x1, x2)
234-
# FIXME: implement B19
240+
_multiply_docstring_ = """
241+
multiply(x1, x2, out=None, order='K')
242+
243+
Calculates the product for each element `x1_i` of the input array `x1`
244+
with the respective element `x2_i` of the input array `x2`.
245+
246+
Args:
247+
x1 (usm_ndarray):
248+
First input array, expected to have numeric data type.
249+
x2 (usm_ndarray):
250+
Second input array, also expected to have numeric data type.
251+
Returns:
252+
usm_narray:
253+
an array containing the element-wise products. The data type of
254+
the returned array is determined by the Type Promotion Rules.
255+
"""
256+
multiply = BinaryElementwiseFunc(
257+
"multiply", ti._multiply_result_type, ti._multiply, _multiply_docstring_
258+
)
235259

236260
# U25: ==== NEGATIVE (x)
237261
# FIXME: implement U25
@@ -268,6 +292,8 @@
268292

269293
# U33: ==== SQRT (x)
270294
_sqrt_docstring_ = """
295+
sqrt(x, out=None, order='K')
296+
271297
Computes sqrt for each element `x_i` for input array `x`.
272298
"""
273299

@@ -276,7 +302,26 @@
276302
)
277303

278304
# B23: ==== SUBTRACT (x1, x2)
279-
# FIXME: implement B23
305+
_subtract_docstring_ = """
306+
subtract(x1, x2, out=None, order='K')
307+
308+
Calculates the difference bewteen each element `x1_i` of the input
309+
array `x1` and the respective element `x2_i` of the input array `x2`.
310+
311+
Args:
312+
x1 (usm_ndarray):
313+
First input array, expected to have numeric data type.
314+
x2 (usm_ndarray):
315+
Second input array, also expected to have numeric data type.
316+
Returns:
317+
usm_narray:
318+
an array containing the element-wise differences. The data type
319+
of the returned array is determined by the Type Promotion Rules.
320+
"""
321+
subtract = BinaryElementwiseFunc(
322+
"subtract", ti._subtract_result_type, ti._subtract, _subtract_docstring_
323+
)
324+
280325

281326
# U34: ==== TAN (x)
282327
# FIXME: implement U34

0 commit comments

Comments
 (0)