Skip to content

Commit 7e798b4

Browse files
Merge pull request #1305 from IntelPython/three-strides-fix-dim-indexer
2 parents 9380514 + bbb2466 commit 7e798b4

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

dpctl/tensor/libtensor/include/utils/offset_utils.hpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,63 @@ template <int nd> struct TwoOffsets_FixedDimStridedIndexer
664664
py::ssize_t starting_offset2;
665665
};
666666

667+
template <int nd> struct ThreeOffsets_FixedDimStridedIndexer
668+
{
669+
ThreeOffsets_FixedDimStridedIndexer(
670+
const std::array<py::ssize_t, nd> _shape,
671+
const std::array<py::ssize_t, nd> _strides1,
672+
const std::array<py::ssize_t, nd> _strides2,
673+
const std::array<py::ssize_t, nd> _strides3,
674+
py::ssize_t _offset1,
675+
py::ssize_t _offset2,
676+
py::ssize_t _offset3)
677+
: _ind(_shape), strides1(_strides1), strides2(_strides2),
678+
strides3(_strides3), starting_offset1(_offset1),
679+
starting_offset2(_offset2), starting_offset3(_offset3)
680+
{
681+
}
682+
683+
ThreeOffsets<py::ssize_t> operator()(size_t gid) const
684+
{
685+
dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t> local_indexer(
686+
std::move(_ind));
687+
local_indexer.set(gid);
688+
auto mi = local_indexer.get();
689+
690+
py::ssize_t relative_offset1 = 0;
691+
#pragma unroll
692+
for (int i = 0; i < nd; ++i) {
693+
relative_offset1 += mi[i] * strides1[i];
694+
}
695+
696+
py::ssize_t relative_offset2 = 0;
697+
#pragma unroll
698+
for (int i = 0; i < nd; ++i) {
699+
relative_offset2 += mi[i] * strides2[i];
700+
}
701+
702+
py::ssize_t relative_offset3 = 0;
703+
#pragma unroll
704+
for (int i = 0; i < nd; ++i) {
705+
relative_offset3 += mi[i] * strides3[i];
706+
}
707+
708+
return ThreeOffsets<py::ssize_t>(starting_offset1 + relative_offset1,
709+
starting_offset2 + relative_offset2,
710+
starting_offset3 + relative_offset3);
711+
}
712+
713+
private:
714+
dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t> _ind;
715+
716+
const std::array<py::ssize_t, nd> strides1;
717+
const std::array<py::ssize_t, nd> strides2;
718+
const std::array<py::ssize_t, nd> strides3;
719+
py::ssize_t starting_offset1;
720+
py::ssize_t starting_offset2;
721+
py::ssize_t starting_offset3;
722+
};
723+
667724
} // namespace offset_utils
668725
} // namespace tensor
669726
} // namespace dpctl

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ template <int _ndim, typename indT = std::ptrdiff_t> class CIndexer_array
296296
}
297297

298298
indT i_ = i;
299+
#pragma unroll
299300
for (int dim = ndim; --dim > 0;) {
300301
indT si = shape[dim];
301302
indT q = i_ / si;

0 commit comments

Comments
 (0)