Skip to content

Commit db68b36

Browse files
Merge pull request #925 from IntelPython/cleanup-tensor-py
Moved implementation of kernels out to dedicated header files.
2 parents 1cdc2a6 + d658ebc commit db68b36

File tree

6 files changed

+1489
-897
lines changed

6 files changed

+1489
-897
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,57 @@ class usm_memory : public py::object
371371

372372
namespace tensor
373373
{
374+
375+
inline std::vector<py::ssize_t>
376+
c_contiguous_strides(int nd,
377+
const py::ssize_t *shape,
378+
py::ssize_t element_size = 1)
379+
{
380+
if (nd > 0) {
381+
std::vector<py::ssize_t> c_strides(nd, element_size);
382+
for (int ic = nd - 1; ic > 0;) {
383+
py::ssize_t next_v = c_strides[ic] * shape[ic];
384+
c_strides[--ic] = next_v;
385+
}
386+
return c_strides;
387+
}
388+
else {
389+
return std::vector<py::ssize_t>();
390+
}
391+
}
392+
393+
inline std::vector<py::ssize_t>
394+
f_contiguous_strides(int nd,
395+
const py::ssize_t *shape,
396+
py::ssize_t element_size = 1)
397+
{
398+
if (nd > 0) {
399+
std::vector<py::ssize_t> f_strides(nd, element_size);
400+
for (int i = 0; i < nd - 1;) {
401+
py::ssize_t next_v = f_strides[i] * shape[i];
402+
f_strides[++i] = next_v;
403+
}
404+
return f_strides;
405+
}
406+
else {
407+
return std::vector<py::ssize_t>();
408+
}
409+
}
410+
411+
inline std::vector<py::ssize_t>
412+
c_contiguous_strides(const std::vector<py::ssize_t> &shape,
413+
py::ssize_t element_size = 1)
414+
{
415+
return c_contiguous_strides(shape.size(), shape.data(), element_size);
416+
}
417+
418+
inline std::vector<py::ssize_t>
419+
f_contiguous_strides(const std::vector<py::ssize_t> &shape,
420+
py::ssize_t element_size = 1)
421+
{
422+
return f_contiguous_strides(shape.size(), shape.data(), element_size);
423+
}
424+
374425
class usm_ndarray : public py::object
375426
{
376427
public:

dpctl/apis/include/dpctl_capi.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@
4747
* C functions can use dpctl's C-API functions without linking to
4848
* shared objects defining this symbols, if they call `import_dpctl()`
4949
* prior to using those symbols.
50+
*
51+
* It is declared inline to allow multiple definitions in
52+
* different translation units
5053
*/
51-
void import_dpctl(void)
54+
static inline void import_dpctl(void)
5255
{
5356
import_dpctl___sycl_device();
5457
import_dpctl___sycl_context();

0 commit comments

Comments
 (0)