Skip to content

Commit 8ca7d50

Browse files
[NFC][SYCL][Reduction] Increase scope of handler::withAuxHandler (#7240)
to cover more internals of sycl::handler class so that reductions implementations can use less handler's private APIs.
1 parent a2f0003 commit 8ca7d50

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -481,13 +481,13 @@ class __SYCL_EXPORT handler {
481481

482482
/// Helper utility for operation widely used through different reduction
483483
/// implementations.
484-
template <class FunctorTy>
485-
event withAuxHandler(std::shared_ptr<detail::queue_impl> Queue,
486-
FunctorTy Func) {
487-
handler AuxHandler(Queue, MIsHost);
484+
template <class FunctorTy> void withAuxHandler(FunctorTy Func) {
485+
this->finalize();
486+
handler AuxHandler(MQueue, MIsHost);
488487
AuxHandler.saveCodeLoc(MCodeLoc);
489488
Func(AuxHandler);
490-
return AuxHandler.finalize();
489+
MLastEvent = AuxHandler.finalize();
490+
return;
491491
}
492492

493493
/// Saves buffers created by handling reduction feature in handler.
@@ -1761,8 +1761,6 @@ class __SYCL_EXPORT handler {
17611761
ext::oneapi::experimental::is_property_list<PropertiesT>::value>
17621762
parallel_for_impl(range<Dims> Range, PropertiesT Properties, Reduction Redu,
17631763
_KERNELFUNCPARAM(KernelFunc)) {
1764-
std::shared_ptr<detail::queue_impl> QueueCopy = MQueue;
1765-
17661764
// Before running the kernels, check that device has enough local memory
17671765
// to hold local arrays required for the tree-reduction algorithm.
17681766
constexpr bool IsTreeReduction =
@@ -1782,8 +1780,7 @@ class __SYCL_EXPORT handler {
17821780
if (detail::reduCGFuncForRange<KernelName>(
17831781
*this, KernelFunc, Range, PrefWGSize, NumConcurrentWorkGroups,
17841782
Properties, Redu)) {
1785-
this->finalize();
1786-
MLastEvent = withAuxHandler(QueueCopy, [&](handler &CopyHandler) {
1783+
withAuxHandler([&](handler &CopyHandler) {
17871784
detail::reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
17881785
});
17891786
}
@@ -1802,7 +1799,6 @@ class __SYCL_EXPORT handler {
18021799
parallel_for_basic_impl<KernelName>(Range, Properties, Redu, KernelFunc);
18031800
return;
18041801
} else { // Can't "early" return for "if constexpr".
1805-
std::shared_ptr<detail::queue_impl> QueueCopy = MQueue;
18061802
if constexpr (Reduction::has_float64_atomics) {
18071803
/// This version is a specialization for the add
18081804
/// operator. It performs runtime checks for device aspect "atomic64";
@@ -1837,8 +1833,7 @@ class __SYCL_EXPORT handler {
18371833
// the kernel would require creation of another variant of user's kernel,
18381834
// which does not seem efficient.
18391835
if (Reduction::is_usm || Redu.initializeToIdentity()) {
1840-
this->finalize();
1841-
MLastEvent = withAuxHandler(QueueCopy, [&](handler &CopyHandler) {
1836+
withAuxHandler([&](handler &CopyHandler) {
18421837
detail::reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
18431838
});
18441839
}
@@ -1888,7 +1883,6 @@ class __SYCL_EXPORT handler {
18881883

18891884
// 1. Call the kernel that includes user's lambda function.
18901885
detail::reduCGFunc<KernelName>(*this, KernelFunc, Range, Properties, Redu);
1891-
std::shared_ptr<detail::queue_impl> QueueCopy = MQueue;
18921886
this->finalize();
18931887

18941888
// 2. Run the additional kernel as many times as needed to reduce
@@ -1906,14 +1900,14 @@ class __SYCL_EXPORT handler {
19061900
PI_ERROR_INVALID_WORK_GROUP_SIZE);
19071901
size_t NWorkItems = Range.get_group_range().size();
19081902
while (NWorkItems > 1) {
1909-
MLastEvent = withAuxHandler(QueueCopy, [&](handler &AuxHandler) {
1903+
withAuxHandler([&](handler &AuxHandler) {
19101904
NWorkItems = detail::reduAuxCGFunc<KernelName, KernelType>(
19111905
AuxHandler, NWorkItems, MaxWGSize, Redu);
19121906
});
19131907
} // end while (NWorkItems > 1)
19141908

19151909
if (Reduction::is_usm) {
1916-
MLastEvent = withAuxHandler(QueueCopy, [&](handler &CopyHandler) {
1910+
withAuxHandler([&](handler &CopyHandler) {
19171911
detail::reduSaveFinalResultToUserMem<KernelName>(CopyHandler, Redu);
19181912
});
19191913
}
@@ -1957,12 +1951,11 @@ class __SYCL_EXPORT handler {
19571951

19581952
detail::reduCGFuncMulti<KernelName>(*this, KernelFunc, Range, Properties,
19591953
ReduTuple, ReduIndices);
1960-
std::shared_ptr<detail::queue_impl> QueueCopy = MQueue;
19611954
this->finalize();
19621955

19631956
size_t NWorkItems = Range.get_group_range().size();
19641957
while (NWorkItems > 1) {
1965-
MLastEvent = withAuxHandler(QueueCopy, [&](handler &AuxHandler) {
1958+
withAuxHandler([&](handler &AuxHandler) {
19661959
NWorkItems = detail::reduAuxCGFunc<KernelName, decltype(KernelFunc)>(
19671960
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
19681961
});

0 commit comments

Comments
 (0)