Skip to content

Commit 522e5cb

Browse files
author
Jakub Bukaj
committed
roll-up merge of #18941: reem/better-task-pool
2 parents 8306203 + 93c4942 commit 522e5cb

File tree

1 file changed

+167
-63
lines changed

1 file changed

+167
-63
lines changed

src/libstd/sync/task_pool.rs

Lines changed: 167 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
1+
// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
22
// file at the top-level directory of this distribution and at
33
// http://rust-lang.org/COPYRIGHT.
44
//
@@ -12,91 +12,195 @@
1212
1313
use core::prelude::*;
1414

15-
use task;
1615
use task::spawn;
17-
use vec::Vec;
18-
use comm::{channel, Sender};
16+
use comm::{channel, Sender, Receiver};
17+
use sync::{Arc, Mutex};
1918

20-
enum Msg<T> {
21-
Execute(proc(&T):Send),
22-
Quit
19+
struct Sentinel<'a> {
20+
jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>,
21+
active: bool
2322
}
2423

25-
/// A task pool used to execute functions in parallel.
26-
pub struct TaskPool<T> {
27-
channels: Vec<Sender<Msg<T>>>,
28-
next_index: uint,
24+
impl<'a> Sentinel<'a> {
25+
fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel {
26+
Sentinel {
27+
jobs: jobs,
28+
active: true
29+
}
30+
}
31+
32+
// Cancel and destroy this sentinel.
33+
fn cancel(mut self) {
34+
self.active = false;
35+
}
2936
}
3037

3138
#[unsafe_destructor]
32-
impl<T> Drop for TaskPool<T> {
39+
impl<'a> Drop for Sentinel<'a> {
3340
fn drop(&mut self) {
34-
for channel in self.channels.iter_mut() {
35-
channel.send(Quit);
41+
if self.active {
42+
spawn_in_pool(self.jobs.clone())
3643
}
3744
}
3845
}
3946

40-
impl<T> TaskPool<T> {
41-
/// Spawns a new task pool with `n_tasks` tasks. The provided
42-
/// `init_fn_factory` returns a function which, given the index of the
43-
/// task, should return local data to be kept around in that task.
47+
/// A task pool used to execute functions in parallel.
48+
///
49+
/// Spawns `n` worker tasks and replenishes the pool if any worker tasks
50+
/// panic.
51+
///
52+
/// # Example
53+
///
54+
/// ```rust
55+
/// # use sync::TaskPool;
56+
/// # use iter::AdditiveIterator;
57+
///
58+
/// let pool = TaskPool::new(4u);
59+
///
60+
/// let (tx, rx) = channel();
61+
/// for _ in range(0, 8u) {
62+
/// let tx = tx.clone();
63+
/// pool.execute(proc() {
64+
/// tx.send(1u);
65+
/// });
66+
/// }
67+
///
68+
/// assert_eq!(rx.iter().take(8u).sum(), 8u);
69+
/// ```
70+
pub struct TaskPool {
71+
// How the taskpool communicates with subtasks.
72+
//
73+
// This is the only such Sender, so when it is dropped all subtasks will
74+
// quit.
75+
jobs: Sender<proc(): Send>
76+
}
77+
78+
impl TaskPool {
79+
/// Spawns a new task pool with `tasks` tasks.
4480
///
4581
/// # Panics
4682
///
47-
/// This function will panic if `n_tasks` is less than 1.
48-
pub fn new(n_tasks: uint,
49-
init_fn_factory: || -> proc(uint):Send -> T)
50-
-> TaskPool<T> {
51-
assert!(n_tasks >= 1);
52-
53-
let channels = Vec::from_fn(n_tasks, |i| {
54-
let (tx, rx) = channel::<Msg<T>>();
55-
let init_fn = init_fn_factory();
56-
57-
let task_body = proc() {
58-
let local_data = init_fn(i);
59-
loop {
60-
match rx.recv() {
61-
Execute(f) => f(&local_data),
62-
Quit => break
63-
}
64-
}
65-
};
83+
/// This function will panic if `tasks` is 0.
84+
pub fn new(tasks: uint) -> TaskPool {
85+
assert!(tasks >= 1);
6686

67-
// Run on this scheduler.
68-
task::spawn(task_body);
87+
let (tx, rx) = channel::<proc(): Send>();
88+
let rx = Arc::new(Mutex::new(rx));
6989

70-
tx
71-
});
90+
// Taskpool tasks.
91+
for _ in range(0, tasks) {
92+
spawn_in_pool(rx.clone());
93+
}
7294

73-
return TaskPool {
74-
channels: channels,
75-
next_index: 0,
76-
};
95+
TaskPool { jobs: tx }
7796
}
7897

79-
/// Executes the function `f` on a task in the pool. The function
80-
/// receives a reference to the local data returned by the `init_fn`.
81-
pub fn execute(&mut self, f: proc(&T):Send) {
82-
self.channels[self.next_index].send(Execute(f));
83-
self.next_index += 1;
84-
if self.next_index == self.channels.len() { self.next_index = 0; }
98+
/// Executes the function `job` on a task in the pool.
99+
pub fn execute(&self, job: proc():Send) {
100+
self.jobs.send(job);
85101
}
86102
}
87103

88-
#[test]
89-
fn test_task_pool() {
90-
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
91-
let mut pool = TaskPool::new(4, f);
92-
for _ in range(0u, 8) {
93-
pool.execute(proc(i) println!("Hello from thread {}!", *i));
94-
}
104+
fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) {
105+
spawn(proc() {
106+
// Will spawn a new task on panic unless it is cancelled.
107+
let sentinel = Sentinel::new(&jobs);
108+
109+
loop {
110+
let message = {
111+
// Only lock jobs for the time it takes
112+
// to get a job, not run it.
113+
let lock = jobs.lock();
114+
lock.recv_opt()
115+
};
116+
117+
match message {
118+
Ok(job) => job(),
119+
120+
// The Taskpool was dropped.
121+
Err(..) => break
122+
}
123+
}
124+
125+
sentinel.cancel();
126+
})
95127
}
96128

97-
#[test]
98-
#[should_fail]
99-
fn test_zero_tasks_panic() {
100-
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
101-
TaskPool::new(0, f);
129+
#[cfg(test)]
130+
mod test {
131+
use core::prelude::*;
132+
use super::*;
133+
use comm::channel;
134+
use iter::range;
135+
136+
const TEST_TASKS: uint = 4u;
137+
138+
#[test]
139+
fn test_works() {
140+
use iter::AdditiveIterator;
141+
142+
let pool = TaskPool::new(TEST_TASKS);
143+
144+
let (tx, rx) = channel();
145+
for _ in range(0, TEST_TASKS) {
146+
let tx = tx.clone();
147+
pool.execute(proc() {
148+
tx.send(1u);
149+
});
150+
}
151+
152+
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
153+
}
154+
155+
#[test]
156+
#[should_fail]
157+
fn test_zero_tasks_panic() {
158+
TaskPool::new(0);
159+
}
160+
161+
#[test]
162+
fn test_recovery_from_subtask_panic() {
163+
use iter::AdditiveIterator;
164+
165+
let pool = TaskPool::new(TEST_TASKS);
166+
167+
// Panic all the existing tasks.
168+
for _ in range(0, TEST_TASKS) {
169+
pool.execute(proc() { panic!() });
170+
}
171+
172+
// Ensure new tasks were spawned to compensate.
173+
let (tx, rx) = channel();
174+
for _ in range(0, TEST_TASKS) {
175+
let tx = tx.clone();
176+
pool.execute(proc() {
177+
tx.send(1u);
178+
});
179+
}
180+
181+
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
182+
}
183+
184+
#[test]
185+
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
186+
use sync::{Arc, Barrier};
187+
188+
let pool = TaskPool::new(TEST_TASKS);
189+
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
190+
191+
// Panic all the existing tasks in a bit.
192+
for _ in range(0, TEST_TASKS) {
193+
let waiter = waiter.clone();
194+
pool.execute(proc() {
195+
waiter.wait();
196+
panic!();
197+
});
198+
}
199+
200+
drop(pool);
201+
202+
// Kick off the failure.
203+
waiter.wait();
204+
}
102205
}
206+

0 commit comments

Comments
 (0)