Skip to content

Commit 558765f

Browse files
Introduced dedicated function for rolling with nd-shift
Function for flattened rolling is renamed: _copy_usm_ndarray_for_roll -> _copy_usm_ndarray_for_roll_1d _copy_usm_ndarray_for_roll_1d has the same signature: _copy_usm_ndarray_for_roll_1d( src : usm_ndarray, dst : usm_ndarray, shift: Int, sycl_queue: dpctl.SyclQueue) -> Tuple[dpctl.SyclEvent, dpctl.SyclEvent] Introduced _copy_usm_ndarray_for_roll_nd( src : usm_ndarray, dst : usm_ndarray, shifts: Tuple[Int], sycl_queue: dpctl.SyclQueue) -> Tuple[dpctl.SyclEvent, dpctl.SyclEvent] The length of shifts tuple must be the same as the dimensionality of src and dst arrays, which are supposed to have the same shape and the same data type.
1 parent aea79dd commit 558765f

File tree

4 files changed

+369
-35
lines changed

4 files changed

+369
-35
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 186 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,85 @@ template <typename fnT, typename Ty> struct CopyForReshapeGenericFactory
746746
}
747747
};
748748

749-
// =============== Copying for reshape ================== //
749+
// ================== Copying for roll ================== //
750+
751+
/*! @brief Functor to cyclically roll global_id to the left */
752+
struct LeftRolled1DTransformer
753+
{
754+
LeftRolled1DTransformer(size_t offset, size_t size)
755+
: offset_(offset), size_(size)
756+
{
757+
}
758+
759+
size_t operator()(size_t gid) const
760+
{
761+
const size_t shifted_gid =
762+
((gid < offset_) ? gid + size_ - offset_ : gid - offset_);
763+
return shifted_gid;
764+
}
765+
766+
private:
767+
size_t offset_ = 0;
768+
size_t size_ = 1;
769+
};
770+
771+
/*! @brief Indexer functor to compose indexer and transformer */
772+
template <typename IndexerT, typename TransformerT> struct CompositionIndexer
773+
{
774+
CompositionIndexer(IndexerT f, TransformerT t) : f_(f), t_(t) {}
775+
776+
auto operator()(size_t gid) const
777+
{
778+
return f_(t_(gid));
779+
}
780+
781+
private:
782+
IndexerT f_;
783+
TransformerT t_;
784+
};
785+
786+
/*! @brief Indexer functor to find offset for nd-shifted indices lifted from
787+
* iteration id */
788+
struct RolledNDIndexer
789+
{
790+
RolledNDIndexer(int nd,
791+
const py::ssize_t *shape,
792+
const py::ssize_t *strides,
793+
const py::ssize_t *ndshifts,
794+
py::ssize_t starting_offset)
795+
: nd_(nd), shape_(shape), strides_(strides), ndshifts_(ndshifts),
796+
starting_offset_(starting_offset)
797+
{
798+
}
799+
800+
py::ssize_t operator()(size_t gid) const
801+
{
802+
return compute_offset(gid);
803+
}
804+
805+
private:
806+
int nd_ = -1;
807+
const py::ssize_t *shape_ = nullptr;
808+
const py::ssize_t *strides_ = nullptr;
809+
const py::ssize_t *ndshifts_ = nullptr;
810+
py::ssize_t starting_offset_ = 0;
811+
812+
py::ssize_t compute_offset(py::ssize_t gid) const
813+
{
814+
using dpctl::tensor::strides::CIndexer_vector;
815+
816+
CIndexer_vector _ind(nd_);
817+
py::ssize_t relative_offset_(0);
818+
_ind.get_left_rolled_displacement<const py::ssize_t *,
819+
const py::ssize_t *>(
820+
gid,
821+
shape_, // shape ptr
822+
strides_, // strides ptr
823+
ndshifts_, // shifts ptr
824+
relative_offset_);
825+
return starting_offset_ + relative_offset_;
826+
}
827+
};
750828

751829
template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
752830
class copy_for_roll_strided_kernel;
@@ -755,32 +833,26 @@ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
755833
class StridedCopyForRollFunctor
756834
{
757835
private:
758-
size_t offset = 0;
759-
size_t size = 1;
760836
const Ty *src_p = nullptr;
761837
Ty *dst_p = nullptr;
762838
SrcIndexerT src_indexer_;
763839
DstIndexerT dst_indexer_;
764840

765841
public:
766-
StridedCopyForRollFunctor(size_t shift,
767-
size_t nelems,
768-
const Ty *src_ptr,
842+
StridedCopyForRollFunctor(const Ty *src_ptr,
769843
Ty *dst_ptr,
770844
SrcIndexerT src_indexer,
771845
DstIndexerT dst_indexer)
772-
: offset(shift), size(nelems), src_p(src_ptr), dst_p(dst_ptr),
773-
src_indexer_(src_indexer), dst_indexer_(dst_indexer)
846+
: src_p(src_ptr), dst_p(dst_ptr), src_indexer_(src_indexer),
847+
dst_indexer_(dst_indexer)
774848
{
775849
}
776850

777851
void operator()(sycl::id<1> wiid) const
778852
{
779853
const size_t gid = wiid.get(0);
780-
const size_t shifted_gid =
781-
((gid < offset) ? gid + size - offset : gid - offset);
782854

783-
const py::ssize_t src_offset = src_indexer_(shifted_gid);
855+
const py::ssize_t src_offset = src_indexer_(gid);
784856
const py::ssize_t dst_offset = dst_indexer_(gid);
785857

786858
dst_p[dst_offset] = src_p[src_offset];
@@ -800,8 +872,6 @@ typedef sycl::event (*copy_for_roll_strided_fn_ptr_t)(
800872
py::ssize_t, // dst_offset
801873
const std::vector<sycl::event> &);
802874

803-
template <typename Ty> class copy_for_roll_contig_kernel;
804-
805875
/*!
806876
* @brief Function to copy content of array with a shift.
807877
*
@@ -812,8 +882,8 @@ template <typename Ty> class copy_for_roll_contig_kernel;
812882
* @param shift The shift in flat indexing, must be non-negative.
813883
* @param nelems The number of elements to copy
814884
* @param nd Array dimensionality of the destination and source arrays
815-
* @param packed_shapes_and_strides Kernel accessible USM array of size
816-
* `3*nd` with content `[common_shape, src_strides, dst_strides]`.
885+
* @param packed_shapes_and_strides Kernel accessible USM array
886+
* of size `3*nd` with content `[common_shape, src_strides, dst_strides]`.
817887
* @param src_p Typeless USM pointer to the buffer of the source array
818888
* @param src_offset Displacement of first element of src relative src_p in
819889
* elements
@@ -849,21 +919,29 @@ copy_for_roll_strided_impl(sycl::queue q,
849919
// [ common_shape; src_strides; dst_strides ]
850920

851921
StridedIndexer src_indexer{nd, src_offset, packed_shapes_and_strides};
922+
LeftRolled1DTransformer left_roll_transformer{shift, nelems};
923+
924+
using CompositeIndexerT =
925+
CompositionIndexer<StridedIndexer, LeftRolled1DTransformer>;
926+
927+
CompositeIndexerT rolled_src_indexer(src_indexer,
928+
left_roll_transformer);
929+
852930
UnpackedStridedIndexer dst_indexer{nd, dst_offset,
853931
packed_shapes_and_strides,
854932
packed_shapes_and_strides + 2 * nd};
855933

856-
using KernelName = copy_for_roll_strided_kernel<Ty, StridedIndexer,
934+
using KernelName = copy_for_roll_strided_kernel<Ty, CompositeIndexerT,
857935
UnpackedStridedIndexer>;
858936

859937
const Ty *src_tp = reinterpret_cast<const Ty *>(src_p);
860938
Ty *dst_tp = reinterpret_cast<Ty *>(dst_p);
861939

862940
cgh.parallel_for<KernelName>(
863941
sycl::range<1>(nelems),
864-
StridedCopyForRollFunctor<Ty, StridedIndexer,
942+
StridedCopyForRollFunctor<Ty, CompositeIndexerT,
865943
UnpackedStridedIndexer>(
866-
shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
944+
src_tp, dst_tp, rolled_src_indexer, dst_indexer));
867945
});
868946

869947
return copy_for_roll_ev;
@@ -880,6 +958,8 @@ typedef sycl::event (*copy_for_roll_contig_fn_ptr_t)(
880958
py::ssize_t, // dst_offset
881959
const std::vector<sycl::event> &);
882960

961+
template <typename Ty> class copy_for_roll_contig_kernel;
962+
883963
/*!
884964
* @brief Function to copy content of array with a shift.
885965
*
@@ -917,6 +997,10 @@ sycl::event copy_for_roll_contig_impl(sycl::queue q,
917997
cgh.depends_on(depends);
918998

919999
NoOpIndexer src_indexer{};
1000+
LeftRolled1DTransformer roller{shift, nelems};
1001+
1002+
CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>
1003+
left_rolled_src_indexer{src_indexer, roller};
9201004
NoOpIndexer dst_indexer{};
9211005

9221006
using KernelName = copy_for_roll_contig_kernel<Ty>;
@@ -926,8 +1010,10 @@ sycl::event copy_for_roll_contig_impl(sycl::queue q,
9261010

9271011
cgh.parallel_for<KernelName>(
9281012
sycl::range<1>(nelems),
929-
StridedCopyForRollFunctor<Ty, NoOpIndexer, NoOpIndexer>(
930-
shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
1013+
StridedCopyForRollFunctor<
1014+
Ty, CompositionIndexer<NoOpIndexer, LeftRolled1DTransformer>,
1015+
NoOpIndexer>(src_tp, dst_tp, left_rolled_src_indexer,
1016+
dst_indexer));
9311017
});
9321018

9331019
return copy_for_roll_ev;
@@ -961,6 +1047,86 @@ template <typename fnT, typename Ty> struct CopyForRollContigFactory
9611047
}
9621048
};
9631049

1050+
template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
1051+
class copy_for_roll_ndshift_strided_kernel;
1052+
1053+
// define function type
1054+
typedef sycl::event (*copy_for_roll_ndshift_strided_fn_ptr_t)(
1055+
sycl::queue,
1056+
size_t, // num_elements
1057+
int, // common_nd
1058+
const py::ssize_t *, // packed shape, strides, shifts
1059+
const char *, // src_data_ptr
1060+
py::ssize_t, // src_offset
1061+
char *, // dst_data_ptr
1062+
py::ssize_t, // dst_offset
1063+
const std::vector<sycl::event> &);
1064+
1065+
template <typename Ty>
1066+
sycl::event copy_for_roll_ndshift_strided_impl(
1067+
sycl::queue q,
1068+
size_t nelems,
1069+
int nd,
1070+
const py::ssize_t *packed_shapes_and_strides_and_shifts,
1071+
const char *src_p,
1072+
py::ssize_t src_offset,
1073+
char *dst_p,
1074+
py::ssize_t dst_offset,
1075+
const std::vector<sycl::event> &depends)
1076+
{
1077+
dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
1078+
1079+
sycl::event copy_for_roll_ev = q.submit([&](sycl::handler &cgh) {
1080+
cgh.depends_on(depends);
1081+
1082+
// packed_shapes_and_strides_and_shifts:
1083+
// USM array of size 4 * nd
1084+
// [ common_shape; src_strides; dst_strides; shifts ]
1085+
1086+
const py::ssize_t *shape_ptr = packed_shapes_and_strides_and_shifts;
1087+
const py::ssize_t *src_strides_ptr =
1088+
packed_shapes_and_strides_and_shifts + nd;
1089+
const py::ssize_t *dst_strides_ptr =
1090+
packed_shapes_and_strides_and_shifts + 2 * nd;
1091+
const py::ssize_t *shifts_ptr =
1092+
packed_shapes_and_strides_and_shifts + 3 * nd;
1093+
1094+
RolledNDIndexer src_indexer{nd, shape_ptr, src_strides_ptr, shifts_ptr,
1095+
src_offset};
1096+
1097+
UnpackedStridedIndexer dst_indexer{nd, dst_offset, shape_ptr,
1098+
dst_strides_ptr};
1099+
1100+
using KernelName = copy_for_roll_strided_kernel<Ty, RolledNDIndexer,
1101+
UnpackedStridedIndexer>;
1102+
1103+
const Ty *src_tp = reinterpret_cast<const Ty *>(src_p);
1104+
Ty *dst_tp = reinterpret_cast<Ty *>(dst_p);
1105+
1106+
cgh.parallel_for<KernelName>(
1107+
sycl::range<1>(nelems),
1108+
StridedCopyForRollFunctor<Ty, RolledNDIndexer,
1109+
UnpackedStridedIndexer>(
1110+
src_tp, dst_tp, src_indexer, dst_indexer));
1111+
});
1112+
1113+
return copy_for_roll_ev;
1114+
}
1115+
1116+
/*!
1117+
* @brief Factory to get function pointer of type `fnT` for given array data
1118+
* type `Ty`.
1119+
* @ingroup CopyAndCastKernels
1120+
*/
1121+
template <typename fnT, typename Ty> struct CopyForRollNDShiftFactory
1122+
{
1123+
fnT get()
1124+
{
1125+
fnT f = copy_for_roll_ndshift_strided_impl<Ty>;
1126+
return f;
1127+
}
1128+
};
1129+
9641130
} // namespace copy_and_cast
9651131
} // namespace kernels
9661132
} // namespace tensor

0 commit comments

Comments
 (0)