Skip to content

Commit 93d40b5

Browse files
Zoxccuviper
authored andcommitted
Add a main_handler which is passed an argument to run the proper main loop
1 parent b04baa6 commit 93d40b5

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

rayon-core/src/lib.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ pub struct ThreadPoolBuilder<S = DefaultSpawn> {
156156
/// Closure invoked to spawn threads.
157157
spawn_handler: S,
158158

159+
/// Closure invoked on worker thread start.
160+
main_handler: Option<Box<MainHandler>>,
161+
159162
/// If false, worker threads will execute spawned jobs in a
160163
/// "depth-first" fashion. If true, they will do a "breadth-first"
161164
/// fashion. Depth-first is the default.
@@ -194,12 +197,19 @@ impl Default for ThreadPoolBuilder {
194197
stack_size: None,
195198
start_handler: None,
196199
exit_handler: None,
200+
main_handler: None,
197201
spawn_handler: DefaultSpawn,
198202
breadth_first: false,
199203
}
200204
}
201205
}
202206

207+
/// The type for a closure that gets invoked with a
208+
/// function which runs rayon tasks.
209+
/// The closure is passed the index of the thread on which it is invoked.
210+
/// Note that this same closure may be invoked multiple times in parallel.
211+
type MainHandler = Fn(usize, &mut FnMut()) + Send + Sync;
212+
203213
impl ThreadPoolBuilder {
204214
/// Creates and returns a valid rayon thread pool builder, but does not initialize it.
205215
pub fn new() -> Self {
@@ -380,6 +390,7 @@ impl<S> ThreadPoolBuilder<S> {
380390
stack_size: self.stack_size,
381391
start_handler: self.start_handler,
382392
exit_handler: self.exit_handler,
393+
main_handler: self.main_handler,
383394
breadth_first: self.breadth_first,
384395
}
385396
}
@@ -575,6 +586,24 @@ impl<S> ThreadPoolBuilder<S> {
575586
self.exit_handler = Some(Box::new(exit_handler));
576587
self
577588
}
589+
590+
/// Takes the current thread main callback, leaving `None`.
591+
fn take_main_handler(&mut self) -> Option<Box<MainHandler>> {
592+
self.main_handler.take()
593+
}
594+
595+
/// Set a callback to be invoked on thread main.
596+
///
597+
/// The closure is passed the index of the thread on which it is invoked.
598+
/// Note that this same closure may be invoked multiple times in parallel.
599+
/// If this closure panics, the panic will be passed to the panic handler.
600+
pub fn main_handler<H>(mut self, main_handler: H) -> Self
601+
where
602+
H: Fn(usize, &mut FnMut()) + Send + Sync + 'static,
603+
{
604+
self.main_handler = Some(Box::new(main_handler));
605+
self
606+
}
578607
}
579608

580609
#[allow(deprecated)]
@@ -692,6 +721,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
692721
ref panic_handler,
693722
ref stack_size,
694723
ref start_handler,
724+
ref main_handler,
695725
ref exit_handler,
696726
spawn_handler: _,
697727
ref breadth_first,
@@ -709,6 +739,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
709739
let panic_handler = panic_handler.as_ref().map(|_| ClosurePlaceholder);
710740
let start_handler = start_handler.as_ref().map(|_| ClosurePlaceholder);
711741
let exit_handler = exit_handler.as_ref().map(|_| ClosurePlaceholder);
742+
let main_handler = main_handler.as_ref().map(|_| ClosurePlaceholder);
712743

713744
f.debug_struct("ThreadPoolBuilder")
714745
.field("num_threads", num_threads)
@@ -717,6 +748,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
717748
.field("stack_size", &stack_size)
718749
.field("start_handler", &start_handler)
719750
.field("exit_handler", &exit_handler)
751+
.field("main_handler", &main_handler)
720752
.field("breadth_first", &breadth_first)
721753
.finish()
722754
}

rayon-core/src/registry.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ use std::thread;
2424
use std::usize;
2525
use unwind;
2626
use util::leak;
27-
use {ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder};
27+
use {
28+
ErrorKind, ExitHandler, MainHandler, PanicHandler, StartHandler, ThreadPoolBuildError,
29+
ThreadPoolBuilder,
30+
};
2831

2932
/// Thread builder used for customization via
3033
/// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler).
@@ -140,6 +143,7 @@ pub(super) struct Registry {
140143
panic_handler: Option<Box<PanicHandler>>,
141144
start_handler: Option<Box<StartHandler>>,
142145
exit_handler: Option<Box<ExitHandler>>,
146+
main_handler: Option<Box<MainHandler>>,
143147

144148
// When this latch reaches 0, it means that all work on this
145149
// registry must be complete. This is ensured in the following ways:
@@ -239,6 +243,7 @@ impl Registry {
239243
terminate_latch: CountLatch::new(),
240244
panic_handler: builder.take_panic_handler(),
241245
start_handler: builder.take_start_handler(),
246+
main_handler: builder.take_main_handler(),
242247
exit_handler: builder.take_exit_handler(),
243248
});
244249

@@ -811,7 +816,20 @@ unsafe fn main_loop(worker: Worker<JobRef>, registry: Arc<Registry>, index: usiz
811816
}
812817
}
813818

814-
worker_thread.wait_until(&registry.terminate_latch);
819+
let mut work = || {
820+
worker_thread.wait_until(&registry.terminate_latch);
821+
};
822+
823+
if let Some(ref handler) = registry.main_handler {
824+
match unwind::halt_unwinding(|| handler(index, &mut work)) {
825+
Ok(()) => {}
826+
Err(err) => {
827+
registry.handle_panic(err);
828+
}
829+
}
830+
} else {
831+
work();
832+
}
815833

816834
// Should not be any work left in our queue.
817835
debug_assert!(worker_thread.take_local_job().is_none());

0 commit comments

Comments
 (0)