Skip to content

Commit b04baa6

Browse files
Zoxccuviper
authored andcommitted
Add thread-local values which are preserved with jobs
1 parent b8b97a1 commit b04baa6

File tree

7 files changed

+71
-6
lines changed

7 files changed

+71
-6
lines changed

rayon-core/src/job.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use latch::Latch;
33
use std::any::Any;
44
use std::cell::UnsafeCell;
55
use std::mem;
6+
use tlv;
67
use unwind;
78

89
pub(super) enum JobResult<T> {
@@ -73,6 +74,7 @@ where
7374
pub(super) latch: L,
7475
func: UnsafeCell<Option<F>>,
7576
result: UnsafeCell<JobResult<R>>,
77+
tlv: usize,
7678
}
7779

7880
impl<L, F, R> StackJob<L, F, R>
@@ -81,11 +83,12 @@ where
8183
F: FnOnce(bool) -> R + Send,
8284
R: Send,
8385
{
84-
pub(super) fn new(func: F, latch: L) -> StackJob<L, F, R> {
86+
pub(super) fn new(tlv: usize, func: F, latch: L) -> StackJob<L, F, R> {
8587
StackJob {
8688
latch,
8789
func: UnsafeCell::new(Some(func)),
8890
result: UnsafeCell::new(JobResult::None),
91+
tlv,
8992
}
9093
}
9194

@@ -110,6 +113,7 @@ where
110113
{
111114
unsafe fn execute(this: *const Self) {
112115
let this = &*this;
116+
tlv::set(this.tlv);
113117
let abort = unwind::AbortIfPanic;
114118
let func = (*this.func.get()).take().unwrap();
115119
(*this.result.get()) = match unwind::halt_unwinding(|| func(true)) {
@@ -132,15 +136,17 @@ where
132136
BODY: FnOnce() + Send,
133137
{
134138
job: UnsafeCell<Option<BODY>>,
139+
tlv: usize,
135140
}
136141

137142
impl<BODY> HeapJob<BODY>
138143
where
139144
BODY: FnOnce() + Send,
140145
{
141-
pub(super) fn new(func: BODY) -> Self {
146+
pub(super) fn new(tlv: usize, func: BODY) -> Self {
142147
HeapJob {
143148
job: UnsafeCell::new(Some(func)),
149+
tlv,
144150
}
145151
}
146152

@@ -159,6 +165,7 @@ where
159165
{
160166
unsafe fn execute(this: *const Self) {
161167
let this: Box<Self> = mem::transmute(this);
168+
tlv::set(this.tlv);
162169
let job = (*this.job.get()).take().unwrap();
163170
job();
164171
}

rayon-core/src/join/mod.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use latch::{LatchProbe, SpinLatch};
33
use log::Event::*;
44
use registry::{self, WorkerThread};
55
use std::any::Any;
6+
use tlv;
67
use unwind;
78

89
use FnContext;
@@ -120,10 +121,12 @@ where
120121
worker: worker_thread.index()
121122
});
122123

124+
let tlv = tlv::get();
123125
// Create virtual wrapper for task b; this all has to be
124126
// done here so that the stack frame can keep it all live
125127
// long enough.
126128
let job_b = StackJob::new(
129+
tlv,
127130
|migrated| oper_b(FnContext::new(migrated)),
128131
SpinLatch::new(),
129132
);
@@ -134,7 +137,7 @@ where
134137
let status_a = unwind::halt_unwinding(move || oper_a(FnContext::new(injected)));
135138
let result_a = match status_a {
136139
Ok(v) => v,
137-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err),
140+
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
138141
};
139142

140143
// Now that task A has finished, try to pop job B from the
@@ -151,7 +154,11 @@ where
151154
log!(PoppedRhs {
152155
worker: worker_thread.index()
153156
});
157+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
158+
tlv::set(tlv);
159+
154160
let result_b = job_b.run_inline(injected);
161+
155162
return (result_a, result_b);
156163
} else {
157164
log!(PoppedJob {
@@ -171,6 +178,9 @@ where
171178
}
172179
}
173180

181+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
182+
tlv::set(tlv);
183+
174184
(result_a, job_b.into_result())
175185
})
176186
}
@@ -183,7 +193,12 @@ unsafe fn join_recover_from_panic(
183193
worker_thread: &WorkerThread,
184194
job_b_latch: &SpinLatch,
185195
err: Box<Any + Send>,
196+
tlv: usize,
186197
) -> ! {
187198
worker_thread.wait_until(job_b_latch);
199+
200+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
201+
tlv::set(tlv);
202+
188203
unwind::resume_unwinding(err)
189204
}

rayon-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ mod util;
6464
mod compile_fail;
6565
mod test;
6666

67+
pub mod tlv;
68+
6769
#[cfg(rayon_unstable)]
6870
pub mod internal;
6971
pub use join::{join, join_context};

rayon-core/src/registry.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ impl Registry {
489489
// This thread isn't a member of *any* thread pool, so just block.
490490
debug_assert!(WorkerThread::current().is_null());
491491
let job = StackJob::new(
492+
0,
492493
|injected| {
493494
let worker_thread = WorkerThread::current();
494495
assert!(injected && !worker_thread.is_null());
@@ -512,6 +513,7 @@ impl Registry {
512513
debug_assert!(current_thread.registry().id() != self.id());
513514
let latch = TickleLatch::new(SpinLatch::new(), &current_thread.registry().sleep);
514515
let job = StackJob::new(
516+
0,
515517
|injected| {
516518
let worker_thread = WorkerThread::current();
517519
assert!(injected && !worker_thread.is_null());

rayon-core/src/scope/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::mem;
1515
use std::ptr;
1616
use std::sync::atomic::{AtomicPtr, Ordering};
1717
use std::sync::Arc;
18+
use tlv;
1819
use unwind;
1920

2021
mod internal;
@@ -60,6 +61,9 @@ struct ScopeBase<'scope> {
6061
/// `Sync`, but it's still safe to let the `Scope` implement `Sync` because
6162
/// the closures are only *moved* across threads to be executed.
6263
marker: PhantomData<Box<FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
64+
65+
/// The TLV at the scope's creation. Used to set the TLV for spawned jobs.
66+
tlv: usize,
6367
}
6468

6569
/// Create a "fork-join" scope `s` and invokes the closure with a
@@ -452,7 +456,7 @@ impl<'scope> Scope<'scope> {
452456
{
453457
self.base.increment();
454458
unsafe {
455-
let job_ref = Box::new(HeapJob::new(move || {
459+
let job_ref = Box::new(HeapJob::new(self.base.tlv, move || {
456460
self.base.execute_job(move || body(self))
457461
}))
458462
.as_job_ref();
@@ -493,7 +497,7 @@ impl<'scope> ScopeFifo<'scope> {
493497
{
494498
self.base.increment();
495499
unsafe {
496-
let job_ref = Box::new(HeapJob::new(move || {
500+
let job_ref = Box::new(HeapJob::new(self.base.tlv, move || {
497501
self.base.execute_job(move || body(self))
498502
}))
499503
.as_job_ref();
@@ -520,6 +524,7 @@ impl<'scope> ScopeBase<'scope> {
520524
panic: AtomicPtr::new(ptr::null_mut()),
521525
job_completed_latch: CountLatch::new(),
522526
marker: PhantomData,
527+
tlv: tlv::get(),
523528
}
524529
}
525530

@@ -537,6 +542,8 @@ impl<'scope> ScopeBase<'scope> {
537542
{
538543
let result = self.execute_job_closure(func);
539544
self.steal_till_jobs_complete(owner_thread);
545+
// Restore the TLV if we ran some jobs while waiting
546+
tlv::set(self.tlv);
540547
result.unwrap() // only None if `op` panicked, and that would have been propagated
541548
}
542549

@@ -613,6 +620,8 @@ impl<'scope> ScopeBase<'scope> {
613620
log!(ScopeCompletePanicked {
614621
owner_thread: owner_thread.index()
615622
});
623+
// Restore the TLV if we ran some jobs while waiting
624+
tlv::set(self.tlv);
616625
let value: Box<Box<Any + Send + 'static>> = mem::transmute(panic);
617626
unwind::resume_unwinding(*value);
618627
} else {

rayon-core/src/spawn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ where
9292
// executed. This ref is decremented at the (*) below.
9393
registry.increment_terminate_count();
9494

95-
Box::new(HeapJob::new({
95+
Box::new(HeapJob::new(0, {
9696
let registry = registry.clone();
9797
move || {
9898
match unwind::halt_unwinding(func) {

rayon-core/src/tlv.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//! Allows access to the Rayon's thread local value
2+
//! which is preserved when moving jobs across threads
3+
4+
use std::cell::Cell;
5+
6+
thread_local!(pub(crate) static TLV: Cell<usize> = Cell::new(0));
7+
8+
/// Sets the current thread-local value to `value` inside the closure.
9+
/// The old value is restored when the closure ends
10+
pub fn with<F: FnOnce() -> R, R>(value: usize, f: F) -> R {
11+
struct Reset(usize);
12+
impl Drop for Reset {
13+
fn drop(&mut self) {
14+
TLV.with(|tlv| tlv.set(self.0));
15+
}
16+
}
17+
let _reset = Reset(get());
18+
TLV.with(|tlv| tlv.set(value));
19+
f()
20+
}
21+
22+
/// Sets the current thread-local value
23+
pub fn set(value: usize) {
24+
TLV.with(|tlv| tlv.set(value));
25+
}
26+
27+
/// Returns the current thread-local value
28+
pub fn get() -> usize {
29+
TLV.with(|tlv| tlv.get())
30+
}

0 commit comments

Comments
 (0)