Skip to content

Commit d658ebc

Browse files
Added doxygen docs for kernels
1 parent 34bfc6d commit d658ebc

File tree

2 files changed

+337
-3
lines changed

2 files changed

+337
-3
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 173 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ namespace kernels
3939
namespace constructors
4040
{
4141

42+
/*!
43+
@defgroup CtorKernels
44+
*/
45+
4246
template <typename Ty> class linear_sequence_step_kernel;
4347
template <typename Ty, typename wTy> class linear_sequence_affine_kernel;
4448
template <typename Ty> class eye_kernel;
@@ -47,6 +51,10 @@ namespace py = pybind11;
4751

4852
/* =========== Unboxing Python scalar =============== */
4953

54+
/*!
55+
* @brief Cast pybind11 class managing Python object to specified type `T`.
56+
* @defgroup CtorKernels
57+
*/
5058
template <typename T> T unbox_py_scalar(py::object o)
5159
{
5260
return py::cast<T>(o);
@@ -96,6 +104,23 @@ template <typename Ty> class LinearSequenceStepFunctor
96104
}
97105
};
98106

107+
/*!
108+
* @brief Function to submit kernel to populate given contiguous memory
109+
* allocation with linear sequence specified by typed starting value and
110+
* increment.
111+
*
112+
* @param q Sycl queue to which the kernel is submitted
113+
* @param nelems Length of the sequence
114+
* @param start_v Typed starting value of the sequence
115+
* @param step_v Typed increment of the sequence
116+
* @param array_data Kernel accessible USM pointer to the start of array to be
117+
* populated.
118+
* @param depends List of events to wait for before starting computations, if
119+
* any.
120+
*
121+
* @return Event to wait on to ensure that computation completes.
122+
* @defgroup CtorKernels
123+
*/
99124
template <typename Ty>
100125
sycl::event lin_space_step_impl(sycl::queue exec_q,
101126
size_t nelems,
@@ -114,6 +139,25 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
114139
return lin_space_step_event;
115140
}
116141

142+
/*!
143+
* @brief Function to submit kernel to populate given contiguous memory
144+
* allocation with linear sequence specified by starting value and increment
145+
* given as Python objects.
146+
*
147+
* @param q Sycl queue to which the kernel is submitted
148+
* @param nelems Length of the sequence
149+
* @param start Starting value of the sequence as Python object. Must be
150+
* convertible to array element data type `Ty`.
151+
* @param step Increment of the sequence as Python object. Must be convertible
152+
* to array element data type `Ty`.
153+
* @param array_data Kernel accessible USM pointer to the start of array to be
154+
* populated.
155+
* @param depends List of events to wait for before starting computations, if
156+
* any.
157+
*
158+
* @return Event to wait on to ensure that computation completes.
159+
* @defgroup CtorKernels
160+
*/
117161
template <typename Ty>
118162
sycl::event lin_space_step_impl(sycl::queue exec_q,
119163
size_t nelems,
@@ -137,6 +181,11 @@ sycl::event lin_space_step_impl(sycl::queue exec_q,
137181
return lin_space_step_event;
138182
}
139183

184+
/*!
185+
* @brief Factor to get function pointer of type `fnT` for array with elements
186+
* of type `Ty`.
187+
* @defgroup CtorKernels
188+
*/
140189
template <typename fnT, typename Ty> struct LinSpaceStepFactory
141190
{
142191
fnT get()
@@ -195,6 +244,23 @@ template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
195244
}
196245
};
197246

247+
/*!
248+
* @brief Function to submit kernel to populate given contiguous memory
249+
* allocation with linear sequence specified by typed starting and end values.
250+
*
251+
* @param exec_q Sycl queue to which kernel is submitted for execution.
252+
* @param nelems Length of the sequence.
253+
* @param start_v Stating value of the sequence.
254+
* @param end_v End-value of the sequence.
255+
* @param include_endpoint Whether the end-value is included in the sequence.
256+
* @param array_data Kernel accessible USM pointer to the start of array to be
257+
* populated.
258+
* @param depends List of events to wait for before starting computations, if
259+
* any.
260+
*
261+
* @return Event to wait on to ensure that computation completes.
262+
* @defgroup CtorKernels
263+
*/
198264
template <typename Ty>
199265
sycl::event lin_space_affine_impl(sycl::queue exec_q,
200266
size_t nelems,
@@ -226,6 +292,26 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
226292
return lin_space_affine_event;
227293
}
228294

295+
/*!
296+
* @brief Function to submit kernel to populate given contiguous memory
297+
* allocation with linear sequence specified by starting and end values given
298+
* as Python objects.
299+
*
300+
* @param exec_q Sycl queue to which kernel is submitted for execution.
301+
* @param nelems Length of the sequence
302+
* @param start Stating value of the sequence as Python object. Must be
303+
* convertible to array data element type `Ty`.
304+
* @param end End-value of the sequence as Python object. Must be convertible
305+
* to array data element type `Ty`.
306+
* @param include_endpoint Whether the end-value is included in the sequence
307+
* @param array_data Kernel accessible USM pointer to the start of array to be
308+
* populated.
309+
* @param depends List of events to wait for before starting computations, if
310+
* any.
311+
*
312+
* @return Event to wait on to ensure that computation completes.
313+
* @defgroup CtorKernels
314+
*/
229315
template <typename Ty>
230316
sycl::event lin_space_affine_impl(sycl::queue exec_q,
231317
size_t nelems,
@@ -249,6 +335,10 @@ sycl::event lin_space_affine_impl(sycl::queue exec_q,
249335
return lin_space_affine_event;
250336
}
251337

338+
/*!
339+
* @brief Factory to get function pointer of type `fnT` for array data type
340+
* `Ty`.
341+
*/
252342
template <typename fnT, typename Ty> struct LinSpaceAffineFactory
253343
{
254344
fnT get()
@@ -266,6 +356,21 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue,
266356
char *,
267357
const std::vector<sycl::event> &);
268358

359+
/*!
360+
* @brief Function to submit kernel to fill given contiguous memory allocation
361+
* with specified value.
362+
*
363+
* @param exec_q Sycl queue to which kernel is submitted for execution.
364+
* @param nelems Length of the sequence
365+
* @param fill_v Value to fill the array with
366+
* @param dst_p Kernel accessible USM pointer to the start of array to be
367+
* populated.
368+
* @param depends List of events to wait for before starting computations, if
369+
* any.
370+
*
371+
* @return Event to wait on to ensure that computation completes.
372+
* @defgroup CtorKernels
373+
*/
269374
template <typename dstTy>
270375
sycl::event full_contig_impl(sycl::queue q,
271376
size_t nelems,
@@ -282,6 +387,22 @@ sycl::event full_contig_impl(sycl::queue q,
282387
return fill_ev;
283388
}
284389

390+
/*!
391+
* @brief Function to submit kernel to fill given contiguous memory allocation
392+
* with specified value.
393+
*
394+
* @param exec_q Sycl queue to which kernel is submitted for execution.
395+
* @param nelems Length of the sequence
396+
* @param py_value Python object representing the value to fill the array with.
397+
* Must be convertible to `dstTy`.
398+
* @param dst_p Kernel accessible USM pointer to the start of array to be
399+
* populated.
400+
* @param depends List of events to wait for before starting computations, if
401+
* any.
402+
*
403+
* @return Event to wait on to ensure that computation completes.
404+
* @defgroup CtorKernels
405+
*/
285406
template <typename dstTy>
286407
sycl::event full_contig_impl(sycl::queue exec_q,
287408
size_t nelems,
@@ -351,6 +472,21 @@ template <typename Ty> class EyeFunctor
351472
}
352473
};
353474

475+
/*!
476+
* @brief Function to populate 2D array with eye matrix.
477+
*
478+
* @param exec_q Sycl queue to which kernel is submitted for execution.
479+
* @param nelems Number of elements to assign.
480+
* @param start Position of the first non-zero value.
481+
* @param end Position of the last non-zero value.
482+
* @param step Number of array elements between non-zeros.
483+
* @param array_data Kernel accessible USM pointer for the destination array.
484+
* @param depends List of events to wait for before starting computations, if
485+
* any.
486+
*
487+
* @return Event to wait on to ensure that computation completes.
488+
* @defgroup CtorKernels
489+
*/
354490
template <typename Ty>
355491
sycl::event eye_impl(sycl::queue exec_q,
356492
size_t nelems,
@@ -370,6 +506,10 @@ sycl::event eye_impl(sycl::queue exec_q,
370506
return eye_event;
371507
}
372508

509+
/*!
510+
* @brief Factory to get function pointer of type `fnT` for data type `Ty`.
511+
* @ingroup CtorKernels
512+
*/
373513
template <typename fnT, typename Ty> struct EyeFactory
374514
{
375515
fnT get()
@@ -393,8 +533,30 @@ typedef sycl::event (*tri_fn_ptr_t)(sycl::queue,
393533
const std::vector<sycl::event> &,
394534
const std::vector<sycl::event> &);
395535

536+
/*!
537+
* @brief Function to copy triangular matrices from source stack to destination
538+
* stack.
539+
*
540+
* @param exec_q Sycl queue to which kernel is submitted for execution.
541+
* @param inner_range Number of elements in each matrix.
542+
* @param outer_range Number of matrices to copy.
543+
* @param src_p Kernel accessible USM pointer for the source array.
544+
* @param dst_p Kernel accessible USM pointer for the destination array.
545+
* @param nd The array dimensionality of source and destination arrays.
546+
* @param shape_and_strides Kernel accessible USM pointer to packed shape and
547+
* strides of arrays.
548+
* @param k Position of the diagonal above/below which to copy filling the rest
549+
* with zero elements.
550+
* @param depends List of events to wait for before starting computations, if
551+
* any.
552+
* @param additional_depends List of additional events to wait for before
553+
* starting computations, if any.
554+
*
555+
* @return Event to wait on to ensure that computation completes.
556+
* @defgroup CtorKernels
557+
*/
396558
template <typename Ty, bool> class tri_kernel;
397-
template <typename Ty, bool l>
559+
template <typename Ty, bool upper>
398560
sycl::event tri_impl(sycl::queue exec_q,
399561
py::ssize_t inner_range,
400562
py::ssize_t outer_range,
@@ -417,7 +579,7 @@ sycl::event tri_impl(sycl::queue exec_q,
417579
sycl::event tri_ev = exec_q.submit([&](sycl::handler &cgh) {
418580
cgh.depends_on(depends);
419581
cgh.depends_on(additional_depends);
420-
cgh.parallel_for<tri_kernel<Ty, l>>(
582+
cgh.parallel_for<tri_kernel<Ty, upper>>(
421583
sycl::range<1>(inner_range * outer_range), [=](sycl::id<1> idx) {
422584
py::ssize_t outer_gid = idx[0] / inner_range;
423585
py::ssize_t inner_gid = idx[0] - inner_range * outer_gid;
@@ -438,7 +600,7 @@ sycl::event tri_impl(sycl::queue exec_q,
438600
inner[0] * shape_and_strides[dst_s + nd_2] +
439601
inner[1] * shape_and_strides[dst_s + nd_1];
440602

441-
if (l)
603+
if constexpr (upper)
442604
to_copy = (inner[0] + k >= inner[1]);
443605
else
444606
to_copy = (inner[0] + k <= inner[1]);
@@ -463,6 +625,10 @@ sycl::event tri_impl(sycl::queue exec_q,
463625
return tri_ev;
464626
}
465627

628+
/*!
629+
* @brief Factory to get function pointer of type `fnT` for data type `Ty`.
630+
* @ingroup CtorKernels
631+
*/
466632
template <typename fnT, typename Ty> struct TrilGenericFactory
467633
{
468634
fnT get()
@@ -472,6 +638,10 @@ template <typename fnT, typename Ty> struct TrilGenericFactory
472638
}
473639
};
474640

641+
/*!
642+
* @brief Factory to get function pointer of type `fnT` for data type `Ty`.
643+
* @ingroup CtorKernels
644+
*/
475645
template <typename fnT, typename Ty> struct TriuGenericFactory
476646
{
477647
fnT get()

0 commit comments

Comments
 (0)