@@ -56,9 +56,6 @@ class copy_cast_contig_kernel;
56
56
template <typename srcT, typename dstT, typename IndexerT>
57
57
class copy_cast_from_host_kernel ;
58
58
59
- template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
60
- class copy_for_reshape_generic_kernel ;
61
-
62
59
template <typename srcTy, typename dstTy> class Caster
63
60
{
64
61
public:
@@ -629,68 +626,56 @@ struct CopyAndCastFromHostFactory
629
626
630
627
// =============== Copying for reshape ================== //
631
628
629
+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
630
+ class copy_for_reshape_generic_kernel ;
631
+
632
632
template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
633
633
class GenericCopyForReshapeFunctor
634
634
{
635
635
private:
636
- py::ssize_t offset = 0 ;
637
- py::ssize_t size = 1 ;
638
- // USM array of size 2*(src_nd + dst_nd)
639
- // [ src_shape; src_strides; dst_shape; dst_strides ]
640
- Ty *src_p = nullptr ;
636
+ const Ty *src_p = nullptr ;
641
637
Ty *dst_p = nullptr ;
642
638
SrcIndexerT src_indexer_;
643
639
DstIndexerT dst_indexer_;
644
640
645
641
public:
646
- GenericCopyForReshapeFunctor (py::ssize_t shift,
647
- py::ssize_t nelems,
648
- char *src_ptr,
642
+ GenericCopyForReshapeFunctor (const char *src_ptr,
649
643
char *dst_ptr,
650
644
SrcIndexerT src_indexer,
651
645
DstIndexerT dst_indexer)
652
- : offset(shift), size(nelems), src_p(reinterpret_cast <Ty *>(src_ptr)),
646
+ : src_p(reinterpret_cast <const Ty *>(src_ptr)),
653
647
dst_p (reinterpret_cast <Ty *>(dst_ptr)), src_indexer_(src_indexer),
654
648
dst_indexer_(dst_indexer)
655
649
{
656
650
}
657
651
658
652
void operator ()(sycl::id<1 > wiid) const
659
653
{
660
- py::ssize_t this_src_offset = src_indexer_ (wiid.get (0 ));
661
- const Ty *in = src_p + this_src_offset;
662
-
663
- py::ssize_t shifted_wiid =
664
- (static_cast <py::ssize_t >(wiid.get (0 )) + offset) % size;
665
- shifted_wiid = (shifted_wiid >= 0 ) ? shifted_wiid : shifted_wiid + size;
654
+ const py::ssize_t src_offset = src_indexer_ (wiid.get (0 ));
655
+ const py::ssize_t dst_offset = dst_indexer_ (wiid.get (0 ));
666
656
667
- py::ssize_t this_dst_offset = dst_indexer_ (shifted_wiid);
668
-
669
- Ty *out = dst_p + this_dst_offset;
670
- *out = *in;
657
+ dst_p[dst_offset] = src_p[src_offset];
671
658
}
672
659
};
673
660
674
661
// define function type
675
662
typedef sycl::event (*copy_for_reshape_fn_ptr_t )(
676
663
sycl::queue,
677
- py::ssize_t , // shift
678
- size_t , // num_elements
679
- int ,
680
- int , // src_nd, dst_nd
664
+ size_t , // num_elements
665
+ int , // src_nd
666
+ int , // dst_nd
681
667
py::ssize_t *, // packed shapes and strides
682
- char *, // src_data_ptr
668
+ const char *, // src_data_ptr
683
669
char *, // dst_data_ptr
684
670
const std::vector<sycl::event> &);
685
671
686
672
/* !
687
673
* @brief Function to copy content of array while reshaping.
688
674
*
689
- * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems ,
675
+ * Submits a kernel to perform a copy `dst[unravel_index(i ,
690
676
* dst.shape)] = src[unravel_undex(i, src.shape)]`.
691
677
*
692
678
* @param q The execution queue where kernel is submitted.
693
- * @param shift The shift in flat indexing.
694
679
* @param nelems The number of elements to copy
695
680
* @param src_nd Array dimension of the source array
696
681
* @param dst_nd Array dimension of the destination array
@@ -708,31 +693,40 @@ typedef sycl::event (*copy_for_reshape_fn_ptr_t)(
708
693
template <typename Ty>
709
694
sycl::event
710
695
copy_for_reshape_generic_impl (sycl::queue q,
711
- py::ssize_t shift,
712
696
size_t nelems,
713
697
int src_nd,
714
698
int dst_nd,
715
699
py::ssize_t *packed_shapes_and_strides,
716
- char *src_p,
700
+ const char *src_p,
717
701
char *dst_p,
718
702
const std::vector<sycl::event> &depends)
719
703
{
720
704
dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
721
705
722
706
sycl::event copy_for_reshape_ev = q.submit ([&](sycl::handler &cgh) {
723
- StridedIndexer src_indexer{
724
- src_nd, 0 ,
725
- const_cast <const py::ssize_t *>(packed_shapes_and_strides)};
726
- StridedIndexer dst_indexer{
727
- dst_nd, 0 ,
728
- const_cast <const py::ssize_t *>(packed_shapes_and_strides +
729
- (2 * src_nd))};
730
707
cgh.depends_on (depends);
731
- cgh.parallel_for <copy_for_reshape_generic_kernel<Ty, StridedIndexer,
732
- StridedIndexer>>(
708
+
709
+ // packed_shapes_and_strides:
710
+ // USM array of size 2*(src_nd + dst_nd)
711
+ // [ src_shape; src_strides; dst_shape; dst_strides ]
712
+
713
+ const py::ssize_t *src_shape_and_strides =
714
+ const_cast <const py::ssize_t *>(packed_shapes_and_strides);
715
+
716
+ const py::ssize_t *dst_shape_and_strides =
717
+ const_cast <const py::ssize_t *>(packed_shapes_and_strides +
718
+ (2 * src_nd));
719
+
720
+ StridedIndexer src_indexer{src_nd, 0 , src_shape_and_strides};
721
+ StridedIndexer dst_indexer{dst_nd, 0 , dst_shape_and_strides};
722
+
723
+ using KernelName =
724
+ copy_for_reshape_generic_kernel<Ty, StridedIndexer, StridedIndexer>;
725
+
726
+ cgh.parallel_for <KernelName>(
733
727
sycl::range<1 >(nelems),
734
728
GenericCopyForReshapeFunctor<Ty, StridedIndexer, StridedIndexer>(
735
- shift, nelems, src_p, dst_p, src_indexer, dst_indexer));
729
+ src_p, dst_p, src_indexer, dst_indexer));
736
730
});
737
731
738
732
return copy_for_reshape_ev;
@@ -752,6 +746,221 @@ template <typename fnT, typename Ty> struct CopyForReshapeGenericFactory
752
746
}
753
747
};
754
748
749
+ // =============== Copying for reshape ================== //
750
+
751
+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
752
+ class copy_for_roll_strided_kernel ;
753
+
754
+ template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
755
+ class StridedCopyForRollFunctor
756
+ {
757
+ private:
758
+ size_t offset = 0 ;
759
+ size_t size = 1 ;
760
+ const Ty *src_p = nullptr ;
761
+ Ty *dst_p = nullptr ;
762
+ SrcIndexerT src_indexer_;
763
+ DstIndexerT dst_indexer_;
764
+
765
+ public:
766
+ StridedCopyForRollFunctor (size_t shift,
767
+ size_t nelems,
768
+ const Ty *src_ptr,
769
+ Ty *dst_ptr,
770
+ SrcIndexerT src_indexer,
771
+ DstIndexerT dst_indexer)
772
+ : offset(shift), size(nelems), src_p(src_ptr), dst_p(dst_ptr),
773
+ src_indexer_ (src_indexer), dst_indexer_(dst_indexer)
774
+ {
775
+ }
776
+
777
+ void operator ()(sycl::id<1 > wiid) const
778
+ {
779
+ const size_t gid = wiid.get (0 );
780
+ const size_t shifted_gid =
781
+ ((gid < offset) ? gid + size - offset : gid - offset);
782
+
783
+ const py::ssize_t src_offset = src_indexer_ (shifted_gid);
784
+ const py::ssize_t dst_offset = dst_indexer_ (gid);
785
+
786
+ dst_p[dst_offset] = src_p[src_offset];
787
+ }
788
+ };
789
+
790
+ // define function type
791
+ typedef sycl::event (*copy_for_roll_strided_fn_ptr_t )(
792
+ sycl::queue,
793
+ size_t , // shift
794
+ size_t , // num_elements
795
+ int , // common_nd
796
+ const py::ssize_t *, // packed shapes and strides
797
+ const char *, // src_data_ptr
798
+ py::ssize_t , // src_offset
799
+ char *, // dst_data_ptr
800
+ py::ssize_t , // dst_offset
801
+ const std::vector<sycl::event> &);
802
+
803
+ template <typename Ty> class copy_for_roll_contig_kernel ;
804
+
805
+ /* !
806
+ * @brief Function to copy content of array with a shift.
807
+ *
808
+ * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems ,
809
+ * dst.shape)] = src[unravel_undex(i, src.shape)]`.
810
+ *
811
+ * @param q The execution queue where kernel is submitted.
812
+ * @param shift The shift in flat indexing, must be non-negative.
813
+ * @param nelems The number of elements to copy
814
+ * @param nd Array dimensionality of the destination and source arrays
815
+ * @param packed_shapes_and_strides Kernel accessible USM array of size
816
+ * `3*nd` with content `[common_shape, src_strides, dst_strides]`.
817
+ * @param src_p Typeless USM pointer to the buffer of the source array
818
+ * @param src_offset Displacement of first element of src relative src_p in
819
+ * elements
820
+ * @param dst_p Typeless USM pointer to the buffer of the destination array
821
+ * @param dst_offset Displacement of first element of dst relative dst_p in
822
+ * elements
823
+ * @param depends List of events to wait for before starting computations, if
824
+ * any.
825
+ *
826
+ * @return Event to wait on to ensure that computation completes.
827
+ * @ingroup CopyAndCastKernels
828
+ */
829
+ template <typename Ty>
830
+ sycl::event
831
+ copy_for_roll_strided_impl (sycl::queue q,
832
+ size_t shift,
833
+ size_t nelems,
834
+ int nd,
835
+ const py::ssize_t *packed_shapes_and_strides,
836
+ const char *src_p,
837
+ py::ssize_t src_offset,
838
+ char *dst_p,
839
+ py::ssize_t dst_offset,
840
+ const std::vector<sycl::event> &depends)
841
+ {
842
+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
843
+
844
+ sycl::event copy_for_roll_ev = q.submit ([&](sycl::handler &cgh) {
845
+ cgh.depends_on (depends);
846
+
847
+ // packed_shapes_and_strides:
848
+ // USM array of size 3 * nd
849
+ // [ common_shape; src_strides; dst_strides ]
850
+
851
+ StridedIndexer src_indexer{nd, src_offset, packed_shapes_and_strides};
852
+ UnpackedStridedIndexer dst_indexer{nd, dst_offset,
853
+ packed_shapes_and_strides,
854
+ packed_shapes_and_strides + 2 * nd};
855
+
856
+ using KernelName = copy_for_roll_strided_kernel<Ty, StridedIndexer,
857
+ UnpackedStridedIndexer>;
858
+
859
+ const Ty *src_tp = reinterpret_cast <const Ty *>(src_p);
860
+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_p);
861
+
862
+ cgh.parallel_for <KernelName>(
863
+ sycl::range<1 >(nelems),
864
+ StridedCopyForRollFunctor<Ty, StridedIndexer,
865
+ UnpackedStridedIndexer>(
866
+ shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
867
+ });
868
+
869
+ return copy_for_roll_ev;
870
+ }
871
+
872
+ // define function type
873
+ typedef sycl::event (*copy_for_roll_contig_fn_ptr_t )(
874
+ sycl::queue,
875
+ size_t , // shift
876
+ size_t , // num_elements
877
+ const char *, // src_data_ptr
878
+ py::ssize_t , // src_offset
879
+ char *, // dst_data_ptr
880
+ py::ssize_t , // dst_offset
881
+ const std::vector<sycl::event> &);
882
+
883
+ /* !
884
+ * @brief Function to copy content of array with a shift.
885
+ *
886
+ * Submits a kernel to perform a copy `dst[unravel_index((i + shift) % nelems ,
887
+ * dst.shape)] = src[unravel_undex(i, src.shape)]`.
888
+ *
889
+ * @param q The execution queue where kernel is submitted.
890
+ * @param shift The shift in flat indexing, must be non-negative.
891
+ * @param nelems The number of elements to copy
892
+ * @param src_p Typeless USM pointer to the buffer of the source array
893
+ * @param src_offset Displacement of the start of array src relative src_p in
894
+ * elements
895
+ * @param dst_p Typeless USM pointer to the buffer of the destination array
896
+ * @param dst_offset Displacement of the start of array dst relative dst_p in
897
+ * elements
898
+ * @param depends List of events to wait for before starting computations, if
899
+ * any.
900
+ *
901
+ * @return Event to wait on to ensure that computation completes.
902
+ * @ingroup CopyAndCastKernels
903
+ */
904
+ template <typename Ty>
905
+ sycl::event copy_for_roll_contig_impl (sycl::queue q,
906
+ size_t shift,
907
+ size_t nelems,
908
+ const char *src_p,
909
+ py::ssize_t src_offset,
910
+ char *dst_p,
911
+ py::ssize_t dst_offset,
912
+ const std::vector<sycl::event> &depends)
913
+ {
914
+ dpctl::tensor::type_utils::validate_type_for_device<Ty>(q);
915
+
916
+ sycl::event copy_for_roll_ev = q.submit ([&](sycl::handler &cgh) {
917
+ cgh.depends_on (depends);
918
+
919
+ NoOpIndexer src_indexer{};
920
+ NoOpIndexer dst_indexer{};
921
+
922
+ using KernelName = copy_for_roll_contig_kernel<Ty>;
923
+
924
+ const Ty *src_tp = reinterpret_cast <const Ty *>(src_p) + src_offset;
925
+ Ty *dst_tp = reinterpret_cast <Ty *>(dst_p) + dst_offset;
926
+
927
+ cgh.parallel_for <KernelName>(
928
+ sycl::range<1 >(nelems),
929
+ StridedCopyForRollFunctor<Ty, NoOpIndexer, NoOpIndexer>(
930
+ shift, nelems, src_tp, dst_tp, src_indexer, dst_indexer));
931
+ });
932
+
933
+ return copy_for_roll_ev;
934
+ }
935
+
936
+ /* !
937
+ * @brief Factory to get function pointer of type `fnT` for given array data
938
+ * type `Ty`.
939
+ * @ingroup CopyAndCastKernels
940
+ */
941
+ template <typename fnT, typename Ty> struct CopyForRollStridedFactory
942
+ {
943
+ fnT get ()
944
+ {
945
+ fnT f = copy_for_roll_strided_impl<Ty>;
946
+ return f;
947
+ }
948
+ };
949
+
950
+ /* !
951
+ * @brief Factory to get function pointer of type `fnT` for given array data
952
+ * type `Ty`.
953
+ * @ingroup CopyAndCastKernels
954
+ */
955
+ template <typename fnT, typename Ty> struct CopyForRollContigFactory
956
+ {
957
+ fnT get ()
958
+ {
959
+ fnT f = copy_for_roll_contig_impl<Ty>;
960
+ return f;
961
+ }
962
+ };
963
+
755
964
} // namespace copy_and_cast
756
965
} // namespace kernels
757
966
} // namespace tensor
0 commit comments