@@ -39,6 +39,10 @@ namespace kernels
39
39
namespace constructors
40
40
{
41
41
42
+ /* !
43
+ @defgroup CtorKernels
44
+ */
45
+
42
46
template <typename Ty> class linear_sequence_step_kernel ;
43
47
template <typename Ty, typename wTy> class linear_sequence_affine_kernel ;
44
48
template <typename Ty> class eye_kernel ;
@@ -47,6 +51,10 @@ namespace py = pybind11;
47
51
48
52
/* =========== Unboxing Python scalar =============== */
49
53
54
+ /* !
55
+ * @brief Cast pybind11 class managing Python object to specified type `T`.
56
+ * @defgroup CtorKernels
57
+ */
50
58
template <typename T> T unbox_py_scalar (py::object o)
51
59
{
52
60
return py::cast<T>(o);
@@ -96,6 +104,23 @@ template <typename Ty> class LinearSequenceStepFunctor
96
104
}
97
105
};
98
106
107
+ /* !
108
+ * @brief Function to submit kernel to populate given contiguous memory
109
+ * allocation with linear sequence specified by typed starting value and
110
+ * increment.
111
+ *
112
+ * @param q Sycl queue to which the kernel is submitted
113
+ * @param nelems Length of the sequence
114
+ * @param start_v Typed starting value of the sequence
115
+ * @param step_v Typed increment of the sequence
116
+ * @param array_data Kernel accessible USM pointer to the start of array to be
117
+ * populated.
118
+ * @param depends List of events to wait for before starting computations, if
119
+ * any.
120
+ *
121
+ * @return Event to wait on to ensure that computation completes.
122
+ * @defgroup CtorKernels
123
+ */
99
124
template <typename Ty>
100
125
sycl::event lin_space_step_impl (sycl::queue exec_q,
101
126
size_t nelems,
@@ -114,6 +139,25 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
114
139
return lin_space_step_event;
115
140
}
116
141
142
+ /* !
143
+ * @brief Function to submit kernel to populate given contiguous memory
144
+ * allocation with linear sequence specified by starting value and increment
145
+ * given as Python objects.
146
+ *
147
+ * @param q Sycl queue to which the kernel is submitted
148
+ * @param nelems Length of the sequence
149
+ * @param start Starting value of the sequence as Python object. Must be
150
+ * convertible to array element data type `Ty`.
151
+ * @param step Increment of the sequence as Python object. Must be convertible
152
+ * to array element data type `Ty`.
153
+ * @param array_data Kernel accessible USM pointer to the start of array to be
154
+ * populated.
155
+ * @param depends List of events to wait for before starting computations, if
156
+ * any.
157
+ *
158
+ * @return Event to wait on to ensure that computation completes.
159
+ * @defgroup CtorKernels
160
+ */
117
161
template <typename Ty>
118
162
sycl::event lin_space_step_impl (sycl::queue exec_q,
119
163
size_t nelems,
@@ -137,6 +181,11 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
137
181
return lin_space_step_event;
138
182
}
139
183
184
+ /* !
185
+ * @brief Factor to get function pointer of type `fnT` for array with elements
186
+ * of type `Ty`.
187
+ * @defgroup CtorKernels
188
+ */
140
189
template <typename fnT, typename Ty> struct LinSpaceStepFactory
141
190
{
142
191
fnT get ()
@@ -195,6 +244,23 @@ template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
195
244
}
196
245
};
197
246
247
+ /* !
248
+ * @brief Function to submit kernel to populate given contiguous memory
249
+ * allocation with linear sequence specified by typed starting and end values.
250
+ *
251
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
252
+ * @param nelems Length of the sequence.
253
+ * @param start_v Stating value of the sequence.
254
+ * @param end_v End-value of the sequence.
255
+ * @param include_endpoint Whether the end-value is included in the sequence.
256
+ * @param array_data Kernel accessible USM pointer to the start of array to be
257
+ * populated.
258
+ * @param depends List of events to wait for before starting computations, if
259
+ * any.
260
+ *
261
+ * @return Event to wait on to ensure that computation completes.
262
+ * @defgroup CtorKernels
263
+ */
198
264
template <typename Ty>
199
265
sycl::event lin_space_affine_impl (sycl::queue exec_q,
200
266
size_t nelems,
@@ -226,6 +292,26 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
226
292
return lin_space_affine_event;
227
293
}
228
294
295
+ /* !
296
+ * @brief Function to submit kernel to populate given contiguous memory
297
+ * allocation with linear sequence specified by starting and end values given
298
+ * as Python objects.
299
+ *
300
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
301
+ * @param nelems Length of the sequence
302
+ * @param start Stating value of the sequence as Python object. Must be
303
+ * convertible to array data element type `Ty`.
304
+ * @param end End-value of the sequence as Python object. Must be convertible
305
+ * to array data element type `Ty`.
306
+ * @param include_endpoint Whether the end-value is included in the sequence
307
+ * @param array_data Kernel accessible USM pointer to the start of array to be
308
+ * populated.
309
+ * @param depends List of events to wait for before starting computations, if
310
+ * any.
311
+ *
312
+ * @return Event to wait on to ensure that computation completes.
313
+ * @defgroup CtorKernels
314
+ */
229
315
template <typename Ty>
230
316
sycl::event lin_space_affine_impl (sycl::queue exec_q,
231
317
size_t nelems,
@@ -249,6 +335,10 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
249
335
return lin_space_affine_event;
250
336
}
251
337
338
+ /* !
339
+ * @brief Factory to get function pointer of type `fnT` for array data type
340
+ * `Ty`.
341
+ */
252
342
template <typename fnT, typename Ty> struct LinSpaceAffineFactory
253
343
{
254
344
fnT get ()
@@ -266,6 +356,21 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue,
266
356
char *,
267
357
const std::vector<sycl::event> &);
268
358
359
+ /* !
360
+ * @brief Function to submit kernel to fill given contiguous memory allocation
361
+ * with specified value.
362
+ *
363
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
364
+ * @param nelems Length of the sequence
365
+ * @param fill_v Value to fill the array with
366
+ * @param dst_p Kernel accessible USM pointer to the start of array to be
367
+ * populated.
368
+ * @param depends List of events to wait for before starting computations, if
369
+ * any.
370
+ *
371
+ * @return Event to wait on to ensure that computation completes.
372
+ * @defgroup CtorKernels
373
+ */
269
374
template <typename dstTy>
270
375
sycl::event full_contig_impl (sycl::queue q,
271
376
size_t nelems,
@@ -282,6 +387,22 @@ sycl::event full_contig_impl(sycl::queue q,
282
387
return fill_ev;
283
388
}
284
389
390
+ /* !
391
+ * @brief Function to submit kernel to fill given contiguous memory allocation
392
+ * with specified value.
393
+ *
394
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
395
+ * @param nelems Length of the sequence
396
+ * @param py_value Python object representing the value to fill the array with.
397
+ * Must be convertible to `dstTy`.
398
+ * @param dst_p Kernel accessible USM pointer to the start of array to be
399
+ * populated.
400
+ * @param depends List of events to wait for before starting computations, if
401
+ * any.
402
+ *
403
+ * @return Event to wait on to ensure that computation completes.
404
+ * @defgroup CtorKernels
405
+ */
285
406
template <typename dstTy>
286
407
sycl::event full_contig_impl (sycl::queue exec_q,
287
408
size_t nelems,
@@ -351,6 +472,21 @@ template <typename Ty> class EyeFunctor
351
472
}
352
473
};
353
474
475
+ /* !
476
+ * @brief Function to populate 2D array with eye matrix.
477
+ *
478
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
479
+ * @param nelems Number of elements to assign.
480
+ * @param start Position of the first non-zero value.
481
+ * @param end Position of the last non-zero value.
482
+ * @param step Number of array elements between non-zeros.
483
+ * @param array_data Kernel accessible USM pointer for the destination array.
484
+ * @param depends List of events to wait for before starting computations, if
485
+ * any.
486
+ *
487
+ * @return Event to wait on to ensure that computation completes.
488
+ * @defgroup CtorKernels
489
+ */
354
490
template <typename Ty>
355
491
sycl::event eye_impl (sycl::queue exec_q,
356
492
size_t nelems,
@@ -370,6 +506,10 @@ sycl::event eye_impl(sycl::queue exec_q,
370
506
return eye_event;
371
507
}
372
508
509
+ /* !
510
+ * @brief Factory to get function pointer of type `fnT` for data type `Ty`.
511
+ * @ingroup CtorKernels
512
+ */
373
513
template <typename fnT, typename Ty> struct EyeFactory
374
514
{
375
515
fnT get ()
@@ -393,8 +533,30 @@ typedef sycl::event (*tri_fn_ptr_t)(sycl::queue,
393
533
const std::vector<sycl::event> &,
394
534
const std::vector<sycl::event> &);
395
535
536
+ /* !
537
+ * @brief Function to copy triangular matrices from source stack to destination
538
+ * stack.
539
+ *
540
+ * @param exec_q Sycl queue to which kernel is submitted for execution.
541
+ * @param inner_range Number of elements in each matrix.
542
+ * @param outer_range Number of matrices to copy.
543
+ * @param src_p Kernel accessible USM pointer for the source array.
544
+ * @param dst_p Kernel accessible USM pointer for the destination array.
545
+ * @param nd The array dimensionality of source and destination arrays.
546
+ * @param shape_and_strides Kernel accessible USM pointer to packed shape and
547
+ * strides of arrays.
548
+ * @param k Position of the diagonal above/below which to copy filling the rest
549
+ * with zero elements.
550
+ * @param depends List of events to wait for before starting computations, if
551
+ * any.
552
+ * @param additional_depends List of additional events to wait for before
553
+ * starting computations, if any.
554
+ *
555
+ * @return Event to wait on to ensure that computation completes.
556
+ * @defgroup CtorKernels
557
+ */
396
558
template <typename Ty, bool > class tri_kernel ;
397
- template <typename Ty, bool l >
559
+ template <typename Ty, bool upper >
398
560
sycl::event tri_impl (sycl::queue exec_q,
399
561
py::ssize_t inner_range,
400
562
py::ssize_t outer_range,
@@ -417,7 +579,7 @@ sycl::event tri_impl(sycl::queue exec_q,
417
579
sycl::event tri_ev = exec_q.submit ([&](sycl::handler &cgh) {
418
580
cgh.depends_on (depends);
419
581
cgh.depends_on (additional_depends);
420
- cgh.parallel_for <tri_kernel<Ty, l >>(
582
+ cgh.parallel_for <tri_kernel<Ty, upper >>(
421
583
sycl::range<1 >(inner_range * outer_range), [=](sycl::id<1 > idx) {
422
584
py::ssize_t outer_gid = idx[0 ] / inner_range;
423
585
py::ssize_t inner_gid = idx[0 ] - inner_range * outer_gid;
@@ -438,7 +600,7 @@ sycl::event tri_impl(sycl::queue exec_q,
438
600
inner[0 ] * shape_and_strides[dst_s + nd_2] +
439
601
inner[1 ] * shape_and_strides[dst_s + nd_1];
440
602
441
- if (l )
603
+ if constexpr (upper )
442
604
to_copy = (inner[0 ] + k >= inner[1 ]);
443
605
else
444
606
to_copy = (inner[0 ] + k <= inner[1 ]);
@@ -463,6 +625,10 @@ sycl::event tri_impl(sycl::queue exec_q,
463
625
return tri_ev;
464
626
}
465
627
628
+ /* !
629
+ * @brief Factory to get function pointer of type `fnT` for data type `Ty`.
630
+ * @ingroup CtorKernels
631
+ */
466
632
template <typename fnT, typename Ty> struct TrilGenericFactory
467
633
{
468
634
fnT get ()
@@ -472,6 +638,10 @@ template <typename fnT, typename Ty> struct TrilGenericFactory
472
638
}
473
639
};
474
640
641
+ /* !
642
+ * @brief Factory to get function pointer of type `fnT` for data type `Ty`.
643
+ * @ingroup CtorKernels
644
+ */
475
645
template <typename fnT, typename Ty> struct TriuGenericFactory
476
646
{
477
647
fnT get ()
0 commit comments