Skip to content

Commit 99bdc82

Browse files
[SYCL][Reduction] Improve getGroupsCounterAccDiscrete performance (#6858)
1 parent 71bdc1f commit 99bdc82

File tree

1 file changed

+79
-78
lines changed

1 file changed

+79
-78
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 79 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <sycl/kernel.hpp>
1919
#include <sycl/known_identity.hpp>
2020
#include <sycl/properties/reduction_properties.hpp>
21+
#include <sycl/usm.hpp>
2122

2223
#include <tuple>
2324

@@ -666,15 +667,18 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
666667

667668
// On discrete (vs. integrated) GPUs it's faster to initialize memory with an
668669
// extra kernel than copy it from the host.
669-
template <typename Name> auto getGroupsCounterAccDiscrete(handler &CGH) {
670-
auto &Buf = getTempBuffer<int>(1, CGH);
671-
std::shared_ptr<detail::queue_impl> QueueCopy = CGH.MQueue;
672-
auto Event = CGH.withAuxHandler(QueueCopy, [&](handler &InitHandler) {
673-
auto Acc = accessor{Buf, InitHandler, sycl::write_only, sycl::no_init};
674-
InitHandler.single_task<Name>([=]() { Acc[0] = 0; });
675-
});
670+
auto getGroupsCounterAccDiscrete(handler &CGH) {
671+
queue q = createSyclObjFromImpl<queue>(CGH.MQueue);
672+
device Dev = q.get_device();
673+
auto Deleter = [=](auto *Ptr) { free(Ptr, q); };
674+
675+
std::shared_ptr<int> Counter(malloc_device<int>(1, q), Deleter);
676+
CGH.addReduction(Counter);
677+
678+
auto Event = q.memset(Counter.get(), 0, sizeof(int));
676679
CGH.depends_on(Event);
677-
return accessor{Buf, CGH};
680+
681+
return Counter.get();
678682
}
679683

680684
RedOutVar &getUserRedVar() { return MRedOut; }
@@ -895,11 +899,8 @@ bool reduCGFuncForRangeFastAtomics(handler &CGH, KernelType KernelFunc,
895899

896900
namespace reduction {
897901
namespace main_krn {
898-
template <class KernelName> struct RangeFastReduce;
902+
template <class KernelName, class NWorkGroupsFinished> struct RangeFastReduce;
899903
} // namespace main_krn
900-
namespace init_krn {
901-
template <class KernelName> struct GroupCounter;
902-
}
903904
} // namespace reduction
904905
template <typename KernelName, typename KernelType, int Dims, class Reduction>
905906
bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc,
@@ -917,81 +918,81 @@ bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc,
917918
accessor PartialSums(PartialSumsBuf, CGH, sycl::read_write, sycl::no_init);
918919

919920
bool IsUpdateOfUserVar = !Reduction::is_usm && !Redu.initializeToIdentity();
920-
using InitName =
921-
__sycl_reduction_kernel<reduction::init_krn::GroupCounter, KernelName>;
922-
923-
// Integrated/discrete GPUs have different faster path.
924-
auto NWorkGroupsFinished =
925-
getDeviceFromHandler(CGH).get_info<info::device::host_unified_memory>()
926-
? Redu.getReadWriteAccessorToInitializedGroupsCounter(CGH)
927-
: Redu.template getGroupsCounterAccDiscrete<InitName>(CGH);
928-
929-
auto DoReducePartialSumsInLastWG =
930-
Reduction::template getReadWriteLocalAcc<int>(1, CGH);
931-
932-
using Name =
933-
__sycl_reduction_kernel<reduction::main_krn::RangeFastReduce, KernelName>;
934-
size_t PerGroup = Range.size() / NWorkGroups;
935-
CGH.parallel_for<Name>(NDRange, [=](nd_item<1> NDId) {
936-
// Call user's functions. Reducer.MValue gets initialized there.
937-
typename Reduction::reducer_type Reducer;
938-
reductionLoop(Range, PerGroup, Reducer, NDId, KernelFunc);
921+
auto Rest = [&](auto NWorkGroupsFinished) {
922+
auto DoReducePartialSumsInLastWG =
923+
Reduction::template getReadWriteLocalAcc<int>(1, CGH);
924+
925+
using Name = __sycl_reduction_kernel<reduction::main_krn::RangeFastReduce,
926+
KernelName, decltype(NWorkGroupsFinished)>;
927+
size_t PerGroup = Range.size() / NWorkGroups;
928+
CGH.parallel_for<Name>(NDRange, [=](nd_item<1> NDId) {
929+
// Call user's functions. Reducer.MValue gets initialized there.
930+
typename Reduction::reducer_type Reducer;
931+
reductionLoop(Range, PerGroup, Reducer, NDId, KernelFunc);
939932

940-
typename Reduction::binary_operation BOp;
941-
auto Group = NDId.get_group();
933+
typename Reduction::binary_operation BOp;
934+
auto Group = NDId.get_group();
942935

943-
// If there are multiple values, reduce each separately
944-
// reduce_over_group is only defined for each T, not for span<T, ...>
945-
size_t LID = NDId.get_local_id(0);
946-
for (int E = 0; E < NElements; ++E) {
947-
auto &RedElem = Reducer.getElement(E);
948-
RedElem = reduce_over_group(Group, RedElem, BOp);
949-
if (LID == 0) {
950-
if (NWorkGroups == 1) {
951-
auto &OutElem = Reduction::getOutPointer(Out)[E];
952-
// Can avoid using partial sum and write the final result immediately.
953-
if (IsUpdateOfUserVar)
954-
RedElem = BOp(RedElem, OutElem);
955-
OutElem = RedElem;
956-
} else {
957-
PartialSums[NDId.get_group_linear_id() * NElements + E] =
958-
Reducer.getElement(E);
936+
// If there are multiple values, reduce each separately
937+
// reduce_over_group is only defined for each T, not for span<T, ...>
938+
size_t LID = NDId.get_local_id(0);
939+
for (int E = 0; E < NElements; ++E) {
940+
auto &RedElem = Reducer.getElement(E);
941+
RedElem = reduce_over_group(Group, RedElem, BOp);
942+
if (LID == 0) {
943+
if (NWorkGroups == 1) {
944+
auto &OutElem = Reduction::getOutPointer(Out)[E];
945+
// Can avoid using partial sum and write the final result
946+
// immediately.
947+
if (IsUpdateOfUserVar)
948+
RedElem = BOp(RedElem, OutElem);
949+
OutElem = RedElem;
950+
} else {
951+
PartialSums[NDId.get_group_linear_id() * NElements + E] =
952+
Reducer.getElement(E);
953+
}
959954
}
960955
}
961-
}
962-
963-
if (NWorkGroups == 1)
964-
// We're done.
965-
return;
966956

967-
// Signal this work-group has finished after all values are reduced
968-
if (LID == 0) {
969-
auto NFinished =
970-
sycl::atomic_ref<int, memory_order::relaxed, memory_scope::device,
971-
access::address_space::global_space>(
972-
NWorkGroupsFinished[0]);
973-
DoReducePartialSumsInLastWG[0] = ++NFinished == NWorkGroups;
974-
}
957+
if (NWorkGroups == 1)
958+
// We're done.
959+
return;
975960

976-
workGroupBarrier();
977-
if (DoReducePartialSumsInLastWG[0]) {
978-
// Reduce each result separately
979-
// TODO: Opportunity to parallelize across elements.
980-
for (int E = 0; E < NElements; ++E) {
981-
auto &OutElem = Reduction::getOutPointer(Out)[E];
982-
auto LocalSum = Reducer.getIdentity();
983-
for (size_t I = LID; I < NWorkGroups; I += WGSize)
984-
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
985-
auto Result = reduce_over_group(Group, LocalSum, BOp);
961+
// Signal this work-group has finished after all values are reduced
962+
if (LID == 0) {
963+
auto NFinished =
964+
sycl::atomic_ref<int, memory_order::relaxed, memory_scope::device,
965+
access::address_space::global_space>(
966+
NWorkGroupsFinished[0]);
967+
DoReducePartialSumsInLastWG[0] = ++NFinished == NWorkGroups;
968+
}
986969

987-
if (LID == 0) {
988-
if (IsUpdateOfUserVar)
989-
Result = BOp(Result, OutElem);
990-
OutElem = Result;
970+
workGroupBarrier();
971+
if (DoReducePartialSumsInLastWG[0]) {
972+
// Reduce each result separately
973+
// TODO: Opportunity to parallelize across elements.
974+
for (int E = 0; E < NElements; ++E) {
975+
auto &OutElem = Reduction::getOutPointer(Out)[E];
976+
auto LocalSum = Reducer.getIdentity();
977+
for (size_t I = LID; I < NWorkGroups; I += WGSize)
978+
LocalSum = BOp(LocalSum, PartialSums[I * NElements + E]);
979+
auto Result = reduce_over_group(Group, LocalSum, BOp);
980+
981+
if (LID == 0) {
982+
if (IsUpdateOfUserVar)
983+
Result = BOp(Result, OutElem);
984+
OutElem = Result;
985+
}
991986
}
992987
}
993-
}
994-
});
988+
});
989+
};
990+
991+
// Integrated/discrete GPUs have different faster path.
992+
if (getDeviceFromHandler(CGH).get_info<info::device::host_unified_memory>())
993+
Rest(Redu.getReadWriteAccessorToInitializedGroupsCounter(CGH));
994+
else
995+
Rest(Redu.getGroupsCounterAccDiscrete(CGH));
995996

996997
// We've updated user's variable, no extra work needed.
997998
return false;

0 commit comments

Comments
 (0)