@@ -664,6 +664,63 @@ template <int nd> struct TwoOffsets_FixedDimStridedIndexer
664
664
py::ssize_t starting_offset2;
665
665
};
666
666
667
+ template <int nd> struct ThreeOffsets_FixedDimStridedIndexer
668
+ {
669
+ ThreeOffsets_FixedDimStridedIndexer (
670
+ const std::array<py::ssize_t , nd> _shape,
671
+ const std::array<py::ssize_t , nd> _strides1,
672
+ const std::array<py::ssize_t , nd> _strides2,
673
+ const std::array<py::ssize_t , nd> _strides3,
674
+ py::ssize_t _offset1,
675
+ py::ssize_t _offset2,
676
+ py::ssize_t _offset3)
677
+ : _ind(_shape), strides1(_strides1), strides2(_strides2),
678
+ strides3 (_strides3), starting_offset1(_offset1),
679
+ starting_offset2(_offset2), starting_offset3(_offset3)
680
+ {
681
+ }
682
+
683
+ ThreeOffsets<py::ssize_t > operator ()(size_t gid) const
684
+ {
685
+ dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t > local_indexer (
686
+ std::move (_ind));
687
+ local_indexer.set (gid);
688
+ auto mi = local_indexer.get ();
689
+
690
+ py::ssize_t relative_offset1 = 0 ;
691
+ #pragma unroll
692
+ for (int i = 0 ; i < nd; ++i) {
693
+ relative_offset1 += mi[i] * strides1[i];
694
+ }
695
+
696
+ py::ssize_t relative_offset2 = 0 ;
697
+ #pragma unroll
698
+ for (int i = 0 ; i < nd; ++i) {
699
+ relative_offset2 += mi[i] * strides2[i];
700
+ }
701
+
702
+ py::ssize_t relative_offset3 = 0 ;
703
+ #pragma unroll
704
+ for (int i = 0 ; i < nd; ++i) {
705
+ relative_offset3 += mi[i] * strides3[i];
706
+ }
707
+
708
+ return ThreeOffsets<py::ssize_t >(starting_offset1 + relative_offset1,
709
+ starting_offset2 + relative_offset2,
710
+ starting_offset3 + relative_offset3);
711
+ }
712
+
713
+ private:
714
+ dpctl::tensor::strides::CIndexer_array<nd, py::ssize_t > _ind;
715
+
716
+ const std::array<py::ssize_t , nd> strides1;
717
+ const std::array<py::ssize_t , nd> strides2;
718
+ const std::array<py::ssize_t , nd> strides3;
719
+ py::ssize_t starting_offset1;
720
+ py::ssize_t starting_offset2;
721
+ py::ssize_t starting_offset3;
722
+ };
723
+
667
724
} // namespace offset_utils
668
725
} // namespace tensor
669
726
} // namespace dpctl
0 commit comments