@@ -184,32 +184,10 @@ sycl::event add_contig_impl(sycl::queue exec_q,
184
184
py::ssize_t res_offset,
185
185
const std::vector<sycl::event> &depends = {})
186
186
{
187
- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
188
- cgh.depends_on (depends);
189
-
190
- size_t lws = 64 ;
191
- constexpr unsigned int vec_sz = 4 ;
192
- constexpr unsigned int n_vecs = 2 ;
193
- const size_t n_groups =
194
- ((nelems + lws * n_vecs * vec_sz - 1 ) / (lws * n_vecs * vec_sz));
195
- const auto gws_range = sycl::range<1 >(n_groups * lws);
196
- const auto lws_range = sycl::range<1 >(lws);
197
-
198
- using resTy = typename AddOutputType<argTy1, argTy2>::value_type;
199
-
200
- const argTy1 *arg1_tp =
201
- reinterpret_cast <const argTy1 *>(arg1_p) + arg1_offset;
202
- const argTy2 *arg2_tp =
203
- reinterpret_cast <const argTy2 *>(arg2_p) + arg2_offset;
204
- resTy *res_tp = reinterpret_cast <resTy *>(res_p) + res_offset;
205
-
206
- cgh.parallel_for <
207
- add_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
208
- sycl::nd_range<1 >(gws_range, lws_range),
209
- AddContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
210
- arg1_tp, arg2_tp, res_tp, nelems));
211
- });
212
- return comp_ev;
187
+ return elementwise_common::binary_contig_impl<
188
+ argTy1, argTy2, AddOutputType, AddContigFunctor, add_contig_kernel>(
189
+ exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p,
190
+ res_offset, depends);
213
191
}
214
192
215
193
template <typename fnT, typename T1, typename T2> struct AddContigFactory
@@ -256,28 +234,11 @@ sycl::event add_strided_impl(sycl::queue exec_q,
256
234
const std::vector<sycl::event> &depends,
257
235
const std::vector<sycl::event> &additional_depends)
258
236
{
259
- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
260
- cgh.depends_on (depends);
261
- cgh.depends_on (additional_depends);
262
-
263
- using resTy = typename AddOutputType<argTy1, argTy2>::value_type;
264
-
265
- using IndexerT =
266
- typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
267
-
268
- IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
269
- shape_and_strides};
270
-
271
- const argTy1 *arg1_tp = reinterpret_cast <const argTy1 *>(arg1_p);
272
- const argTy2 *arg2_tp = reinterpret_cast <const argTy2 *>(arg2_p);
273
- resTy *res_tp = reinterpret_cast <resTy *>(res_p);
274
-
275
- cgh.parallel_for <
276
- add_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
277
- {nelems}, AddStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
278
- arg1_tp, arg2_tp, res_tp, indexer));
279
- });
280
- return comp_ev;
237
+ return elementwise_common::binary_strided_impl<
238
+ argTy1, argTy2, AddOutputType, AddStridedFunctor,
239
+ add_strided_strided_kernel>(
240
+ exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
241
+ arg2_offset, res_p, res_offset, depends, additional_depends);
281
242
}
282
243
283
244
template <typename fnT, typename T1, typename T2> struct AddStridedFactory
@@ -322,62 +283,11 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
322
283
py::ssize_t res_offset,
323
284
const std::vector<sycl::event> &depends = {})
324
285
{
325
- const argT1 *mat = reinterpret_cast <const argT1 *>(mat_p) + mat_offset;
326
- const argT2 *vec = reinterpret_cast <const argT2 *>(vec_p) + vec_offset;
327
- resT *res = reinterpret_cast <resT *>(res_p) + res_offset;
328
-
329
- const auto &dev = exec_q.get_device ();
330
- const auto &sg_sizes = dev.get_info <sycl::info::device::sub_group_sizes>();
331
- // Get device-specific kernel info max_sub_group_size
332
- size_t max_sgSize =
333
- *(std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
334
-
335
- size_t n1_padded = n1 + max_sgSize;
336
- argT2 *padded_vec = sycl::malloc_device<argT2>(n1_padded, exec_q);
337
-
338
- if (padded_vec == nullptr ) {
339
- throw std::runtime_error (" Could not allocate memory on the device" );
340
- }
341
- sycl::event make_padded_vec_ev = exec_q.submit ([&](sycl::handler &cgh) {
342
- cgh.depends_on (depends); // ensure vec contains actual data
343
- cgh.parallel_for ({n1_padded}, [=](sycl::id<1 > id) {
344
- auto i = id[0 ];
345
- padded_vec[i] = vec[i % n1];
346
- });
347
- });
348
-
349
- // sub-group spans work-items [I, I + sgSize)
350
- // base = ndit.get_global_linear_id() - sg.get_local_id()[0]
351
- // Generically, sg.load( &mat[base]) may load arrays from
352
- // different rows of mat. The start corresponds to row (base / n0)
353
- // We read sg.load(&padded_vec[(base / n0)]). The vector is padded to
354
- // ensure that reads are accessible
355
-
356
- size_t lws = 64 ;
357
-
358
- sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
359
- cgh.depends_on (make_padded_vec_ev);
360
-
361
- auto lwsRange = sycl::range<1 >(lws);
362
- size_t n_elems = n0 * n1;
363
- size_t n_groups = (n_elems + lws - 1 ) / lws;
364
- auto gwsRange = sycl::range<1 >(n_groups * lws);
365
-
366
- cgh.parallel_for <
367
- class add_matrix_row_broadcast_sg_krn <argT1, argT2, resT>>(
368
- sycl::nd_range<1 >(gwsRange, lwsRange),
369
- AddContigMatrixContigRowBroadcastingFunctor<argT1, argT2, resT>(
370
- mat, padded_vec, res, n_elems, n1));
371
- });
372
-
373
- sycl::event tmp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
374
- cgh.depends_on (comp_ev);
375
- sycl::context ctx = exec_q.get_context ();
376
- cgh.host_task ([ctx, padded_vec]() { sycl::free (padded_vec, ctx); });
377
- });
378
- host_tasks.push_back (tmp_cleanup_ev);
379
-
380
- return comp_ev;
286
+ return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl<
287
+ argT1, argT2, resT, AddContigMatrixContigRowBroadcastingFunctor,
288
+ add_matrix_row_broadcast_sg_krn>(exec_q, host_tasks, n0, n1, mat_p,
289
+ mat_offset, vec_p, vec_offset, res_p,
290
+ res_offset, depends);
381
291
}
382
292
383
293
template <typename fnT, typename T1, typename T2>
0 commit comments