9
9
#pragma once
10
10
11
11
#include < CL/sycl/accessor.hpp>
12
+ #include < CL/sycl/handler.hpp>
12
13
#include < CL/sycl/intel/group_algorithm.hpp>
13
14
14
15
__SYCL_INLINE_NAMESPACE (cl) {
@@ -17,6 +18,11 @@ namespace intel {
17
18
18
19
namespace detail {
19
20
21
+ __SYCL_EXPORT size_t reduGetMaxWGSize (shared_ptr_class<queue_impl> Queue,
22
+ size_t LocalMemBytesPerWorkItem);
23
+ __SYCL_EXPORT size_t reduComputeWGSize (size_t NWorkItems, size_t MaxWGSize,
24
+ size_t &NWorkGroups);
25
+
20
26
using cl::sycl::detail::bool_constant;
21
27
using cl::sycl::detail::enable_if_t ;
22
28
using cl::sycl::detail::is_geninteger16bit;
@@ -867,19 +873,19 @@ reduCGFunc(handler &CGH, KernelType KernelFunc, const nd_range<Dims> &Range,
867
873
// / of work-groups. At the end of each work-groups the partial sum is written
868
874
// / to a global buffer.
869
875
// /
870
- // / Briefly: aux kernel, intel:reduce(), reproducible results,FP + ADD/MIN/MAX
871
- template <typename KernelName, typename KernelType, int Dims, class Reduction ,
872
- bool UniformWG , typename InputT, typename OutputT>
876
+ // / Briefly: aux kernel, intel:reduce(), reproducible results, FP + ADD/MIN/MAX
877
+ template <typename KernelName, typename KernelType, bool UniformWG ,
878
+ class Reduction , typename InputT, typename OutputT>
873
879
enable_if_t <Reduction::has_fast_reduce && !Reduction::has_fast_atomics>
874
- reduAuxCGFuncImpl (handler &CGH, const nd_range<Dims> &Range, size_t NWorkItems,
875
- Reduction &, InputT In, OutputT Out) {
876
- size_t NWorkGroups = Range.get_group_range ().size ();
877
- bool IsUpdateOfUserVar =
878
- Reduction::accessor_mode == access::mode::read_write && NWorkGroups == 1 ;
879
-
880
+ reduAuxCGFuncImpl (handler &CGH, size_t NWorkItems, size_t NWorkGroups,
881
+ size_t WGSize, Reduction &, InputT In, OutputT Out) {
880
882
using Name = typename get_reduction_aux_kernel_name_t <
881
883
KernelName, KernelType, Reduction::is_usm, UniformWG, OutputT>::name;
882
- CGH.parallel_for <Name>(Range, [=](nd_item<Dims> NDIt) {
884
+
885
+ bool IsUpdateOfUserVar =
886
+ Reduction::accessor_mode == access::mode::read_write && NWorkGroups == 1 ;
887
+ nd_range<1 > Range{range<1 >(NWorkItems), range<1 >(WGSize)};
888
+ CGH.parallel_for <Name>(Range, [=](nd_item<1 > NDIt) {
883
889
typename Reduction::binary_operation BOp;
884
890
size_t WGID = NDIt.get_group_linear_id ();
885
891
size_t GID = NDIt.get_global_linear_id ();
@@ -903,14 +909,11 @@ reduAuxCGFuncImpl(handler &CGH, const nd_range<Dims> &Range, size_t NWorkItems,
903
909
// / to a global buffer.
904
910
// /
905
911
// / Briefly: aux kernel, tree-reduction, CUSTOM types/ops.
906
- template <typename KernelName, typename KernelType, int Dims, class Reduction ,
907
- bool UniformPow2WG , typename InputT, typename OutputT>
912
+ template <typename KernelName, typename KernelType, bool UniformPow2WG ,
913
+ class Reduction , typename InputT, typename OutputT>
908
914
enable_if_t <!Reduction::has_fast_reduce && !Reduction::has_fast_atomics>
909
- reduAuxCGFuncImpl (handler &CGH, const nd_range<Dims> &Range, size_t NWorkItems,
910
- Reduction &Redu, InputT In, OutputT Out) {
911
- size_t WGSize = Range.get_local_range ().size ();
912
- size_t NWorkGroups = Range.get_group_range ().size ();
913
-
915
+ reduAuxCGFuncImpl (handler &CGH, size_t NWorkItems, size_t NWorkGroups,
916
+ size_t WGSize, Reduction &Redu, InputT In, OutputT Out) {
914
917
bool IsUpdateOfUserVar =
915
918
Reduction::accessor_mode == access::mode::read_write && NWorkGroups == 1 ;
916
919
@@ -924,7 +927,8 @@ reduAuxCGFuncImpl(handler &CGH, const nd_range<Dims> &Range, size_t NWorkItems,
924
927
auto ReduIdentity = Redu.getIdentity ();
925
928
using Name = typename get_reduction_aux_kernel_name_t <
926
929
KernelName, KernelType, Reduction::is_usm, UniformPow2WG, OutputT>::name;
927
- CGH.parallel_for <Name>(Range, [=](nd_item<Dims> NDIt) {
930
+ nd_range<1 > Range{range<1 >(NWorkItems), range<1 >(WGSize)};
931
+ CGH.parallel_for <Name>(Range, [=](nd_item<1 > NDIt) {
928
932
size_t WGSize = NDIt.get_local_range ().size ();
929
933
size_t LID = NDIt.get_local_linear_id ();
930
934
size_t GID = NDIt.get_global_linear_id ();
@@ -962,12 +966,22 @@ reduAuxCGFuncImpl(handler &CGH, const nd_range<Dims> &Range, size_t NWorkItems,
962
966
});
963
967
}
964
968
965
- template <typename KernelName, typename KernelType, int Dims, class Reduction >
966
- enable_if_t <!Reduction::has_fast_atomics>
967
- reduAuxCGFunc (handler &CGH, const nd_range<Dims> &Range, size_t NWorkItems,
969
+ // / Implements a command group function that enqueues a kernel that does one
970
+ // / iteration of reduction of elements in each of work-groups.
971
+ // / At the end of each work-group the partial sum is written to a global buffer.
972
+ // / The function returns the number of the newly generated partial sums.
973
+ template <typename KernelName, typename KernelType, class Reduction >
974
+ enable_if_t <!Reduction::has_fast_atomics, size_t >
975
+ reduAuxCGFunc (handler &CGH, size_t NWorkItems, size_t MaxWGSize,
968
976
Reduction &Redu) {
969
- size_t WGSize = Range.get_local_range ().size ();
970
- size_t NWorkGroups = Range.get_group_range ().size ();
977
+
978
+ size_t NWorkGroups;
979
+ size_t WGSize = reduComputeWGSize (NWorkItems, MaxWGSize, NWorkGroups);
980
+
981
+ // The last kernel DOES write to user's accessor passed to reduction.
982
+ // Associate it with handler manually.
983
+ if (NWorkGroups == 1 && !Reduction::is_usm)
984
+ Redu.associateWithHandler (CGH);
971
985
972
986
// The last work-group may be not fully loaded with work, or the work group
973
987
// size may be not power of two. Those two cases considered inefficient
@@ -981,20 +995,21 @@ reduAuxCGFunc(handler &CGH, const nd_range<Dims> &Range, size_t NWorkItems,
981
995
auto In = Redu.getReadAccToPreviousPartialReds (CGH);
982
996
if (Reduction::is_usm && NWorkGroups == 1 ) {
983
997
if (HasUniformWG)
984
- reduAuxCGFuncImpl<KernelName, KernelType, Dims, Reduction, true >(
985
- CGH, Range, NWorkItems , Redu, In, Redu.getUSMPointer ());
998
+ reduAuxCGFuncImpl<KernelName, KernelType, true >(
999
+ CGH, NWorkItems, NWorkGroups, WGSize , Redu, In, Redu.getUSMPointer ());
986
1000
else
987
- reduAuxCGFuncImpl<KernelName, KernelType, Dims, Reduction, false >(
988
- CGH, Range, NWorkItems , Redu, In, Redu.getUSMPointer ());
1001
+ reduAuxCGFuncImpl<KernelName, KernelType, false >(
1002
+ CGH, NWorkItems, NWorkGroups, WGSize , Redu, In, Redu.getUSMPointer ());
989
1003
} else {
990
1004
auto Out = Redu.getWriteAccForPartialReds (NWorkGroups, CGH);
991
1005
if (HasUniformWG)
992
- reduAuxCGFuncImpl<KernelName, KernelType, Dims, Reduction, true >(
993
- CGH, Range, NWorkItems , Redu, In, Out);
1006
+ reduAuxCGFuncImpl<KernelName, KernelType, true >(
1007
+ CGH, NWorkItems, NWorkGroups, WGSize , Redu, In, Out);
994
1008
else
995
- reduAuxCGFuncImpl<KernelName, KernelType, Dims, Reduction, false >(
996
- CGH, Range, NWorkItems , Redu, In, Out);
1009
+ reduAuxCGFuncImpl<KernelName, KernelType, false >(
1010
+ CGH, NWorkItems, NWorkGroups, WGSize , Redu, In, Out);
997
1011
}
1012
+ return NWorkGroups;
998
1013
}
999
1014
1000
1015
} // namespace detail
0 commit comments