Skip to content

Commit c7ba282

Browse files
Zoxccuviper
authored andcommitted
Add thread-local values which are preserved with jobs
1 parent 60cdb43 commit c7ba282

File tree

7 files changed

+71
-7
lines changed

7 files changed

+71
-7
lines changed

rayon-core/src/job.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use latch::Latch;
33
use std::any::Any;
44
use std::cell::UnsafeCell;
55
use std::mem;
6+
use tlv;
67
use unwind;
78

89
pub(super) enum JobResult<T> {
@@ -73,6 +74,7 @@ where
7374
pub(super) latch: L,
7475
func: UnsafeCell<Option<F>>,
7576
result: UnsafeCell<JobResult<R>>,
77+
tlv: usize,
7678
}
7779

7880
impl<L, F, R> StackJob<L, F, R>
@@ -81,11 +83,12 @@ where
8183
F: FnOnce(bool) -> R + Send,
8284
R: Send,
8385
{
84-
pub(super) fn new(func: F, latch: L) -> StackJob<L, F, R> {
86+
pub(super) fn new(tlv: usize, func: F, latch: L) -> StackJob<L, F, R> {
8587
StackJob {
8688
latch,
8789
func: UnsafeCell::new(Some(func)),
8890
result: UnsafeCell::new(JobResult::None),
91+
tlv,
8992
}
9093
}
9194

@@ -114,6 +117,7 @@ where
114117
}
115118

116119
let this = &*this;
120+
tlv::set(this.tlv);
117121
let abort = unwind::AbortIfPanic;
118122
let func = (*this.func.get()).take().unwrap();
119123
(*this.result.get()) = match unwind::halt_unwinding(call(func)) {
@@ -136,15 +140,17 @@ where
136140
BODY: FnOnce() + Send,
137141
{
138142
job: UnsafeCell<Option<BODY>>,
143+
tlv: usize,
139144
}
140145

141146
impl<BODY> HeapJob<BODY>
142147
where
143148
BODY: FnOnce() + Send,
144149
{
145-
pub(super) fn new(func: BODY) -> Self {
150+
pub(super) fn new(tlv: usize, func: BODY) -> Self {
146151
HeapJob {
147152
job: UnsafeCell::new(Some(func)),
153+
tlv,
148154
}
149155
}
150156

@@ -163,6 +169,7 @@ where
163169
{
164170
unsafe fn execute(this: *const Self) {
165171
let this: Box<Self> = mem::transmute(this);
172+
tlv::set(this.tlv);
166173
let job = (*this.job.get()).take().unwrap();
167174
job();
168175
}

rayon-core/src/join/mod.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use latch::{LatchProbe, SpinLatch};
33
use log::Event::*;
44
use registry::{self, WorkerThread};
55
use std::any::Any;
6+
use tlv;
67
use unwind;
78

89
use FnContext;
@@ -135,18 +136,19 @@ where
135136
worker: worker_thread.index()
136137
});
137138

139+
let tlv = tlv::get();
138140
// Create virtual wrapper for task b; this all has to be
139141
// done here so that the stack frame can keep it all live
140142
// long enough.
141-
let job_b = StackJob::new(call_b(oper_b), SpinLatch::new());
143+
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new());
142144
let job_b_ref = job_b.as_job_ref();
143145
worker_thread.push(job_b_ref);
144146

145147
// Execute task a; hopefully b gets stolen in the meantime.
146148
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
147149
let result_a = match status_a {
148150
Ok(v) => v,
149-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err),
151+
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
150152
};
151153

152154
// Now that task A has finished, try to pop job B from the
@@ -163,7 +165,11 @@ where
163165
log!(PoppedRhs {
164166
worker: worker_thread.index()
165167
});
168+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
169+
tlv::set(tlv);
170+
166171
let result_b = job_b.run_inline(injected);
172+
167173
return (result_a, result_b);
168174
} else {
169175
log!(PoppedJob {
@@ -183,6 +189,9 @@ where
183189
}
184190
}
185191

192+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
193+
tlv::set(tlv);
194+
186195
(result_a, job_b.into_result())
187196
})
188197
}
@@ -195,7 +204,12 @@ unsafe fn join_recover_from_panic(
195204
worker_thread: &WorkerThread,
196205
job_b_latch: &SpinLatch,
197206
err: Box<dyn Any + Send>,
207+
tlv: usize,
198208
) -> ! {
199209
worker_thread.wait_until(job_b_latch);
210+
211+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
212+
tlv::set(tlv);
213+
200214
unwind::resume_unwinding(err)
201215
}

rayon-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ mod util;
6464
mod compile_fail;
6565
mod test;
6666

67+
pub mod tlv;
68+
6769
#[cfg(rayon_unstable)]
6870
pub mod internal;
6971
pub use join::{join, join_context};

rayon-core/src/registry.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ impl Registry {
495495
// This thread isn't a member of *any* thread pool, so just block.
496496
debug_assert!(WorkerThread::current().is_null());
497497
let job = StackJob::new(
498+
0,
498499
|injected| {
499500
let worker_thread = WorkerThread::current();
500501
assert!(injected && !worker_thread.is_null());
@@ -519,6 +520,7 @@ impl Registry {
519520
debug_assert!(current_thread.registry().id() != self.id());
520521
let latch = TickleLatch::new(SpinLatch::new(), &current_thread.registry().sleep);
521522
let job = StackJob::new(
523+
0,
522524
|injected| {
523525
let worker_thread = WorkerThread::current();
524526
assert!(injected && !worker_thread.is_null());

rayon-core/src/scope/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::mem;
1515
use std::ptr;
1616
use std::sync::atomic::{AtomicPtr, Ordering};
1717
use std::sync::Arc;
18+
use tlv;
1819
use unwind;
1920

2021
mod internal;
@@ -60,6 +61,9 @@ struct ScopeBase<'scope> {
6061
/// `Sync`, but it's still safe to let the `Scope` implement `Sync` because
6162
/// the closures are only *moved* across threads to be executed.
6263
marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
64+
65+
/// The TLV at the scope's creation. Used to set the TLV for spawned jobs.
66+
tlv: usize,
6367
}
6468

6569
/// Create a "fork-join" scope `s` and invokes the closure with a
@@ -452,7 +456,7 @@ impl<'scope> Scope<'scope> {
452456
{
453457
self.base.increment();
454458
unsafe {
455-
let job_ref = Box::new(HeapJob::new(move || {
459+
let job_ref = Box::new(HeapJob::new(self.base.tlv, move || {
456460
self.base.execute_job(move || body(self))
457461
}))
458462
.as_job_ref();
@@ -493,7 +497,7 @@ impl<'scope> ScopeFifo<'scope> {
493497
{
494498
self.base.increment();
495499
unsafe {
496-
let job_ref = Box::new(HeapJob::new(move || {
500+
let job_ref = Box::new(HeapJob::new(self.base.tlv, move || {
497501
self.base.execute_job(move || body(self))
498502
}))
499503
.as_job_ref();
@@ -520,6 +524,7 @@ impl<'scope> ScopeBase<'scope> {
520524
panic: AtomicPtr::new(ptr::null_mut()),
521525
job_completed_latch: CountLatch::new(),
522526
marker: PhantomData,
527+
tlv: tlv::get(),
523528
}
524529
}
525530

@@ -537,6 +542,8 @@ impl<'scope> ScopeBase<'scope> {
537542
{
538543
let result = self.execute_job_closure(func);
539544
self.steal_till_jobs_complete(owner_thread);
545+
// Restore the TLV if we ran some jobs while waiting
546+
tlv::set(self.tlv);
540547
result.unwrap() // only None if `op` panicked, and that would have been propagated
541548
}
542549

@@ -613,6 +620,8 @@ impl<'scope> ScopeBase<'scope> {
613620
log!(ScopeCompletePanicked {
614621
owner_thread: owner_thread.index()
615622
});
623+
// Restore the TLV if we ran some jobs while waiting
624+
tlv::set(self.tlv);
616625
let value: Box<Box<dyn Any + Send + 'static>> = mem::transmute(panic);
617626
unwind::resume_unwinding(*value);
618627
} else {

rayon-core/src/spawn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ where
9292
// executed. This ref is decremented at the (*) below.
9393
registry.increment_terminate_count();
9494

95-
Box::new(HeapJob::new({
95+
Box::new(HeapJob::new(0, {
9696
let registry = registry.clone();
9797
move || {
9898
match unwind::halt_unwinding(func) {

rayon-core/src/tlv.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//! Allows access to the Rayon's thread local value
2+
//! which is preserved when moving jobs across threads
3+
4+
use std::cell::Cell;
5+
6+
thread_local!(pub(crate) static TLV: Cell<usize> = Cell::new(0));
7+
8+
/// Sets the current thread-local value to `value` inside the closure.
9+
/// The old value is restored when the closure ends
10+
pub fn with<F: FnOnce() -> R, R>(value: usize, f: F) -> R {
11+
struct Reset(usize);
12+
impl Drop for Reset {
13+
fn drop(&mut self) {
14+
TLV.with(|tlv| tlv.set(self.0));
15+
}
16+
}
17+
let _reset = Reset(get());
18+
TLV.with(|tlv| tlv.set(value));
19+
f()
20+
}
21+
22+
/// Sets the current thread-local value
23+
pub fn set(value: usize) {
24+
TLV.with(|tlv| tlv.set(value));
25+
}
26+
27+
/// Returns the current thread-local value
28+
pub fn get() -> usize {
29+
TLV.with(|tlv| tlv.get())
30+
}

0 commit comments

Comments
 (0)