Skip to content

Commit 68a652e

Browse files
committed
Moved split_iteration_space into simplify_iteration_space.cpp
1 parent 9e88171 commit 68a652e

File tree

3 files changed

+74
-61
lines changed

3 files changed

+74
-61
lines changed

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -46,51 +46,6 @@ namespace tensor
4646
namespace py_internal
4747
{
4848

49-
/* @brief Split shape/strides into dir1 (complementary to axis_start <= i <
50-
* axis_end) and dir2 (along given set of axes)
51-
*/
52-
template <typename shT>
53-
void _split_iteration_space(const shT &shape_vec,
54-
const shT &strides_vec,
55-
int axis_start,
56-
int axis_end,
57-
shT &dir1_shape_vec,
58-
shT &dir2_shape_vec,
59-
shT &dir1_strides_vec,
60-
shT &dir2_strides_vec)
61-
{
62-
int nd = static_cast<int>(shape_vec.size());
63-
int dir2_sz = axis_end - axis_start;
64-
int dir1_sz = nd - dir2_sz;
65-
66-
assert(dir1_sz > 0);
67-
assert(dir2_sz > 0);
68-
69-
dir1_shape_vec.resize(dir1_sz);
70-
dir2_shape_vec.resize(dir2_sz);
71-
72-
std::copy(shape_vec.begin(), shape_vec.begin() + axis_start,
73-
dir1_shape_vec.begin());
74-
std::copy(shape_vec.begin() + axis_end, shape_vec.end(),
75-
dir1_shape_vec.begin() + axis_start);
76-
77-
std::copy(shape_vec.begin() + axis_start, shape_vec.begin() + axis_end,
78-
dir2_shape_vec.begin());
79-
80-
dir1_strides_vec.resize(dir1_sz);
81-
dir2_strides_vec.resize(dir2_sz);
82-
83-
std::copy(strides_vec.begin(), strides_vec.begin() + axis_start,
84-
dir1_strides_vec.begin());
85-
std::copy(strides_vec.begin() + axis_end, strides_vec.end(),
86-
dir1_strides_vec.begin() + axis_start);
87-
88-
std::copy(strides_vec.begin() + axis_start, strides_vec.begin() + axis_end,
89-
dir2_strides_vec.begin());
90-
91-
return;
92-
}
93-
9449
// Masked extraction
9550

9651
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -334,19 +289,21 @@ py_extract(dpctl::tensor::usm_ndarray src,
334289
shT masked_src_shape;
335290
shT ortho_src_strides;
336291
shT masked_src_strides;
337-
_split_iteration_space(src_shape_vec, src_strides_vec, axis_start,
338-
axis_end, ortho_src_shape,
339-
masked_src_shape, // 4 vectors modified
340-
ortho_src_strides, masked_src_strides);
292+
dpctl::tensor::py_internal::split_iteration_space(
293+
src_shape_vec, src_strides_vec, axis_start, axis_end,
294+
ortho_src_shape,
295+
masked_src_shape, // 4 vectors modified
296+
ortho_src_strides, masked_src_strides);
341297

342298
shT ortho_dst_shape;
343299
shT masked_dst_shape;
344300
shT ortho_dst_strides;
345301
shT masked_dst_strides;
346-
_split_iteration_space(dst_shape_vec, dst_strides_vec, axis_start,
347-
axis_start + 1, ortho_dst_shape,
348-
masked_dst_shape, // 4 vectors modified
349-
ortho_dst_strides, masked_dst_strides);
302+
dpctl::tensor::py_internal::split_iteration_space(
303+
dst_shape_vec, dst_strides_vec, axis_start, axis_start + 1,
304+
ortho_dst_shape,
305+
masked_dst_shape, // 4 vectors modified
306+
ortho_dst_strides, masked_dst_strides);
350307

351308
assert(ortho_src_shape.size() == static_cast<size_t>(ortho_nd));
352309
assert(ortho_dst_shape.size() == static_cast<size_t>(ortho_nd));
@@ -662,19 +619,21 @@ py_place(dpctl::tensor::usm_ndarray dst,
662619
shT masked_dst_shape;
663620
shT ortho_dst_strides;
664621
shT masked_dst_strides;
665-
_split_iteration_space(dst_shape_vec, dst_strides_vec, axis_start,
666-
axis_end, ortho_dst_shape,
667-
masked_dst_shape, // 4 vectors modified
668-
ortho_dst_strides, masked_dst_strides);
622+
dpctl::tensor::py_internal::split_iteration_space(
623+
dst_shape_vec, dst_strides_vec, axis_start, axis_end,
624+
ortho_dst_shape,
625+
masked_dst_shape, // 4 vectors modified
626+
ortho_dst_strides, masked_dst_strides);
669627

670628
shT ortho_rhs_shape;
671629
shT masked_rhs_shape;
672630
shT ortho_rhs_strides;
673631
shT masked_rhs_strides;
674-
_split_iteration_space(rhs_shape_vec, rhs_strides_vec, axis_start,
675-
axis_start + 1, ortho_rhs_shape,
676-
masked_rhs_shape, // 4 vectors modified
677-
ortho_rhs_strides, masked_rhs_strides);
632+
dpctl::tensor::py_internal::split_iteration_space(
633+
rhs_shape_vec, rhs_strides_vec, axis_start, axis_start + 1,
634+
ortho_rhs_shape,
635+
masked_rhs_shape, // 4 vectors modified
636+
ortho_rhs_strides, masked_rhs_strides);
678637

679638
assert(ortho_dst_shape.size() == static_cast<size_t>(ortho_nd));
680639
assert(ortho_rhs_shape.size() == static_cast<size_t>(ortho_nd));

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,50 @@ void compact_iteration_space(int &nd,
408408
}
409409
}
410410

411+
/* @brief Split shape/strides into dir1 (complementary to axis_start <= i <
412+
* axis_end) and dir2 (along given set of axes)
413+
*/
414+
void split_iteration_space(const std::vector<py::ssize_t> &shape_vec,
415+
const std::vector<py::ssize_t> &strides_vec,
416+
int axis_start,
417+
int axis_end,
418+
std::vector<py::ssize_t> &dir1_shape_vec,
419+
std::vector<py::ssize_t> &dir2_shape_vec,
420+
std::vector<py::ssize_t> &dir1_strides_vec,
421+
std::vector<py::ssize_t> &dir2_strides_vec)
422+
{
423+
int nd = static_cast<int>(shape_vec.size());
424+
int dir2_sz = axis_end - axis_start;
425+
int dir1_sz = nd - dir2_sz;
426+
427+
assert(dir1_sz > 0);
428+
assert(dir2_sz > 0);
429+
430+
dir1_shape_vec.resize(dir1_sz);
431+
dir2_shape_vec.resize(dir2_sz);
432+
433+
std::copy(shape_vec.begin(), shape_vec.begin() + axis_start,
434+
dir1_shape_vec.begin());
435+
std::copy(shape_vec.begin() + axis_end, shape_vec.end(),
436+
dir1_shape_vec.begin() + axis_start);
437+
438+
std::copy(shape_vec.begin() + axis_start, shape_vec.begin() + axis_end,
439+
dir2_shape_vec.begin());
440+
441+
dir1_strides_vec.resize(dir1_sz);
442+
dir2_strides_vec.resize(dir2_sz);
443+
444+
std::copy(strides_vec.begin(), strides_vec.begin() + axis_start,
445+
dir1_strides_vec.begin());
446+
std::copy(strides_vec.begin() + axis_end, strides_vec.end(),
447+
dir1_strides_vec.begin() + axis_start);
448+
449+
std::copy(strides_vec.begin() + axis_start, strides_vec.begin() + axis_end,
450+
dir2_strides_vec.begin());
451+
452+
return;
453+
}
454+
411455
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &mi,
412456
std::vector<py::ssize_t> const &shape)
413457
{

dpctl/tensor/libtensor/source/simplify_iteration_space.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ void compact_iteration_space(int &,
9797
std::vector<py::ssize_t> &,
9898
std::vector<py::ssize_t> &);
9999

100+
void split_iteration_space(const std::vector<py::ssize_t> &,
101+
const std::vector<py::ssize_t> &,
102+
int,
103+
int,
104+
// output
105+
std::vector<py::ssize_t> &,
106+
std::vector<py::ssize_t> &,
107+
std::vector<py::ssize_t> &,
108+
std::vector<py::ssize_t> &);
109+
100110
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &,
101111
std::vector<py::ssize_t> const &);
102112
py::ssize_t _ravel_multi_index_f(std::vector<py::ssize_t> const &,

0 commit comments

Comments
 (0)