Skip to content

Commit a469906

Browse files
Merge pull request #1211 from IntelPython/feature/multiply-and-subtract
Feature/multiply and subtract
2 parents c822d41 + f25053f commit a469906

File tree

17 files changed

+1774
-723
lines changed

17 files changed

+1774
-723
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@
100100
isfinite,
101101
isinf,
102102
isnan,
103+
multiply,
103104
sqrt,
105+
subtract,
104106
)
105107

106108
__all__ = [
@@ -186,5 +188,7 @@
186188
"isfinite",
187189
"sqrt",
188190
"divide",
191+
"multiply",
192+
"subtract",
189193
"equal",
190194
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# B01: ===== ADD (x1, x2)
3535

3636
_add_docstring_ = """
37-
add(x1, x2, order='K')
37+
add(x1, x2, out=None, order='K')
3838
3939
Calculates the sum for each element `x1_i` of the input array `x1` with
4040
the respective element `x2_i` of the input array `x2`.
@@ -94,7 +94,7 @@
9494

9595
# U11: ==== COS (x)
9696
_cos_docstring = """
97-
cos(x, order='K')
97+
cos(x, out=None, order='K')
9898
9999
Computes cosine for each element `x_i` for input array `x`.
100100
"""
@@ -106,7 +106,7 @@
106106

107107
# B08: ==== DIVIDE (x1, x2)
108108
_divide_docstring_ = """
109-
divide(x1, x2, order='K')
109+
divide(x1, x2, out=None, order='K')
110110
111111
Calculates the ratio for each element `x1_i` of the input array `x1` with
112112
the respective element `x2_i` of the input array `x2`.
@@ -128,7 +128,7 @@
128128

129129
# B09: ==== EQUAL (x1, x2)
130130
_equal_docstring_ = """
131-
equal(x1, x2, order='K')
131+
equal(x1, x2, out=None, order='K')
132132
133133
Calculates equality test results for each element `x1_i` of the input array `x1`
134134
with the respective element `x2_i` of the input array `x2`.
@@ -172,6 +172,8 @@
172172

173173
# U17: ==== ISFINITE (x)
174174
_isfinite_docstring_ = """
175+
isfinite(x, out=None, order='K')
176+
175177
Computes if every element of input array is a finite number.
176178
"""
177179

@@ -181,6 +183,8 @@
181183

182184
# U18: ==== ISINF (x)
183185
_isinf_docstring_ = """
186+
isinf(x, out=None, order='K')
187+
184188
Computes if every element of input array is an infinity.
185189
"""
186190

@@ -190,6 +194,8 @@
190194

191195
# U19: ==== ISNAN (x)
192196
_isnan_docstring_ = """
197+
isnan(x, out=None, order='K')
198+
193199
Computes if every element of input array is a NaN.
194200
"""
195201

@@ -231,7 +237,25 @@
231237
# FIXME: implement B18
232238

233239
# B19: ==== MULTIPLY (x1, x2)
234-
# FIXME: implement B19
240+
_multiply_docstring_ = """
241+
multiply(x1, x2, out=None, order='K')
242+
243+
Calculates the product for each element `x1_i` of the input array `x1`
244+
with the respective element `x2_i` of the input array `x2`.
245+
246+
Args:
247+
x1 (usm_ndarray):
248+
First input array, expected to have numeric data type.
249+
x2 (usm_ndarray):
250+
Second input array, also expected to have numeric data type.
251+
Returns:
252+
usm_narray:
253+
an array containing the element-wise products. The data type of
254+
the returned array is determined by the Type Promotion Rules.
255+
"""
256+
multiply = BinaryElementwiseFunc(
257+
"multiply", ti._multiply_result_type, ti._multiply, _multiply_docstring_
258+
)
235259

236260
# U25: ==== NEGATIVE (x)
237261
# FIXME: implement U25
@@ -268,6 +292,8 @@
268292

269293
# U33: ==== SQRT (x)
270294
_sqrt_docstring_ = """
295+
sqrt(x, out=None, order='K')
296+
271297
Computes sqrt for each element `x_i` for input array `x`.
272298
"""
273299

@@ -276,7 +302,26 @@
276302
)
277303

278304
# B23: ==== SUBTRACT (x1, x2)
279-
# FIXME: implement B23
305+
_subtract_docstring_ = """
306+
subtract(x1, x2, out=None, order='K')
307+
308+
Calculates the difference bewteen each element `x1_i` of the input
309+
array `x1` and the respective element `x2_i` of the input array `x2`.
310+
311+
Args:
312+
x1 (usm_ndarray):
313+
First input array, expected to have numeric data type.
314+
x2 (usm_ndarray):
315+
Second input array, also expected to have numeric data type.
316+
Returns:
317+
usm_narray:
318+
an array containing the element-wise differences. The data type
319+
of the returned array is determined by the Type Promotion Rules.
320+
"""
321+
subtract = BinaryElementwiseFunc(
322+
"subtract", ti._subtract_result_type, ti._subtract, _subtract_docstring_
323+
)
324+
280325

281326
# U34: ==== TAN (x)
282327
# FIXME: implement U34

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -115,27 +115,9 @@ sycl::event abs_contig_impl(sycl::queue exec_q,
115115
char *res_p,
116116
const std::vector<sycl::event> &depends = {})
117117
{
118-
sycl::event abs_ev = exec_q.submit([&](sycl::handler &cgh) {
119-
cgh.depends_on(depends);
120-
121-
size_t lws = 64;
122-
constexpr unsigned int vec_sz = 4;
123-
constexpr unsigned int n_vecs = 2;
124-
const size_t n_groups =
125-
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
126-
const auto gws_range = sycl::range<1>(n_groups * lws);
127-
const auto lws_range = sycl::range<1>(lws);
128-
129-
using resTy = typename AbsOutputType<argTy>::value_type;
130-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
131-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
132-
133-
cgh.parallel_for<abs_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
134-
sycl::nd_range<1>(gws_range, lws_range),
135-
AbsContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
136-
nelems));
137-
});
138-
return abs_ev;
118+
return elementwise_common::unary_contig_impl<
119+
argTy, AbsOutputType, AbsContigFunctor, abs_contig_kernel>(
120+
exec_q, nelems, arg_p, res_p, depends);
139121
}
140122

141123
template <typename fnT, typename T> struct AbsContigFactory
@@ -182,24 +164,10 @@ sycl::event abs_strided_impl(sycl::queue exec_q,
182164
const std::vector<sycl::event> &depends,
183165
const std::vector<sycl::event> &additional_depends)
184166
{
185-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
186-
cgh.depends_on(depends);
187-
cgh.depends_on(additional_depends);
188-
189-
using resTy = typename AbsOutputType<argTy>::value_type;
190-
using IndexerT =
191-
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
192-
193-
IndexerT indexer{nd, arg_offset, res_offset, shape_and_strides};
194-
195-
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
196-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
197-
198-
cgh.parallel_for<abs_strided_kernel<argTy, resTy, IndexerT>>(
199-
{nelems},
200-
AbsStridedFunctor<argTy, resTy, IndexerT>(arg_tp, res_tp, indexer));
201-
});
202-
return comp_ev;
167+
return elementwise_common::unary_strided_impl<
168+
argTy, AbsOutputType, AbsStridedFunctor, abs_strided_kernel>(
169+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
170+
res_offset, depends, additional_depends);
203171
}
204172

205173
template <typename fnT, typename T> struct AbsStridedFactory

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 14 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -184,32 +184,10 @@ sycl::event add_contig_impl(sycl::queue exec_q,
184184
py::ssize_t res_offset,
185185
const std::vector<sycl::event> &depends = {})
186186
{
187-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
188-
cgh.depends_on(depends);
189-
190-
size_t lws = 64;
191-
constexpr unsigned int vec_sz = 4;
192-
constexpr unsigned int n_vecs = 2;
193-
const size_t n_groups =
194-
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
195-
const auto gws_range = sycl::range<1>(n_groups * lws);
196-
const auto lws_range = sycl::range<1>(lws);
197-
198-
using resTy = typename AddOutputType<argTy1, argTy2>::value_type;
199-
200-
const argTy1 *arg1_tp =
201-
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
202-
const argTy2 *arg2_tp =
203-
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
204-
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
205-
206-
cgh.parallel_for<
207-
add_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
208-
sycl::nd_range<1>(gws_range, lws_range),
209-
AddContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
210-
arg1_tp, arg2_tp, res_tp, nelems));
211-
});
212-
return comp_ev;
187+
return elementwise_common::binary_contig_impl<
188+
argTy1, argTy2, AddOutputType, AddContigFunctor, add_contig_kernel>(
189+
exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p,
190+
res_offset, depends);
213191
}
214192

215193
template <typename fnT, typename T1, typename T2> struct AddContigFactory
@@ -256,28 +234,11 @@ sycl::event add_strided_impl(sycl::queue exec_q,
256234
const std::vector<sycl::event> &depends,
257235
const std::vector<sycl::event> &additional_depends)
258236
{
259-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
260-
cgh.depends_on(depends);
261-
cgh.depends_on(additional_depends);
262-
263-
using resTy = typename AddOutputType<argTy1, argTy2>::value_type;
264-
265-
using IndexerT =
266-
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
267-
268-
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
269-
shape_and_strides};
270-
271-
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
272-
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
273-
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
274-
275-
cgh.parallel_for<
276-
add_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
277-
{nelems}, AddStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
278-
arg1_tp, arg2_tp, res_tp, indexer));
279-
});
280-
return comp_ev;
237+
return elementwise_common::binary_strided_impl<
238+
argTy1, argTy2, AddOutputType, AddStridedFunctor,
239+
add_strided_strided_kernel>(
240+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
241+
arg2_offset, res_p, res_offset, depends, additional_depends);
281242
}
282243

283244
template <typename fnT, typename T1, typename T2> struct AddStridedFactory
@@ -322,62 +283,11 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
322283
py::ssize_t res_offset,
323284
const std::vector<sycl::event> &depends = {})
324285
{
325-
const argT1 *mat = reinterpret_cast<const argT1 *>(mat_p) + mat_offset;
326-
const argT2 *vec = reinterpret_cast<const argT2 *>(vec_p) + vec_offset;
327-
resT *res = reinterpret_cast<resT *>(res_p) + res_offset;
328-
329-
const auto &dev = exec_q.get_device();
330-
const auto &sg_sizes = dev.get_info<sycl::info::device::sub_group_sizes>();
331-
// Get device-specific kernel info max_sub_group_size
332-
size_t max_sgSize =
333-
*(std::max_element(std::begin(sg_sizes), std::end(sg_sizes)));
334-
335-
size_t n1_padded = n1 + max_sgSize;
336-
argT2 *padded_vec = sycl::malloc_device<argT2>(n1_padded, exec_q);
337-
338-
if (padded_vec == nullptr) {
339-
throw std::runtime_error("Could not allocate memory on the device");
340-
}
341-
sycl::event make_padded_vec_ev = exec_q.submit([&](sycl::handler &cgh) {
342-
cgh.depends_on(depends); // ensure vec contains actual data
343-
cgh.parallel_for({n1_padded}, [=](sycl::id<1> id) {
344-
auto i = id[0];
345-
padded_vec[i] = vec[i % n1];
346-
});
347-
});
348-
349-
// sub-group spans work-items [I, I + sgSize)
350-
// base = ndit.get_global_linear_id() - sg.get_local_id()[0]
351-
// Generically, sg.load( &mat[base]) may load arrays from
352-
// different rows of mat. The start corresponds to row (base / n0)
353-
// We read sg.load(&padded_vec[(base / n0)]). The vector is padded to
354-
// ensure that reads are accessible
355-
356-
size_t lws = 64;
357-
358-
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
359-
cgh.depends_on(make_padded_vec_ev);
360-
361-
auto lwsRange = sycl::range<1>(lws);
362-
size_t n_elems = n0 * n1;
363-
size_t n_groups = (n_elems + lws - 1) / lws;
364-
auto gwsRange = sycl::range<1>(n_groups * lws);
365-
366-
cgh.parallel_for<
367-
class add_matrix_row_broadcast_sg_krn<argT1, argT2, resT>>(
368-
sycl::nd_range<1>(gwsRange, lwsRange),
369-
AddContigMatrixContigRowBroadcastingFunctor<argT1, argT2, resT>(
370-
mat, padded_vec, res, n_elems, n1));
371-
});
372-
373-
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
374-
cgh.depends_on(comp_ev);
375-
sycl::context ctx = exec_q.get_context();
376-
cgh.host_task([ctx, padded_vec]() { sycl::free(padded_vec, ctx); });
377-
});
378-
host_tasks.push_back(tmp_cleanup_ev);
379-
380-
return comp_ev;
286+
return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl<
287+
argT1, argT2, resT, AddContigMatrixContigRowBroadcastingFunctor,
288+
add_matrix_row_broadcast_sg_krn>(exec_q, host_tasks, n0, n1, mat_p,
289+
mat_offset, vec_p, vec_offset, res_p,
290+
res_offset, depends);
381291
}
382292

383293
template <typename fnT, typename T1, typename T2>

0 commit comments

Comments
 (0)