Skip to content

Commit 67802a3

Browse files
Merge pull request #933 from IntelPython/cleanup-tensor-step4
[cleanup/tensor, part 4] Made contract_iter and contract_iter2 functions templated to avoid needing to include pybind11 headers
2 parents 1c77305 + f1779bd commit 67802a3

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,16 +499,16 @@ int simplify_iteration_two_strides(const int nd,
499499
return nd_;
500500
}
501501

502-
using vecT = std::vector<py::ssize_t>;
503-
std::tuple<vecT, vecT, py::size_t> contract_iter(vecT shape, vecT strides)
502+
template <typename T, class Error, typename vecT = std::vector<T>>
503+
std::tuple<vecT, vecT, T> contract_iter(vecT shape, vecT strides)
504504
{
505505
const size_t dim = shape.size();
506506
if (dim != strides.size()) {
507-
throw py::value_error("Shape and strides must be of equal size.");
507+
throw Error("Shape and strides must be of equal size.");
508508
}
509509
vecT out_shape = shape;
510510
vecT out_strides = strides;
511-
py::ssize_t disp(0);
511+
T disp(0);
512512

513513
int nd = simplify_iteration_stride(dim, out_shape.data(),
514514
out_strides.data(), disp);
@@ -517,18 +517,19 @@ std::tuple<vecT, vecT, py::size_t> contract_iter(vecT shape, vecT strides)
517517
return std::make_tuple(out_shape, out_strides, disp);
518518
}
519519

520-
std::tuple<vecT, vecT, py::size_t, vecT, py::ssize_t>
520+
template <typename T, class Error, typename vecT = std::vector<T>>
521+
std::tuple<vecT, vecT, T, vecT, T>
521522
contract_iter2(vecT shape, vecT strides1, vecT strides2)
522523
{
523524
const size_t dim = shape.size();
524525
if (dim != strides1.size() || dim != strides2.size()) {
525-
throw py::value_error("Shape and strides must be of equal size.");
526+
throw Error("Shape and strides must be of equal size.");
526527
}
527528
vecT out_shape = shape;
528529
vecT out_strides1 = strides1;
529530
vecT out_strides2 = strides2;
530-
py::ssize_t disp1(0);
531-
py::ssize_t disp2(0);
531+
T disp1(0);
532+
T disp2(0);
532533

533534
int nd = simplify_iteration_two_strides(dim, out_shape.data(),
534535
out_strides1.data(),

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,7 +1449,7 @@ PYBIND11_MODULE(_tensor_impl, m)
14491449
import_dpctl();
14501450

14511451
m.def(
1452-
"_contract_iter", &contract_iter,
1452+
"_contract_iter", &contract_iter<py::ssize_t, py::value_error>,
14531453
"Simplifies iteration of array of given shape & stride. Returns "
14541454
"a triple: shape, stride and offset for the new iterator of possible "
14551455
"smaller dimension, which traverses the same elements as the original "
@@ -1464,7 +1464,7 @@ PYBIND11_MODULE(_tensor_impl, m)
14641464
py::arg("depends") = py::list());
14651465

14661466
m.def(
1467-
"_contract_iter2", &contract_iter2,
1467+
"_contract_iter2", &contract_iter2<py::ssize_t, py::value_error>,
14681468
"Simplifies iteration over elements of pair of arrays of given shape "
14691469
"with strides stride1 and stride2. Returns "
14701470
"a 5-tuple: shape, stride and offset for the new iterator of possible "

0 commit comments

Comments
 (0)