Skip to content

Commit 43d6af7

Browse files
[NFC][SYCL] Unify single/multi reduction_parallel_for for nd_range (#7346)
This also directs reduction::strategy::multi through the dispatcher for other single-reduction strategies, making it possible to invoke vararg implementation for a single reduction (through internal APIs, of course).
1 parent 6923e96 commit 43d6af7

File tree

2 files changed

+104
-110
lines changed

2 files changed

+104
-110
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 100 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,48 +1571,6 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
15711571
}
15721572
};
15731573

1574-
// Auto-dispatch. Must be the last one.
1575-
template <> struct NDRangeReduction<reduction::strategy::auto_select> {
1576-
// Some readability aliases, to increase signal/noise ratio below.
1577-
template <reduction::strategy Strategy>
1578-
using Impl = NDRangeReduction<Strategy>;
1579-
using S = reduction::strategy;
1580-
1581-
template <typename KernelName, int Dims, typename PropertiesT,
1582-
typename KernelType, typename Reduction>
1583-
static void run(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
1584-
nd_range<Dims> NDRange, PropertiesT &Properties,
1585-
Reduction &Redu, KernelType &KernelFunc) {
1586-
auto Delegate = [&](auto Impl) {
1587-
Impl.template run<KernelName>(CGH, Queue, NDRange, Properties, Redu,
1588-
KernelFunc);
1589-
};
1590-
1591-
if constexpr (Reduction::has_float64_atomics) {
1592-
if (getDeviceFromHandler(CGH).has(aspect::atomic64))
1593-
return Delegate(Impl<S::group_reduce_and_atomic_cross_wg>{});
1594-
1595-
if constexpr (Reduction::has_fast_reduce)
1596-
return Delegate(Impl<S::group_reduce_and_multiple_kernels>{});
1597-
else
1598-
return Delegate(Impl<S::basic>{});
1599-
} else if constexpr (Reduction::has_fast_atomics) {
1600-
if constexpr (Reduction::has_fast_reduce) {
1601-
return Delegate(Impl<S::group_reduce_and_atomic_cross_wg>{});
1602-
} else {
1603-
return Delegate(Impl<S::local_mem_tree_and_atomic_cross_wg>{});
1604-
}
1605-
} else {
1606-
if constexpr (Reduction::has_fast_reduce)
1607-
return Delegate(Impl<S::group_reduce_and_multiple_kernels>{});
1608-
else
1609-
return Delegate(Impl<S::basic>{});
1610-
}
1611-
1612-
assert(false && "Must be unreachable!");
1613-
}
1614-
};
1615-
16161574
/// For the given 'Reductions' types pack and indices enumerating them this
16171575
/// function either creates new temporary accessors for partial sums (if IsOneWG
16181576
/// is false) or returns user's accessor/USM-pointer if (IsOneWG is true).
@@ -2230,21 +2188,109 @@ tuple_select_elements(TupleT Tuple, std::index_sequence<Is...>) {
22302188
return {std::get<Is>(std::move(Tuple))...};
22312189
}
22322190

2191+
template <> struct NDRangeReduction<reduction::strategy::multi> {
2192+
template <typename KernelName, int Dims, typename PropertiesT,
2193+
typename... RestT>
2194+
static void run(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2195+
nd_range<Dims> NDRange, PropertiesT &Properties,
2196+
RestT... Rest) {
2197+
std::tuple<RestT...> ArgsTuple(Rest...);
2198+
constexpr size_t NumArgs = sizeof...(RestT);
2199+
auto KernelFunc = std::get<NumArgs - 1>(ArgsTuple);
2200+
auto ReduIndices = std::make_index_sequence<NumArgs - 1>();
2201+
auto ReduTuple = detail::tuple_select_elements(ArgsTuple, ReduIndices);
2202+
2203+
size_t LocalMemPerWorkItem = reduGetMemPerWorkItem(ReduTuple, ReduIndices);
2204+
// TODO: currently the maximal work group size is determined for the given
2205+
// queue/device, while it is safer to use queries to the kernel compiled
2206+
// for the device.
2207+
size_t MaxWGSize = reduGetMaxWGSize(Queue, LocalMemPerWorkItem);
2208+
if (NDRange.get_local_range().size() > MaxWGSize)
2209+
throw sycl::runtime_error("The implementation handling parallel_for with"
2210+
" reduction requires work group size not bigger"
2211+
" than " +
2212+
std::to_string(MaxWGSize),
2213+
PI_ERROR_INVALID_WORK_GROUP_SIZE);
2214+
2215+
reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
2216+
ReduIndices);
2217+
reduction::finalizeHandler(CGH);
2218+
2219+
size_t NWorkItems = NDRange.get_group_range().size();
2220+
while (NWorkItems > 1) {
2221+
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
2222+
NWorkItems = reduAuxCGFunc<KernelName, decltype(KernelFunc)>(
2223+
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2224+
});
2225+
} // end while (NWorkItems > 1)
2226+
}
2227+
};
2228+
2229+
// Auto-dispatch. Must be the last one.
2230+
template <> struct NDRangeReduction<reduction::strategy::auto_select> {
2231+
// Some readability aliases, to increase signal/noise ratio below.
2232+
template <reduction::strategy Strategy>
2233+
using Impl = NDRangeReduction<Strategy>;
2234+
using Strat = reduction::strategy;
2235+
2236+
template <typename KernelName, int Dims, typename PropertiesT,
2237+
typename KernelType, typename Reduction>
2238+
static void run(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2239+
nd_range<Dims> NDRange, PropertiesT &Properties,
2240+
Reduction &Redu, KernelType &KernelFunc) {
2241+
auto Delegate = [&](auto Impl) {
2242+
Impl.template run<KernelName>(CGH, Queue, NDRange, Properties, Redu,
2243+
KernelFunc);
2244+
};
2245+
2246+
if constexpr (Reduction::has_float64_atomics) {
2247+
if (getDeviceFromHandler(CGH).has(aspect::atomic64))
2248+
return Delegate(Impl<Strat::group_reduce_and_atomic_cross_wg>{});
2249+
2250+
if constexpr (Reduction::has_fast_reduce)
2251+
return Delegate(Impl<Strat::group_reduce_and_multiple_kernels>{});
2252+
else
2253+
return Delegate(Impl<Strat::basic>{});
2254+
} else if constexpr (Reduction::has_fast_atomics) {
2255+
if constexpr (Reduction::has_fast_reduce) {
2256+
return Delegate(Impl<Strat::group_reduce_and_atomic_cross_wg>{});
2257+
} else {
2258+
return Delegate(Impl<Strat::local_mem_tree_and_atomic_cross_wg>{});
2259+
}
2260+
} else {
2261+
if constexpr (Reduction::has_fast_reduce)
2262+
return Delegate(Impl<Strat::group_reduce_and_multiple_kernels>{});
2263+
else
2264+
return Delegate(Impl<Strat::basic>{});
2265+
}
2266+
2267+
assert(false && "Must be unreachable!");
2268+
}
2269+
template <typename KernelName, int Dims, typename PropertiesT,
2270+
typename... RestT>
2271+
static void run(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2272+
nd_range<Dims> NDRange, PropertiesT &Properties,
2273+
RestT... Rest) {
2274+
return Impl<Strat::multi>::run<KernelName>(CGH, Queue, NDRange, Properties,
2275+
Rest...);
2276+
}
2277+
};
2278+
22332279
template <typename KernelName, reduction::strategy Strategy, int Dims,
2234-
typename PropertiesT, typename KernelType, typename Reduction>
2280+
typename PropertiesT, typename... RestT>
22352281
void reduction_parallel_for(handler &CGH,
22362282
std::shared_ptr<detail::queue_impl> Queue,
22372283
nd_range<Dims> NDRange, PropertiesT Properties,
2238-
Reduction Redu, KernelType KernelFunc) {
2239-
NDRangeReduction<Strategy>::template run<KernelName>(
2240-
CGH, Queue, NDRange, Properties, Redu, KernelFunc);
2284+
RestT... Rest) {
2285+
NDRangeReduction<Strategy>::template run<KernelName>(CGH, Queue, NDRange,
2286+
Properties, Rest...);
22412287
}
22422288

22432289
__SYCL_EXPORT uint32_t
22442290
reduGetMaxNumConcurrentWorkGroups(std::shared_ptr<queue_impl> Queue);
22452291

2246-
template <typename KernelName, int Dims, typename PropertiesT,
2247-
typename KernelType, typename Reduction>
2292+
template <typename KernelName, reduction::strategy Strategy, int Dims,
2293+
typename PropertiesT, typename KernelType, typename Reduction>
22482294
void reduction_parallel_for(handler &CGH,
22492295
std::shared_ptr<detail::queue_impl> Queue,
22502296
range<Dims> Range, PropertiesT Properties,
@@ -2303,7 +2349,10 @@ void reduction_parallel_for(handler &CGH,
23032349
KernelFunc(getDelinearizedId(Range, I), Reducer);
23042350
};
23052351

2306-
constexpr auto Strategy = [&]() {
2352+
constexpr auto StrategyToUse = [&]() {
2353+
if constexpr (Strategy != reduction::strategy::auto_select)
2354+
return Strategy;
2355+
23072356
if constexpr (Reduction::has_fast_reduce)
23082357
return reduction::strategy::group_reduce_and_last_wg_detection;
23092358
else if constexpr (Reduction::has_fast_atomics)
@@ -2312,57 +2361,8 @@ void reduction_parallel_for(handler &CGH,
23122361
return reduction::strategy::range_basic;
23132362
}();
23142363

2315-
reduction_parallel_for<KernelName, Strategy>(CGH, Queue, NDRange, Properties,
2316-
Redu, UpdatedKernelFunc);
2317-
}
2318-
2319-
template <> struct NDRangeReduction<reduction::strategy::multi> {
2320-
template <typename KernelName, int Dims, typename PropertiesT,
2321-
typename... RestT>
2322-
static void run(handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2323-
nd_range<Dims> NDRange, PropertiesT &Properties,
2324-
RestT... Rest) {
2325-
std::tuple<RestT...> ArgsTuple(Rest...);
2326-
constexpr size_t NumArgs = sizeof...(RestT);
2327-
auto KernelFunc = std::get<NumArgs - 1>(ArgsTuple);
2328-
auto ReduIndices = std::make_index_sequence<NumArgs - 1>();
2329-
auto ReduTuple = detail::tuple_select_elements(ArgsTuple, ReduIndices);
2330-
2331-
size_t LocalMemPerWorkItem = reduGetMemPerWorkItem(ReduTuple, ReduIndices);
2332-
// TODO: currently the maximal work group size is determined for the given
2333-
// queue/device, while it is safer to use queries to the kernel compiled
2334-
// for the device.
2335-
size_t MaxWGSize = reduGetMaxWGSize(Queue, LocalMemPerWorkItem);
2336-
if (NDRange.get_local_range().size() > MaxWGSize)
2337-
throw sycl::runtime_error("The implementation handling parallel_for with"
2338-
" reduction requires work group size not bigger"
2339-
" than " +
2340-
std::to_string(MaxWGSize),
2341-
PI_ERROR_INVALID_WORK_GROUP_SIZE);
2342-
2343-
reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
2344-
ReduIndices);
2345-
reduction::finalizeHandler(CGH);
2346-
2347-
size_t NWorkItems = NDRange.get_group_range().size();
2348-
while (NWorkItems > 1) {
2349-
reduction::withAuxHandler(CGH, [&](handler &AuxHandler) {
2350-
NWorkItems = reduAuxCGFunc<KernelName, decltype(KernelFunc)>(
2351-
AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2352-
});
2353-
} // end while (NWorkItems > 1)
2354-
}
2355-
};
2356-
2357-
template <typename KernelName, int Dims, typename PropertiesT,
2358-
typename... RestT>
2359-
void reduction_parallel_for(handler &CGH,
2360-
std::shared_ptr<detail::queue_impl> Queue,
2361-
nd_range<Dims> NDRange, PropertiesT Properties,
2362-
RestT... Rest) {
2363-
constexpr auto Strategy = reduction::strategy::multi;
2364-
NDRangeReduction<Strategy>::template run<KernelName>(CGH, Queue, NDRange,
2365-
Properties, Rest...);
2364+
reduction_parallel_for<KernelName, StrategyToUse>(
2365+
CGH, Queue, NDRange, Properties, Redu, UpdatedKernelFunc);
23662366
}
23672367
} // namespace detail
23682368

sycl/include/sycl/reduction_forward.hpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,18 @@ inline void finalizeHandler(handler &CGH);
4444
template <class FunctorTy> void withAuxHandler(handler &CGH, FunctorTy Func);
4545
} // namespace reduction
4646

47-
template <typename KernelName, int Dims, typename PropertiesT,
48-
typename KernelType, typename Reduction>
49-
void reduction_parallel_for(handler &CGH,
50-
std::shared_ptr<detail::queue_impl> Queue,
51-
range<Dims> Range, PropertiesT Properties,
52-
Reduction Redu, KernelType KernelFunc);
53-
5447
template <typename KernelName,
5548
reduction::strategy Strategy = reduction::strategy::auto_select,
5649
int Dims, typename PropertiesT, typename KernelType,
5750
typename Reduction>
5851
void reduction_parallel_for(handler &CGH,
5952
std::shared_ptr<detail::queue_impl> Queue,
60-
nd_range<Dims> NDRange, PropertiesT Properties,
53+
range<Dims> Range, PropertiesT Properties,
6154
Reduction Redu, KernelType KernelFunc);
6255

63-
template <typename KernelName, int Dims, typename PropertiesT,
64-
typename... RestT>
56+
template <typename KernelName,
57+
reduction::strategy Strategy = reduction::strategy::auto_select,
58+
int Dims, typename PropertiesT, typename... RestT>
6559
void reduction_parallel_for(handler &CGH,
6660
std::shared_ptr<detail::queue_impl> Queue,
6761
nd_range<Dims> NDRange, PropertiesT Properties,

0 commit comments

Comments
 (0)