Skip to content

Commit 93c4942

Browse files
committed
Rewrite std::sync::TaskPool to be load balancing and panic-resistant
The previous implementation was very likely to cause panics during unwinding through this process: - child panics, drops its receiver - taskpool comes back around and sends another job over to that child - the child receiver has hung up, so the taskpool panics on send - during unwinding, the taskpool attempts to send a quit message to the child, causing a panic during unwinding - panic during unwinding causes a process abort This meant that TaskPool upgraded any child panic to a full process abort. This came up in Iron when it caused crashes in long-running servers. This implementation uses a single channel to communicate between spawned tasks and the TaskPool, which significantly reduces the complexity of the implementation and cuts down on allocation. The TaskPool uses the channel as a single-producer-multiple-consumer queue. Additionally, through the use of send_opt and recv_opt instead of send and recv, this TaskPool is robust on the face of child panics, both before, during, and after the TaskPool itself is dropped. Due to the TaskPool no longer using an `init_fn_factory`, this is a [breaking-change] otherwise, the API has not changed. If you used `init_fn_factory` in your code, and this change breaks for you, you can instead use an `AtomicUint` counter and a channel to move information into child tasks.
1 parent 15ba87f commit 93c4942

File tree

1 file changed

+167
-63
lines changed

1 file changed

+167
-63
lines changed

src/libstd/sync/task_pool.rs

Lines changed: 167 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
1+
// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
22
// file at the top-level directory of this distribution and at
33
// http://rust-lang.org/COPYRIGHT.
44
//
@@ -12,91 +12,195 @@
1212
1313
use core::prelude::*;
1414

15-
use task;
1615
use task::spawn;
17-
use vec::Vec;
18-
use comm::{channel, Sender};
16+
use comm::{channel, Sender, Receiver};
17+
use sync::{Arc, Mutex};
1918

20-
enum Msg<T> {
21-
Execute(proc(&T):Send),
22-
Quit
19+
struct Sentinel<'a> {
20+
jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>,
21+
active: bool
2322
}
2423

25-
/// A task pool used to execute functions in parallel.
26-
pub struct TaskPool<T> {
27-
channels: Vec<Sender<Msg<T>>>,
28-
next_index: uint,
24+
impl<'a> Sentinel<'a> {
25+
fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel {
26+
Sentinel {
27+
jobs: jobs,
28+
active: true
29+
}
30+
}
31+
32+
// Cancel and destroy this sentinel.
33+
fn cancel(mut self) {
34+
self.active = false;
35+
}
2936
}
3037

3138
#[unsafe_destructor]
32-
impl<T> Drop for TaskPool<T> {
39+
impl<'a> Drop for Sentinel<'a> {
3340
fn drop(&mut self) {
34-
for channel in self.channels.iter_mut() {
35-
channel.send(Quit);
41+
if self.active {
42+
spawn_in_pool(self.jobs.clone())
3643
}
3744
}
3845
}
3946

40-
impl<T> TaskPool<T> {
41-
/// Spawns a new task pool with `n_tasks` tasks. The provided
42-
/// `init_fn_factory` returns a function which, given the index of the
43-
/// task, should return local data to be kept around in that task.
47+
/// A task pool used to execute functions in parallel.
48+
///
49+
/// Spawns `n` worker tasks and replenishes the pool if any worker tasks
50+
/// panic.
51+
///
52+
/// # Example
53+
///
54+
/// ```rust
55+
/// # use sync::TaskPool;
56+
/// # use iter::AdditiveIterator;
57+
///
58+
/// let pool = TaskPool::new(4u);
59+
///
60+
/// let (tx, rx) = channel();
61+
/// for _ in range(0, 8u) {
62+
/// let tx = tx.clone();
63+
/// pool.execute(proc() {
64+
/// tx.send(1u);
65+
/// });
66+
/// }
67+
///
68+
/// assert_eq!(rx.iter().take(8u).sum(), 8u);
69+
/// ```
70+
pub struct TaskPool {
71+
// How the taskpool communicates with subtasks.
72+
//
73+
// This is the only such Sender, so when it is dropped all subtasks will
74+
// quit.
75+
jobs: Sender<proc(): Send>
76+
}
77+
78+
impl TaskPool {
79+
/// Spawns a new task pool with `tasks` tasks.
4480
///
4581
/// # Panics
4682
///
47-
/// This function will panic if `n_tasks` is less than 1.
48-
pub fn new(n_tasks: uint,
49-
init_fn_factory: || -> proc(uint):Send -> T)
50-
-> TaskPool<T> {
51-
assert!(n_tasks >= 1);
52-
53-
let channels = Vec::from_fn(n_tasks, |i| {
54-
let (tx, rx) = channel::<Msg<T>>();
55-
let init_fn = init_fn_factory();
56-
57-
let task_body = proc() {
58-
let local_data = init_fn(i);
59-
loop {
60-
match rx.recv() {
61-
Execute(f) => f(&local_data),
62-
Quit => break
63-
}
64-
}
65-
};
83+
/// This function will panic if `tasks` is 0.
84+
pub fn new(tasks: uint) -> TaskPool {
85+
assert!(tasks >= 1);
6686

67-
// Run on this scheduler.
68-
task::spawn(task_body);
87+
let (tx, rx) = channel::<proc(): Send>();
88+
let rx = Arc::new(Mutex::new(rx));
6989

70-
tx
71-
});
90+
// Taskpool tasks.
91+
for _ in range(0, tasks) {
92+
spawn_in_pool(rx.clone());
93+
}
7294

73-
return TaskPool {
74-
channels: channels,
75-
next_index: 0,
76-
};
95+
TaskPool { jobs: tx }
7796
}
7897

79-
/// Executes the function `f` on a task in the pool. The function
80-
/// receives a reference to the local data returned by the `init_fn`.
81-
pub fn execute(&mut self, f: proc(&T):Send) {
82-
self.channels[self.next_index].send(Execute(f));
83-
self.next_index += 1;
84-
if self.next_index == self.channels.len() { self.next_index = 0; }
98+
/// Executes the function `job` on a task in the pool.
99+
pub fn execute(&self, job: proc():Send) {
100+
self.jobs.send(job);
85101
}
86102
}
87103

88-
#[test]
89-
fn test_task_pool() {
90-
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
91-
let mut pool = TaskPool::new(4, f);
92-
for _ in range(0u, 8) {
93-
pool.execute(proc(i) println!("Hello from thread {}!", *i));
94-
}
104+
fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) {
105+
spawn(proc() {
106+
// Will spawn a new task on panic unless it is cancelled.
107+
let sentinel = Sentinel::new(&jobs);
108+
109+
loop {
110+
let message = {
111+
// Only lock jobs for the time it takes
112+
// to get a job, not run it.
113+
let lock = jobs.lock();
114+
lock.recv_opt()
115+
};
116+
117+
match message {
118+
Ok(job) => job(),
119+
120+
// The Taskpool was dropped.
121+
Err(..) => break
122+
}
123+
}
124+
125+
sentinel.cancel();
126+
})
95127
}
96128

97-
#[test]
98-
#[should_fail]
99-
fn test_zero_tasks_panic() {
100-
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
101-
TaskPool::new(0, f);
129+
#[cfg(test)]
130+
mod test {
131+
use core::prelude::*;
132+
use super::*;
133+
use comm::channel;
134+
use iter::range;
135+
136+
const TEST_TASKS: uint = 4u;
137+
138+
#[test]
139+
fn test_works() {
140+
use iter::AdditiveIterator;
141+
142+
let pool = TaskPool::new(TEST_TASKS);
143+
144+
let (tx, rx) = channel();
145+
for _ in range(0, TEST_TASKS) {
146+
let tx = tx.clone();
147+
pool.execute(proc() {
148+
tx.send(1u);
149+
});
150+
}
151+
152+
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
153+
}
154+
155+
#[test]
156+
#[should_fail]
157+
fn test_zero_tasks_panic() {
158+
TaskPool::new(0);
159+
}
160+
161+
#[test]
162+
fn test_recovery_from_subtask_panic() {
163+
use iter::AdditiveIterator;
164+
165+
let pool = TaskPool::new(TEST_TASKS);
166+
167+
// Panic all the existing tasks.
168+
for _ in range(0, TEST_TASKS) {
169+
pool.execute(proc() { panic!() });
170+
}
171+
172+
// Ensure new tasks were spawned to compensate.
173+
let (tx, rx) = channel();
174+
for _ in range(0, TEST_TASKS) {
175+
let tx = tx.clone();
176+
pool.execute(proc() {
177+
tx.send(1u);
178+
});
179+
}
180+
181+
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
182+
}
183+
184+
#[test]
185+
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
186+
use sync::{Arc, Barrier};
187+
188+
let pool = TaskPool::new(TEST_TASKS);
189+
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
190+
191+
// Panic all the existing tasks in a bit.
192+
for _ in range(0, TEST_TASKS) {
193+
let waiter = waiter.clone();
194+
pool.execute(proc() {
195+
waiter.wait();
196+
panic!();
197+
});
198+
}
199+
200+
drop(pool);
201+
202+
// Kick off the failure.
203+
waiter.wait();
204+
}
102205
}
206+

0 commit comments

Comments
 (0)