Skip to content

Commit 949711e

Browse files
Merge pull request #1044 from IntelPython/add-contract-iter3
Added _contract_iter3 utility to simplify iteration space over 3 arrays
2 parents 6ca4bbb + e5c7552 commit 949711e

File tree

2 files changed

+159
-5
lines changed

2 files changed

+159
-5
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 149 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,8 @@ int simplify_iteration_stride(const int nd,
408408
409409
The new shape and new strides, as well as the offset
410410
`(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that
411-
iterating over them will traverse the same pairs of elements, possibly in
412-
different order.
413-
411+
iterating over them will traverse the same set of pairs of elements,
412+
possibly in a different order.
414413
*/
415414
template <class ShapeTy, class StridesTy>
416415
int simplify_iteration_two_strides(const int nd,
@@ -447,7 +446,7 @@ int simplify_iteration_two_strides(const int nd,
447446
auto str1_p = strides1[p];
448447
auto str2_p = strides2[p];
449448
shape_w.push_back(sh_p);
450-
if (str1_p < 0 && str2_p < 0) {
449+
if (str1_p <= 0 && str2_p <= 0 && std::min(str1_p, str2_p) < 0) {
451450
disp1 += str1_p * (sh_p - 1);
452451
str1_p = -str1_p;
453452
disp2 += str2_p * (sh_p - 1);
@@ -468,7 +467,7 @@ int simplify_iteration_two_strides(const int nd,
468467
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
469468
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;
470469

471-
if (jump1 == str1 and jump2 == str2) {
470+
if (jump1 == str1 && jump2 == str2) {
472471
changed = true;
473472
shape_w[i] *= shape_w[i + 1];
474473
for (int j = i; j < nd_; ++j) {
@@ -540,3 +539,148 @@ contract_iter2(vecT shape, vecT strides1, vecT strides2)
540539
out_strides2.resize(nd);
541540
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2);
542541
}
542+
543+
/*
544+
For purposes of iterating over pairs of elements of three arrays
545+
with `shape` and strides `strides1`, `strides2`, `strides3` given as
546+
pointers `simplify_iteration_three_strides(nd, shape_ptr, strides1_ptr,
547+
strides2_ptr, strides3_ptr, disp1, disp2, disp3)`
548+
may modify memory and returns new length of these arrays.
549+
550+
The new shape and new strides, as well as the offset
551+
`(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3)`
552+
are such that iterating over them will traverse the same set of tuples of
553+
elements, possibly in a different order.
554+
*/
555+
template <class ShapeTy, class StridesTy>
556+
int simplify_iteration_three_strides(const int nd,
557+
ShapeTy *shape,
558+
StridesTy *strides1,
559+
StridesTy *strides2,
560+
StridesTy *strides3,
561+
StridesTy &disp1,
562+
StridesTy &disp2,
563+
StridesTy &disp3)
564+
{
565+
disp1 = std::ptrdiff_t(0);
566+
disp2 = std::ptrdiff_t(0);
567+
if (nd < 2)
568+
return nd;
569+
570+
std::vector<int> pos(nd);
571+
std::iota(pos.begin(), pos.end(), 0);
572+
573+
std::stable_sort(
574+
pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) {
575+
auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
576+
auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
577+
return (abs_str1 > abs_str2) ||
578+
(abs_str1 == abs_str2 && shape[i1] > shape[i2]);
579+
});
580+
581+
std::vector<ShapeTy> shape_w;
582+
std::vector<StridesTy> strides1_w;
583+
std::vector<StridesTy> strides2_w;
584+
std::vector<StridesTy> strides3_w;
585+
586+
bool contractable = true;
587+
for (int i = 0; i < nd; ++i) {
588+
auto p = pos[i];
589+
auto sh_p = shape[p];
590+
auto str1_p = strides1[p];
591+
auto str2_p = strides2[p];
592+
auto str3_p = strides3[p];
593+
shape_w.push_back(sh_p);
594+
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
595+
std::min(std::min(str1_p, str2_p), str3_p) < 0)
596+
{
597+
disp1 += str1_p * (sh_p - 1);
598+
str1_p = -str1_p;
599+
disp2 += str2_p * (sh_p - 1);
600+
str2_p = -str2_p;
601+
disp3 += str3_p * (sh_p - 1);
602+
str3_p = -str3_p;
603+
}
604+
if (str1_p < 0 || str2_p < 0 || str3_p < 0) {
605+
contractable = false;
606+
}
607+
strides1_w.push_back(str1_p);
608+
strides2_w.push_back(str2_p);
609+
strides3_w.push_back(str3_p);
610+
}
611+
int nd_ = nd;
612+
while (contractable) {
613+
bool changed = false;
614+
for (int i = 0; i + 1 < nd_; ++i) {
615+
StridesTy str1 = strides1_w[i + 1];
616+
StridesTy str2 = strides2_w[i + 1];
617+
StridesTy str3 = strides3_w[i + 1];
618+
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
619+
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;
620+
StridesTy jump3 = strides3_w[i] - (shape_w[i + 1] - 1) * str3;
621+
622+
if (jump1 == str1 && jump2 == str2 && jump3 == str3) {
623+
changed = true;
624+
shape_w[i] *= shape_w[i + 1];
625+
for (int j = i; j < nd_; ++j) {
626+
strides1_w[j] = strides1_w[j + 1];
627+
}
628+
for (int j = i; j < nd_; ++j) {
629+
strides2_w[j] = strides2_w[j + 1];
630+
}
631+
for (int j = i; j < nd_; ++j) {
632+
strides3_w[j] = strides3_w[j + 1];
633+
}
634+
for (int j = i + 1; j + 1 < nd_; ++j) {
635+
shape_w[j] = shape_w[j + 1];
636+
}
637+
--nd_;
638+
break;
639+
}
640+
}
641+
if (!changed)
642+
break;
643+
}
644+
for (int i = 0; i < nd_; ++i) {
645+
shape[i] = shape_w[i];
646+
}
647+
for (int i = 0; i < nd_; ++i) {
648+
strides1[i] = strides1_w[i];
649+
}
650+
for (int i = 0; i < nd_; ++i) {
651+
strides2[i] = strides2_w[i];
652+
}
653+
for (int i = 0; i < nd_; ++i) {
654+
strides3[i] = strides3_w[i];
655+
}
656+
657+
return nd_;
658+
}
659+
660+
template <typename T, class Error, typename vecT = std::vector<T>>
661+
std::tuple<vecT, vecT, T, vecT, T, vecT, T>
662+
contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
663+
{
664+
const size_t dim = shape.size();
665+
if (dim != strides1.size() || dim != strides2.size() ||
666+
dim != strides3.size()) {
667+
throw Error("Shape and strides must be of equal size.");
668+
}
669+
vecT out_shape = shape;
670+
vecT out_strides1 = strides1;
671+
vecT out_strides2 = strides2;
672+
vecT out_strides3 = strides3;
673+
T disp1(0);
674+
T disp2(0);
675+
T disp3(0);
676+
677+
int nd = simplify_iteration_three_strides(
678+
dim, out_shape.data(), out_strides1.data(), out_strides2.data(),
679+
out_strides3.data(), disp1, disp2, disp3);
680+
out_shape.resize(nd);
681+
out_strides1.resize(nd);
682+
out_strides2.resize(nd);
683+
out_strides3.resize(nd);
684+
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2,
685+
out_strides3, disp3);
686+
}

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ PYBIND11_MODULE(_tensor_impl, m)
133133
"as the original "
134134
"iterator, possibly in a different order.");
135135

136+
m.def(
137+
"_contract_iter3", &contract_iter3<py::ssize_t, py::value_error>,
138+
"Simplifies iteration over elements of 3-tuple of arrays of given "
139+
"shape "
140+
"with strides stride1, stride2, and stride3. Returns "
141+
"a 7-tuple: shape, stride and offset for the new iterator of possible "
142+
"smaller dimension for each array, which traverses the same elements "
143+
"as the original "
144+
"iterator, possibly in a different order.");
145+
136146
m.def("_copy_usm_ndarray_for_reshape", &copy_usm_ndarray_for_reshape,
137147
"Copies from usm_ndarray `src` into usm_ndarray `dst` with the same "
138148
"number of elements using underlying 'C'-contiguous order for flat "

0 commit comments

Comments
 (0)