@@ -1571,48 +1571,6 @@ template <> struct NDRangeReduction<reduction::strategy::basic> {
1571
1571
}
1572
1572
};
1573
1573
1574
- // Auto-dispatch. Must be the last one.
1575
- template <> struct NDRangeReduction <reduction::strategy::auto_select> {
1576
- // Some readability aliases, to increase signal/noise ratio below.
1577
- template <reduction::strategy Strategy>
1578
- using Impl = NDRangeReduction<Strategy>;
1579
- using S = reduction::strategy;
1580
-
1581
- template <typename KernelName, int Dims, typename PropertiesT,
1582
- typename KernelType, typename Reduction>
1583
- static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
1584
- nd_range<Dims> NDRange, PropertiesT &Properties,
1585
- Reduction &Redu, KernelType &KernelFunc) {
1586
- auto Delegate = [&](auto Impl) {
1587
- Impl.template run <KernelName>(CGH, Queue, NDRange, Properties, Redu,
1588
- KernelFunc);
1589
- };
1590
-
1591
- if constexpr (Reduction::has_float64_atomics) {
1592
- if (getDeviceFromHandler (CGH).has (aspect::atomic64))
1593
- return Delegate (Impl<S::group_reduce_and_atomic_cross_wg>{});
1594
-
1595
- if constexpr (Reduction::has_fast_reduce)
1596
- return Delegate (Impl<S::group_reduce_and_multiple_kernels>{});
1597
- else
1598
- return Delegate (Impl<S::basic>{});
1599
- } else if constexpr (Reduction::has_fast_atomics) {
1600
- if constexpr (Reduction::has_fast_reduce) {
1601
- return Delegate (Impl<S::group_reduce_and_atomic_cross_wg>{});
1602
- } else {
1603
- return Delegate (Impl<S::local_mem_tree_and_atomic_cross_wg>{});
1604
- }
1605
- } else {
1606
- if constexpr (Reduction::has_fast_reduce)
1607
- return Delegate (Impl<S::group_reduce_and_multiple_kernels>{});
1608
- else
1609
- return Delegate (Impl<S::basic>{});
1610
- }
1611
-
1612
- assert (false && " Must be unreachable!" );
1613
- }
1614
- };
1615
-
1616
1574
// / For the given 'Reductions' types pack and indices enumerating them this
1617
1575
// / function either creates new temporary accessors for partial sums (if IsOneWG
1618
1576
// / is false) or returns user's accessor/USM-pointer if (IsOneWG is true).
@@ -2230,21 +2188,109 @@ tuple_select_elements(TupleT Tuple, std::index_sequence<Is...>) {
2230
2188
return {std::get<Is>(std::move (Tuple))...};
2231
2189
}
2232
2190
2191
+ template <> struct NDRangeReduction <reduction::strategy::multi> {
2192
+ template <typename KernelName, int Dims, typename PropertiesT,
2193
+ typename ... RestT>
2194
+ static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2195
+ nd_range<Dims> NDRange, PropertiesT &Properties,
2196
+ RestT... Rest) {
2197
+ std::tuple<RestT...> ArgsTuple (Rest...);
2198
+ constexpr size_t NumArgs = sizeof ...(RestT);
2199
+ auto KernelFunc = std::get<NumArgs - 1 >(ArgsTuple);
2200
+ auto ReduIndices = std::make_index_sequence<NumArgs - 1 >();
2201
+ auto ReduTuple = detail::tuple_select_elements (ArgsTuple, ReduIndices);
2202
+
2203
+ size_t LocalMemPerWorkItem = reduGetMemPerWorkItem (ReduTuple, ReduIndices);
2204
+ // TODO: currently the maximal work group size is determined for the given
2205
+ // queue/device, while it is safer to use queries to the kernel compiled
2206
+ // for the device.
2207
+ size_t MaxWGSize = reduGetMaxWGSize (Queue, LocalMemPerWorkItem);
2208
+ if (NDRange.get_local_range ().size () > MaxWGSize)
2209
+ throw sycl::runtime_error (" The implementation handling parallel_for with"
2210
+ " reduction requires work group size not bigger"
2211
+ " than " +
2212
+ std::to_string (MaxWGSize),
2213
+ PI_ERROR_INVALID_WORK_GROUP_SIZE);
2214
+
2215
+ reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
2216
+ ReduIndices);
2217
+ reduction::finalizeHandler (CGH);
2218
+
2219
+ size_t NWorkItems = NDRange.get_group_range ().size ();
2220
+ while (NWorkItems > 1 ) {
2221
+ reduction::withAuxHandler (CGH, [&](handler &AuxHandler) {
2222
+ NWorkItems = reduAuxCGFunc<KernelName, decltype (KernelFunc)>(
2223
+ AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2224
+ });
2225
+ } // end while (NWorkItems > 1)
2226
+ }
2227
+ };
2228
+
2229
+ // Auto-dispatch. Must be the last one.
2230
+ template <> struct NDRangeReduction <reduction::strategy::auto_select> {
2231
+ // Some readability aliases, to increase signal/noise ratio below.
2232
+ template <reduction::strategy Strategy>
2233
+ using Impl = NDRangeReduction<Strategy>;
2234
+ using Strat = reduction::strategy;
2235
+
2236
+ template <typename KernelName, int Dims, typename PropertiesT,
2237
+ typename KernelType, typename Reduction>
2238
+ static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2239
+ nd_range<Dims> NDRange, PropertiesT &Properties,
2240
+ Reduction &Redu, KernelType &KernelFunc) {
2241
+ auto Delegate = [&](auto Impl) {
2242
+ Impl.template run <KernelName>(CGH, Queue, NDRange, Properties, Redu,
2243
+ KernelFunc);
2244
+ };
2245
+
2246
+ if constexpr (Reduction::has_float64_atomics) {
2247
+ if (getDeviceFromHandler (CGH).has (aspect::atomic64))
2248
+ return Delegate (Impl<Strat::group_reduce_and_atomic_cross_wg>{});
2249
+
2250
+ if constexpr (Reduction::has_fast_reduce)
2251
+ return Delegate (Impl<Strat::group_reduce_and_multiple_kernels>{});
2252
+ else
2253
+ return Delegate (Impl<Strat::basic>{});
2254
+ } else if constexpr (Reduction::has_fast_atomics) {
2255
+ if constexpr (Reduction::has_fast_reduce) {
2256
+ return Delegate (Impl<Strat::group_reduce_and_atomic_cross_wg>{});
2257
+ } else {
2258
+ return Delegate (Impl<Strat::local_mem_tree_and_atomic_cross_wg>{});
2259
+ }
2260
+ } else {
2261
+ if constexpr (Reduction::has_fast_reduce)
2262
+ return Delegate (Impl<Strat::group_reduce_and_multiple_kernels>{});
2263
+ else
2264
+ return Delegate (Impl<Strat::basic>{});
2265
+ }
2266
+
2267
+ assert (false && " Must be unreachable!" );
2268
+ }
2269
+ template <typename KernelName, int Dims, typename PropertiesT,
2270
+ typename ... RestT>
2271
+ static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2272
+ nd_range<Dims> NDRange, PropertiesT &Properties,
2273
+ RestT... Rest) {
2274
+ return Impl<Strat::multi>::run<KernelName>(CGH, Queue, NDRange, Properties,
2275
+ Rest...);
2276
+ }
2277
+ };
2278
+
2233
2279
template <typename KernelName, reduction::strategy Strategy, int Dims,
2234
- typename PropertiesT, typename KernelType, typename Reduction >
2280
+ typename PropertiesT, typename ... RestT >
2235
2281
void reduction_parallel_for (handler &CGH,
2236
2282
std::shared_ptr<detail::queue_impl> Queue,
2237
2283
nd_range<Dims> NDRange, PropertiesT Properties,
2238
- Reduction Redu, KernelType KernelFunc ) {
2239
- NDRangeReduction<Strategy>::template run<KernelName>(
2240
- CGH, Queue, NDRange, Properties, Redu, KernelFunc );
2284
+ RestT... Rest ) {
2285
+ NDRangeReduction<Strategy>::template run<KernelName>(CGH, Queue, NDRange,
2286
+ Properties, Rest... );
2241
2287
}
2242
2288
2243
2289
__SYCL_EXPORT uint32_t
2244
2290
reduGetMaxNumConcurrentWorkGroups (std::shared_ptr<queue_impl> Queue);
2245
2291
2246
- template <typename KernelName, int Dims, typename PropertiesT ,
2247
- typename KernelType, typename Reduction>
2292
+ template <typename KernelName, reduction::strategy Strategy, int Dims ,
2293
+ typename PropertiesT, typename KernelType, typename Reduction>
2248
2294
void reduction_parallel_for (handler &CGH,
2249
2295
std::shared_ptr<detail::queue_impl> Queue,
2250
2296
range<Dims> Range, PropertiesT Properties,
@@ -2303,7 +2349,10 @@ void reduction_parallel_for(handler &CGH,
2303
2349
KernelFunc (getDelinearizedId (Range, I), Reducer);
2304
2350
};
2305
2351
2306
- constexpr auto Strategy = [&]() {
2352
+ constexpr auto StrategyToUse = [&]() {
2353
+ if constexpr (Strategy != reduction::strategy::auto_select)
2354
+ return Strategy;
2355
+
2307
2356
if constexpr (Reduction::has_fast_reduce)
2308
2357
return reduction::strategy::group_reduce_and_last_wg_detection;
2309
2358
else if constexpr (Reduction::has_fast_atomics)
@@ -2312,57 +2361,8 @@ void reduction_parallel_for(handler &CGH,
2312
2361
return reduction::strategy::range_basic;
2313
2362
}();
2314
2363
2315
- reduction_parallel_for<KernelName, Strategy>(CGH, Queue, NDRange, Properties,
2316
- Redu, UpdatedKernelFunc);
2317
- }
2318
-
2319
- template <> struct NDRangeReduction <reduction::strategy::multi> {
2320
- template <typename KernelName, int Dims, typename PropertiesT,
2321
- typename ... RestT>
2322
- static void run (handler &CGH, std::shared_ptr<detail::queue_impl> &Queue,
2323
- nd_range<Dims> NDRange, PropertiesT &Properties,
2324
- RestT... Rest) {
2325
- std::tuple<RestT...> ArgsTuple (Rest...);
2326
- constexpr size_t NumArgs = sizeof ...(RestT);
2327
- auto KernelFunc = std::get<NumArgs - 1 >(ArgsTuple);
2328
- auto ReduIndices = std::make_index_sequence<NumArgs - 1 >();
2329
- auto ReduTuple = detail::tuple_select_elements (ArgsTuple, ReduIndices);
2330
-
2331
- size_t LocalMemPerWorkItem = reduGetMemPerWorkItem (ReduTuple, ReduIndices);
2332
- // TODO: currently the maximal work group size is determined for the given
2333
- // queue/device, while it is safer to use queries to the kernel compiled
2334
- // for the device.
2335
- size_t MaxWGSize = reduGetMaxWGSize (Queue, LocalMemPerWorkItem);
2336
- if (NDRange.get_local_range ().size () > MaxWGSize)
2337
- throw sycl::runtime_error (" The implementation handling parallel_for with"
2338
- " reduction requires work group size not bigger"
2339
- " than " +
2340
- std::to_string (MaxWGSize),
2341
- PI_ERROR_INVALID_WORK_GROUP_SIZE);
2342
-
2343
- reduCGFuncMulti<KernelName>(CGH, KernelFunc, NDRange, Properties, ReduTuple,
2344
- ReduIndices);
2345
- reduction::finalizeHandler (CGH);
2346
-
2347
- size_t NWorkItems = NDRange.get_group_range ().size ();
2348
- while (NWorkItems > 1 ) {
2349
- reduction::withAuxHandler (CGH, [&](handler &AuxHandler) {
2350
- NWorkItems = reduAuxCGFunc<KernelName, decltype (KernelFunc)>(
2351
- AuxHandler, NWorkItems, MaxWGSize, ReduTuple, ReduIndices);
2352
- });
2353
- } // end while (NWorkItems > 1)
2354
- }
2355
- };
2356
-
2357
- template <typename KernelName, int Dims, typename PropertiesT,
2358
- typename ... RestT>
2359
- void reduction_parallel_for (handler &CGH,
2360
- std::shared_ptr<detail::queue_impl> Queue,
2361
- nd_range<Dims> NDRange, PropertiesT Properties,
2362
- RestT... Rest) {
2363
- constexpr auto Strategy = reduction::strategy::multi;
2364
- NDRangeReduction<Strategy>::template run<KernelName>(CGH, Queue, NDRange,
2365
- Properties, Rest...);
2364
+ reduction_parallel_for<KernelName, StrategyToUse>(
2365
+ CGH, Queue, NDRange, Properties, Redu, UpdatedKernelFunc);
2366
2366
}
2367
2367
} // namespace detail
2368
2368
0 commit comments