Skip to content

Commit aea79dd

Browse files
Add method CIndexer_vector::get_left_rolled_displacement
This is used to compute displacement for a[(i0 - shifts[0]) % shape[0], (i1 - shifts[1]) % shape[1], ... ]
1 parent 2c3f748 commit aea79dd

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,30 @@ template <typename indT = std::ptrdiff_t> class CIndexer_vector
238238
}
239239
return;
240240
}
241+
242+
template <class ShapeTy, class StridesTy>
243+
void get_left_rolled_displacement(indT i,
244+
ShapeTy shape,
245+
StridesTy stride,
246+
StridesTy shifts,
247+
indT &disp) const
248+
{
249+
indT i_ = i;
250+
indT d = 0;
251+
for (int dim = nd; --dim > 0;) {
252+
const indT si = shape[dim];
253+
const indT q = i_ / si;
254+
const indT r = (i_ - q * si);
255+
// assumes si > shifts[dim] >= 0
256+
const indT shifted_r =
257+
(r < shifts[dim] ? r + si - shifts[dim] : r - shifts[dim]);
258+
d += shifted_r * stride[dim];
259+
i_ = q;
260+
}
261+
const indT shifted_r =
262+
(i_ < shifts[0] ? i_ + shape[0] - shifts[0] : i_ - shifts[0]);
263+
disp = d + shifted_r * stride[0];
264+
}
241265
};
242266

243267
/*

0 commit comments

Comments
 (0)