Skip to content

Commit 6fe7cc1

Browse files
[NFC][SYCL] Use sycl::local_accessor in reduction implementation (#7182)
1 parent c69e5ce commit 6fe7cc1

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -561,19 +561,6 @@ class reduction_impl_algo : public reduction_impl_common<T, BinaryOperation> {
561561
RedOutVar RedOut)
562562
: base(Identity, BinaryOp, Init), MRedOut(std::move(RedOut)){};
563563

564-
/// Creates and returns a local accessor with the \p Size elements.
565-
/// By default the local accessor elements are of the same type as the
566-
/// elements processed by the reduction, but may it be altered by specifying
567-
/// \p _T explicitly if need an accessor with elements of different type.
568-
///
569-
/// For array reductions we process them one element in a type to avoid stack
570-
/// growth, so the dimensionality of the temporary buffer is always one.
571-
template <class _T = result_type>
572-
static accessor<_T, 1, access::mode::read_write, access::target::local>
573-
getReadWriteLocalAcc(size_t Size, handler &CGH) {
574-
return {Size, CGH};
575-
}
576-
577564
auto getReadAccToPreviousPartialReds(handler &CGH) const {
578565
CGH.addReduction(MOutBufPtr);
579566
return accessor{*MOutBufPtr, CGH, sycl::read_only};
@@ -849,7 +836,7 @@ bool reduCGFuncForRangeFastAtomics(handler &CGH, KernelType KernelFunc,
849836
PropertiesT Properties, Reduction &Redu) {
850837
size_t NElements = Reduction::num_elements;
851838
auto Out = Redu.getReadWriteAccessorToInitializedMem(CGH);
852-
auto GroupSum = Reduction::getReadWriteLocalAcc(NElements, CGH);
839+
local_accessor<typename Reduction::result_type, 1> GroupSum{NElements, CGH};
853840
using Name = __sycl_reduction_kernel<reduction::main_krn::RangeFastAtomics,
854841
KernelName>;
855842
size_t NWorkGroups = NDRange.get_group_range().size();
@@ -907,8 +894,7 @@ bool reduCGFuncForRangeFastReduce(handler &CGH, KernelType KernelFunc,
907894

908895
bool IsUpdateOfUserVar = !Reduction::is_usm && !Redu.initializeToIdentity();
909896
auto Rest = [&](auto NWorkGroupsFinished) {
910-
auto DoReducePartialSumsInLastWG =
911-
Reduction::template getReadWriteLocalAcc<int>(1, CGH);
897+
local_accessor<int, 1> DoReducePartialSumsInLastWG{1, CGH};
912898

913899
using Name = __sycl_reduction_kernel<reduction::main_krn::RangeFastReduce,
914900
KernelName, decltype(NWorkGroupsFinished)>;
@@ -1008,11 +994,10 @@ bool reduCGFuncForRangeBasic(handler &CGH, KernelType KernelFunc,
1008994
auto Out = (NWorkGroups == 1)
1009995
? PartialSums
1010996
: Redu.getWriteAccForPartialReds(NElements, CGH);
1011-
auto LocalReds = Reduction::getReadWriteLocalAcc(WGSize + 1, CGH);
997+
local_accessor<typename Reduction::result_type, 1> LocalReds{WGSize + 1, CGH};
1012998
auto NWorkGroupsFinished =
1013999
Redu.getReadWriteAccessorToInitializedGroupsCounter(CGH);
1014-
auto DoReducePartialSumsInLastWG =
1015-
Reduction::template getReadWriteLocalAcc<int>(1, CGH);
1000+
local_accessor<int, 1> DoReducePartialSumsInLastWG{1, CGH};
10161001

10171002
auto Identity = Redu.getIdentity();
10181003
auto BOp = Redu.getBinaryOperation();
@@ -1198,7 +1183,8 @@ void reduCGFuncForNDRangeFastAtomicsOnly(handler &CGH, bool IsPow2WG,
11981183
// The additional last element is used to catch reduce elements that could
11991184
// otherwise be lost in the tree-reduction algorithm used in the kernel.
12001185
size_t NLocalElements = WGSize + (IsPow2WG ? 0 : 1);
1201-
auto LocalReds = Reduction::getReadWriteLocalAcc(NLocalElements, CGH);
1186+
local_accessor<typename Reduction::result_type, 1> LocalReds{NLocalElements,
1187+
CGH};
12021188

12031189
using Name =
12041190
__sycl_reduction_kernel<reduction::main_krn::NDRangeFastAtomicsOnly,
@@ -1332,7 +1318,8 @@ void reduCGFuncForNDRangeBasic(handler &CGH, bool IsPow2WG,
13321318
// The additional last element is used to catch elements that could
13331319
// otherwise be lost in the tree-reduction algorithm.
13341320
size_t NumLocalElements = WGSize + (IsPow2WG ? 0 : 1);
1335-
auto LocalReds = Reduction::getReadWriteLocalAcc(NumLocalElements, CGH);
1321+
local_accessor<typename Reduction::result_type, 1> LocalReds{NumLocalElements,
1322+
CGH};
13361323
typename Reduction::result_type ReduIdentity = Redu.getIdentity();
13371324
using Name =
13381325
__sycl_reduction_kernel<reduction::main_krn::NDRangeBasic, KernelName>;
@@ -1460,7 +1447,8 @@ void reduAuxCGFuncNoFastReduceNorAtomicImpl(handler &CGH, bool UniformPow2WG,
14601447
// The additional last element is used to catch elements that could
14611448
// otherwise be lost in the tree-reduction algorithm.
14621449
size_t NumLocalElements = WGSize + (UniformPow2WG ? 0 : 1);
1463-
auto LocalReds = Reduction::getReadWriteLocalAcc(NumLocalElements, CGH);
1450+
local_accessor<typename Reduction::result_type, 1> LocalReds{NumLocalElements,
1451+
CGH};
14641452

14651453
auto ReduIdentity = Redu.getIdentity();
14661454
auto BOp = Redu.getBinaryOperation();
@@ -1592,8 +1580,9 @@ template <typename... Reductions, size_t... Is>
15921580
auto createReduLocalAccs(size_t Size, handler &CGH,
15931581
std::index_sequence<Is...>) {
15941582
return makeReduTupleT(
1595-
std::tuple_element_t<Is, std::tuple<Reductions...>>::getReadWriteLocalAcc(
1596-
Size, CGH)...);
1583+
local_accessor<typename std::tuple_element_t<
1584+
Is, std::tuple<Reductions...>>::result_type,
1585+
1>{Size, CGH}...);
15971586
}
15981587

15991588
/// For the given 'Reductions' types pack and indices enumerating them this

0 commit comments

Comments
 (0)