@@ -746,7 +746,85 @@ template <typename fnT, typename Ty> struct CopyForReshapeGenericFactory
746
746
}
747
747
};
748
748
749
- // =============== Copying for reshape ================== //
749
+ // ================== Copying for roll ================== //
750
+
751
+ /* ! @brief Functor to cyclically roll global_id to the left */
752
+ struct LeftRolled1DTransformer
753
+ {
754
+ LeftRolled1DTransformer (size_t offset, size_t size)
755
+ : offset_(offset), size_(size)
756
+ {
757
+ }
758
+
759
+ size_t operator ()(size_t gid) const
760
+ {
761
+ const size_t shifted_gid =
762
+ ((gid < offset_) ? gid + size_ - offset_ : gid - offset_);
763
+ return shifted_gid;
764
+ }
765
+
766
+ private:
767
+ size_t offset_ = 0 ;
768
+ size_t size_ = 1 ;
769
+ };
770
+
771
+ /* ! @brief Indexer functor to compose indexer and transformer */
772
+ template <typename IndexerT, typename TransformerT> struct CompositionIndexer
773
+ {
774
+ CompositionIndexer (IndexerT f, TransformerT t) : f_(f), t_(t) {}
775
+
776
+ auto operator ()(size_t gid) const
777
+ {
778
+ return f_ (t_ (gid));
779
+ }
780
+
781
+ private:
782
+ IndexerT f_;
783
+ TransformerT t_;
784
+ };
785
+
786
+ /* ! @brief Indexer functor to find offset for nd-shifted indices lifted from
787
+ * iteration id */
788
+ struct RolledNDIndexer
789
+ {
790
+ RolledNDIndexer (int nd,
791
+ const py::ssize_t *shape,
792
+ const py::ssize_t *strides,
793
+ const py::ssize_t *ndshifts,
794
+ py::ssize_t starting_offset)
795
+ : nd_(nd), shape_(shape), strides_(strides), ndshifts_(ndshifts),
796
+ starting_offset_ (starting_offset)
797
+ {
798
+ }
799
+
800
+ py::ssize_t operator ()(size_t gid) const
801
+ {
802
+ return compute_offset (gid);
803
+ }
804
+
805
+ private:
806
+ int nd_ = -1 ;
807
+ const py::ssize_t *shape_ = nullptr ;
808
+ const py::ssize_t *strides_ = nullptr ;
809
+ const py::ssize_t *ndshifts_ = nullptr ;
810
+ py::ssize_t starting_offset_ = 0 ;
811
+
812
+ py::ssize_t compute_offset (py::ssize_t gid) const
813
+ {
814
+ using dpctl::tensor::strides::CIndexer_vector;
815
+
816
+ CIndexer_vector _ind (nd_);
817
+ py::ssize_t relative_offset_ (0 );
818
+ _ind.get_left_rolled_displacement <const py::ssize_t *,
819
+ const py::ssize_t *>(
820
+ gid,
821
+ shape_, // shape ptr
822
+ strides_, // strides ptr
823
+ ndshifts_, // shifts ptr
824
+ relative_offset_);
825
+ return starting_offset_ + relative_offset_;
826
+ }
827
+ };
750
828
751
829
template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
752
830
class copy_for_roll_strided_kernel ;
@@ -755,32 +833,26 @@ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
755
833
class StridedCopyForRollFunctor
756
834
{
757
835
private:
758
- size_t offset = 0 ;
759
- size_t size = 1 ;
760
836
const Ty *src_p = nullptr ;
761
837
Ty *dst_p = nullptr ;
762
838
SrcIndexerT src_indexer_;
763
839
DstIndexerT dst_indexer_;
764
840
765
841
public:
766
- StridedCopyForRollFunctor (size_t shift,
767
- size_t nelems,
768
- const Ty *src_ptr,
842
+ StridedCopyForRollFunctor (const Ty *src_ptr,
769
843
Ty *dst_ptr,
770
844
SrcIndexerT src_indexer,
771
845
DstIndexerT dst_indexer)
772
- : offset(shift), size(nelems), src_p(src_ptr), dst_p(dst_ptr),
773
- src_indexer_ (src_indexer), dst_indexer_(dst_indexer)
846
+ : src_p(src_ptr), dst_p(dst_ptr), src_indexer_(src_indexer ),
847
+ dst_indexer_ (dst_indexer)
774
848
{
775
849
}
776
850
777
851
void operator ()(sycl::id<1 > wiid) const
778
852
{
779
853
const size_t gid = wiid.get (0 );
780
- const size_t shifted_gid =
781
- ((gid < offset) ? gid + size - offset : gid - offset);
782
854
783
- const py::ssize_t src_offset = src_indexer_ (shifted_gid );
855
+ const py::ssize_t src_offset = src_indexer_ (gid );
784
856
const py::ssize_t dst_offset = dst_indexer_ (gid);
785
857
786
858
dst_p[dst_offset] = src_p[src_offset];
@@ -800,8 +872,6 @@ typedef sycl::event (*copy_for_roll_strided_fn_ptr_t)(
800
872
py::ssize_t , // dst_offset
801
873
const std::vector<sycl::event> &);
802
874
803
- template <typename Ty> class copy_for_roll_contig_kernel ;
804
-
805
875
/* !
806
876
* @brief Function to copy content of array with a shift.
807
877
*
@@ -812,8 +882,8 @@ template <typename Ty> class copy_for_roll_contig_kernel;
812
882
* @param shift The shift in flat indexing, must be non-negative.
813
883
* @param nelems The number of elements to copy
814
884
* @param nd Array dimensionality of the destination and source arrays
815
- * @param packed_shapes_and_strides Kernel accessible USM array of size
816
- * `3*nd` with content `[common_shape, src_strides, dst_strides]`.
885
+ * @param packed_shapes_and_strides Kernel accessible USM array
886
+ * of size `3*nd` with content `[common_shape, src_strides, dst_strides]`.
817
887
* @param src_p Typeless USM pointer to the buffer of the source array
818
888
* @param src_offset Displacement of first element of src relative src_p in
819
889
* elements
@@ -849,21 +919,29 @@ copy_for_roll_strided_impl(sycl::queue q,
849
919
// [ common_shape; src_strides; dst_strides ]
850
920
851
921
StridedIndexer src_indexer{nd, src_offset, packed_shapes_and_strides};
922
+ LeftRolled1DTransformer left_roll_transformer{shift, nelems};
923
+
924
+ using CompositeIndexerT =
925
+ CompositionIndexer<StridedIndexer, LeftRolled1DTransformer>;
926
+
927
+ CompositeIndexerT rolled_src_indexer (src_indexer,
928
+ left_roll_transformer);
929
+
852
930
UnpackedStridedIndexer dst_indexer{nd, dst_offset,
853
931
packed_shapes_and_strides,
854
932
packed_shapes_and_strides + 2 * nd};
855
933
856
- using KernelName = copy_for_roll_strided_kernel<Ty, StridedIndexer ,
934
+ using KernelName = copy_for_roll_strided_kernel<Ty, CompositeIndexerT ,
857
935
UnpackedStridedIndexer>;
858
936
859
937
const Ty *src_tp = reinterpret_cast <const Ty *>(src_p);
860
938
Ty *dst_tp = reinterpret_cast <Ty *>(dst_p);
861
939
862
940
cgh.parallel_for <KernelName>(
863
941
sycl::range<1 >(nelems),
864
- StridedCopyForRollFunctor<Ty, StridedIndexer ,
942
+ StridedCopyForRollFunctor<Ty, CompositeIndexerT ,
865
943
UnpackedStridedIndexer>(
866
- shift, nelems, src_tp, dst_tp, src_indexer , dst_indexer));
944
+ src_tp, dst_tp, rolled_src_indexer , dst_indexer));
867
945
});
868
946
869
947
return copy_for_roll_ev;
@@ -880,6 +958,8 @@ typedef sycl::event (*copy_for_roll_contig_fn_ptr_t)(
880
958
py::ssize_t , // dst_offset
881
959
const std::vector<sycl::event> &);
882
960
961
+ template <typename Ty> class copy_for_roll_contig_kernel ;
962
+
883
963
/* !
884
964
* @brief Function to copy content of array with a shift.
885
965
*
@@ -917,6 +997,10 @@ sycl::event copy_for_roll_contig_impl(sycl::queue q,
917
997
cgh.depends_on (depends);
918
998
919
999
NoOpIndexer src_indexer{};
1000
+ LeftRolled1DTransformer roller{shift, nelems};
1001
+
1002
+ CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>
1003
+ left_rolled_src_indexer{src_indexer, roller};
920
1004
NoOpIndexer dst_indexer{};
921
1005
922
1006
using KernelName = copy_for_roll_contig_kernel<Ty>;
@@ -926,8 +1010,10 @@ sycl::event copy_for_roll_contig_impl(sycl::queue q,
926
1010
927
1011
cgh.parallel_for <KernelName>(
928
1012
sycl::range<1 >(nelems),
929
- StridedCopyForRollFunctor<Ty, NoOpIndexer, NoOpIndexer>(
930
- shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
1013
+ StridedCopyForRollFunctor<
1014
+ Ty, CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>,
1015
+ NoOpIndexer>(src_tp, dst_tp, left_rolled_src_indexer,
1016
+ dst_indexer));
931
1017
});
932
1018
933
1019
return copy_for_roll_ev;
@@ -961,6 +1047,86 @@ template <typename fnT, typename Ty> struct CopyForRollContigFactory
961
1047
}
962
1048
};
963
1049
1050
+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
1051
+ class copy_for_roll_ndshift_strided_kernel ;
1052
+
1053
+ // define function type
1054
+ typedef sycl::event (*copy_for_roll_ndshift_strided_fn_ptr_t )(
1055
+ sycl::queue,
1056
+ size_t , // num_elements
1057
+ int , // common_nd
1058
+ const py::ssize_t *, // packed shape, strides, shifts
1059
+ const char *, // src_data_ptr
1060
+ py::ssize_t , // src_offset
1061
+ char *, // dst_data_ptr
1062
+ py::ssize_t , // dst_offset
1063
+ const std::vector<sycl::event> &);
1064
+
1065
+ template <typename Ty>
1066
+ sycl::event copy_for_roll_ndshift_strided_impl (
1067
+ sycl::queue q,
1068
+ size_t nelems,
1069
+ int nd,
1070
+ const py::ssize_t *packed_shapes_and_strides_and_shifts,
1071
+ const char *src_p,
1072
+ py::ssize_t src_offset,
1073
+ char *dst_p,
1074
+ py::ssize_t dst_offset,
1075
+ const std::vector<sycl::event> &depends)
1076
+ {
1077
+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
1078
+
1079
+ sycl::event copy_for_roll_ev = q.submit ([&](sycl::handler &cgh) {
1080
+ cgh.depends_on (depends);
1081
+
1082
+ // packed_shapes_and_strides_and_shifts:
1083
+ // USM array of size 4 * nd
1084
+ // [ common_shape; src_strides; dst_strides; shifts ]
1085
+
1086
+ const py::ssize_t *shape_ptr = packed_shapes_and_strides_and_shifts;
1087
+ const py::ssize_t *src_strides_ptr =
1088
+ packed_shapes_and_strides_and_shifts + nd;
1089
+ const py::ssize_t *dst_strides_ptr =
1090
+ packed_shapes_and_strides_and_shifts + 2 * nd;
1091
+ const py::ssize_t *shifts_ptr =
1092
+ packed_shapes_and_strides_and_shifts + 3 * nd;
1093
+
1094
+ RolledNDIndexer src_indexer{nd, shape_ptr, src_strides_ptr, shifts_ptr,
1095
+ src_offset};
1096
+
1097
+ UnpackedStridedIndexer dst_indexer{nd, dst_offset, shape_ptr,
1098
+ dst_strides_ptr};
1099
+
1100
+ using KernelName = copy_for_roll_strided_kernel<Ty, RolledNDIndexer,
1101
+ UnpackedStridedIndexer>;
1102
+
1103
+ const Ty *src_tp = reinterpret_cast <const Ty *>(src_p);
1104
+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_p);
1105
+
1106
+ cgh.parallel_for <KernelName>(
1107
+ sycl::range<1 >(nelems),
1108
+ StridedCopyForRollFunctor<Ty, RolledNDIndexer,
1109
+ UnpackedStridedIndexer>(
1110
+ src_tp, dst_tp, src_indexer, dst_indexer));
1111
+ });
1112
+
1113
+ return copy_for_roll_ev;
1114
+ }
1115
+
1116
+ /* !
1117
+ * @brief Factory to get function pointer of type `fnT` for given array data
1118
+ * type `Ty`.
1119
+ * @ingroup CopyAndCastKernels
1120
+ */
1121
+ template <typename fnT, typename Ty> struct CopyForRollNDShiftFactory
1122
+ {
1123
+ fnT get ()
1124
+ {
1125
+ fnT f = copy_for_roll_ndshift_strided_impl<Ty>;
1126
+ return f;
1127
+ }
1128
+ };
1129
+
964
1130
} // namespace copy_and_cast
965
1131
} // namespace kernels
966
1132
} // namespace tensor
0 commit comments