11
11
12
12
#include " llvm/ADT/STLExtras.h"
13
13
#include " llvm/Config/llvm-config.h"
14
+ #include " llvm/Support/Error.h"
14
15
#include " llvm/Support/MathExtras.h"
15
16
#include " llvm/Support/Threading.h"
16
17
@@ -120,13 +121,17 @@ void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
120
121
llvm::Log2_64 (std::distance (Start, End)) + 1 );
121
122
}
122
123
124
+ // TaskGroup has a relatively high overhead, so we want to reduce
125
+ // the number of spawn() calls. We'll create up to 1024 tasks here.
126
+ // (Note that 1024 is an arbitrary number. This code probably needs
127
+ // improving to take the number of available cores into account.)
128
+ enum { MaxTasksPerGroup = 1024 };
129
+
123
130
template <class IterTy , class FuncTy >
124
131
void parallel_for_each (IterTy Begin, IterTy End, FuncTy Fn) {
125
- // TaskGroup has a relatively high overhead, so we want to reduce
126
- // the number of spawn() calls. We'll create up to 1024 tasks here.
127
- // (Note that 1024 is an arbitrary number. This code probably needs
128
- // improving to take the number of available cores into account.)
129
- ptrdiff_t TaskSize = std::distance (Begin, End) / 1024 ;
132
+ // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
133
+ // overhead on large inputs.
134
+ ptrdiff_t TaskSize = std::distance (Begin, End) / MaxTasksPerGroup;
130
135
if (TaskSize == 0 )
131
136
TaskSize = 1 ;
132
137
@@ -140,7 +145,9 @@ void parallel_for_each(IterTy Begin, IterTy End, FuncTy Fn) {
140
145
141
146
template <class IndexTy , class FuncTy >
142
147
void parallel_for_each_n (IndexTy Begin, IndexTy End, FuncTy Fn) {
143
- ptrdiff_t TaskSize = (End - Begin) / 1024 ;
148
+ // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
149
+ // overhead on large inputs.
150
+ ptrdiff_t TaskSize = (End - Begin) / MaxTasksPerGroup;
144
151
if (TaskSize == 0 )
145
152
TaskSize = 1 ;
146
153
@@ -156,6 +163,50 @@ void parallel_for_each_n(IndexTy Begin, IndexTy End, FuncTy Fn) {
156
163
Fn (J);
157
164
}
158
165
166
+ template <class IterTy , class ResultTy , class ReduceFuncTy ,
167
+ class TransformFuncTy >
168
+ ResultTy parallel_transform_reduce (IterTy Begin, IterTy End, ResultTy Init,
169
+ ReduceFuncTy Reduce,
170
+ TransformFuncTy Transform) {
171
+ // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
172
+ // overhead on large inputs.
173
+ size_t NumInputs = std::distance (Begin, End);
174
+ if (NumInputs == 0 )
175
+ return std::move (Init);
176
+ size_t NumTasks = std::min (static_cast <size_t >(MaxTasksPerGroup), NumInputs);
177
+ std::vector<ResultTy> Results (NumTasks, Init);
178
+ {
179
+ // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs
180
+ // remaining after dividing them equally amongst tasks are distributed as
181
+ // one extra input over the first tasks.
182
+ TaskGroup TG;
183
+ size_t TaskSize = NumInputs / NumTasks;
184
+ size_t RemainingInputs = NumInputs % NumTasks;
185
+ IterTy TBegin = Begin;
186
+ for (size_t TaskId = 0 ; TaskId < NumTasks; ++TaskId) {
187
+ IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0 );
188
+ TG.spawn ([=, &Transform, &Reduce, &Results] {
189
+ // Reduce the result of transformation eagerly within each task.
190
+ ResultTy R = Init;
191
+ for (IterTy It = TBegin; It != TEnd; ++It)
192
+ R = Reduce (R, Transform (*It));
193
+ Results[TaskId] = R;
194
+ });
195
+ TBegin = TEnd;
196
+ }
197
+ assert (TBegin == End);
198
+ }
199
+
200
+ // Do a final reduction. There are at most 1024 tasks, so this only adds
201
+ // constant single-threaded overhead for large inputs. Hopefully most
202
+ // reductions are cheaper than the transformation.
203
+ ResultTy FinalResult = std::move (Results.front ());
204
+ for (ResultTy &PartialResult :
205
+ makeMutableArrayRef (Results.data () + 1 , Results.size () - 1 ))
206
+ FinalResult = Reduce (FinalResult, std::move (PartialResult));
207
+ return std::move (FinalResult);
208
+ }
209
+
159
210
#endif
160
211
161
212
} // namespace detail
@@ -198,6 +249,22 @@ void parallelForEachN(size_t Begin, size_t End, FuncTy Fn) {
198
249
Fn (I);
199
250
}
200
251
252
+ template <class IterTy , class ResultTy , class ReduceFuncTy ,
253
+ class TransformFuncTy >
254
+ ResultTy parallelTransformReduce (IterTy Begin, IterTy End, ResultTy Init,
255
+ ReduceFuncTy Reduce,
256
+ TransformFuncTy Transform) {
257
+ #if LLVM_ENABLE_THREADS
258
+ if (parallel::strategy.ThreadsRequested != 1 ) {
259
+ return parallel::detail::parallel_transform_reduce (Begin, End, Init, Reduce,
260
+ Transform);
261
+ }
262
+ #endif
263
+ for (IterTy I = Begin; I != End; ++I)
264
+ Init = Reduce (std::move (Init), Transform (*I));
265
+ return std::move (Init);
266
+ }
267
+
201
268
// Range wrappers.
202
269
template <class RangeTy ,
203
270
class Comparator = std::less<decltype (*std::begin (RangeTy()))>>
@@ -210,6 +277,31 @@ void parallelForEach(RangeTy &&R, FuncTy Fn) {
210
277
parallelForEach (std::begin (R), std::end (R), Fn);
211
278
}
212
279
280
+ template <class RangeTy , class ResultTy , class ReduceFuncTy ,
281
+ class TransformFuncTy >
282
+ ResultTy parallelTransformReduce (RangeTy &&R, ResultTy Init,
283
+ ReduceFuncTy Reduce,
284
+ TransformFuncTy Transform) {
285
+ return parallelTransformReduce (std::begin (R), std::end (R), Init, Reduce,
286
+ Transform);
287
+ }
288
+
289
+ // Parallel for-each, but with error handling.
290
+ template <class RangeTy , class FuncTy >
291
+ Error parallelForEachError (RangeTy &&R, FuncTy Fn) {
292
+ // The transform_reduce algorithm requires that the initial value be copyable.
293
+ // Error objects are uncopyable. We only need to copy initial success values,
294
+ // so work around this mismatch via the C API. The C API represents success
295
+ // values with a null pointer. The joinErrors discards null values and joins
296
+ // multiple errors into an ErrorList.
297
+ return unwrap (parallelTransformReduce (
298
+ std::begin (R), std::end (R), wrap (Error::success ()),
299
+ [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
300
+ return wrap (joinErrors (unwrap (Lhs), unwrap (Rhs)));
301
+ },
302
+ [&Fn](auto &&V) { return wrap (Fn (V)); }));
303
+ }
304
+
213
305
} // namespace llvm
214
306
215
307
#endif // LLVM_SUPPORT_PARALLEL_H
0 commit comments