Skip to content

Commit 5b167be

Browse files
committed
Add thread-local values which are preserved with jobs
1 parent e6186ea commit 5b167be

File tree

8 files changed

+81
-9
lines changed

8 files changed

+81
-9
lines changed

rayon-core/src/broadcast/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ where
107107

108108
let n_threads = registry.num_threads();
109109
let current_thread = WorkerThread::current().as_ref();
110+
let tlv = crate::tlv::get();
110111
let latch = ScopeLatch::with_count(n_threads, current_thread);
111112
let jobs: Vec<_> = (0..n_threads)
112-
.map(|_| StackJob::new(&f, LatchRef::new(&latch)))
113+
.map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch)))
113114
.collect();
114115
let job_refs = jobs.iter().map(|job| job.as_job_ref());
115116

rayon-core/src/job.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use crate::latch::Latch;
2+
use crate::tlv;
3+
use crate::tlv::Tlv;
24
use crate::unwind;
35
use crossbeam_deque::{Injector, Steal};
46
use std::any::Any;
@@ -78,6 +80,7 @@ where
7880
pub(super) latch: L,
7981
func: UnsafeCell<Option<F>>,
8082
result: UnsafeCell<JobResult<R>>,
83+
tlv: Tlv,
8184
}
8285

8386
impl<L, F, R> StackJob<L, F, R>
@@ -86,11 +89,12 @@ where
8689
F: FnOnce(bool) -> R + Send,
8790
R: Send,
8891
{
89-
pub(super) fn new(func: F, latch: L) -> StackJob<L, F, R> {
92+
pub(super) fn new(tlv: Tlv, func: F, latch: L) -> StackJob<L, F, R> {
9093
StackJob {
9194
latch,
9295
func: UnsafeCell::new(Some(func)),
9396
result: UnsafeCell::new(JobResult::None),
97+
tlv,
9498
}
9599
}
96100

@@ -115,6 +119,7 @@ where
115119
{
116120
unsafe fn execute(this: *const ()) {
117121
let this = &*(this as *const Self);
122+
tlv::set(this.tlv);
118123
let abort = unwind::AbortIfPanic;
119124
let func = (*this.func.get()).take().unwrap();
120125
(*this.result.get()) = JobResult::call(func);
@@ -134,14 +139,15 @@ where
134139
BODY: FnOnce() + Send,
135140
{
136141
job: BODY,
142+
tlv: Tlv,
137143
}
138144

139145
impl<BODY> HeapJob<BODY>
140146
where
141147
BODY: FnOnce() + Send,
142148
{
143-
pub(super) fn new(job: BODY) -> Box<Self> {
144-
Box::new(HeapJob { job })
149+
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
150+
Box::new(HeapJob { job, tlv })
145151
}
146152

147153
/// Creates a `JobRef` from this job -- note that this hides all
@@ -166,6 +172,7 @@ where
166172
{
167173
unsafe fn execute(this: *const ()) {
168174
let this = Box::from_raw(this as *mut Self);
175+
tlv::set(this.tlv);
169176
(this.job)();
170177
}
171178
}

rayon-core/src/join/mod.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::job::StackJob;
22
use crate::latch::SpinLatch;
33
use crate::registry::{self, WorkerThread};
4+
use crate::tlv::{self, Tlv};
45
use crate::unwind;
56
use std::any::Any;
67

@@ -130,10 +131,11 @@ where
130131
}
131132

132133
registry::in_worker(|worker_thread, injected| unsafe {
134+
let tlv = tlv::get();
133135
// Create virtual wrapper for task b; this all has to be
134136
// done here so that the stack frame can keep it all live
135137
// long enough.
136-
let job_b = StackJob::new(call_b(oper_b), SpinLatch::new(worker_thread));
138+
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
137139
let job_b_ref = job_b.as_job_ref();
138140
let job_b_id = job_b_ref.id();
139141
worker_thread.push(job_b_ref);
@@ -142,7 +144,7 @@ where
142144
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
143145
let result_a = match status_a {
144146
Ok(v) => v,
145-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err),
147+
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
146148
};
147149

148150
// Now that task A has finished, try to pop job B from the
@@ -156,6 +158,10 @@ where
156158
// Found it! Let's run it.
157159
//
158160
// Note that this could panic, but it's ok if we unwind here.
161+
162+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
163+
tlv::set(tlv);
164+
159165
let result_b = job_b.run_inline(injected);
160166
return (result_a, result_b);
161167
} else {
@@ -170,6 +176,9 @@ where
170176
}
171177
}
172178

179+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
180+
tlv::set(tlv);
181+
173182
(result_a, job_b.into_result())
174183
})
175184
}
@@ -182,7 +191,12 @@ unsafe fn join_recover_from_panic(
182191
worker_thread: &WorkerThread,
183192
job_b_latch: &SpinLatch<'_>,
184193
err: Box<dyn Any + Send>,
194+
tlv: Tlv,
185195
) -> ! {
186196
worker_thread.wait_until(job_b_latch);
197+
198+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
199+
tlv::set(tlv);
200+
187201
unwind::resume_unwinding(err)
188202
}

rayon-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ mod unwind;
9393
mod compile_fail;
9494
mod test;
9595

96+
pub mod tlv;
97+
9698
pub use self::broadcast::{broadcast, spawn_broadcast, BroadcastContext};
9799
pub use self::join::{join, join_context};
98100
pub use self::registry::ThreadBuilder;

rayon-core/src/registry.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LatchRef, LockLatc
33
use crate::log::Event::*;
44
use crate::log::Logger;
55
use crate::sleep::Sleep;
6+
use crate::tlv::Tlv;
67
use crate::unwind;
78
use crate::{
89
ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder,
@@ -535,6 +536,7 @@ impl Registry {
535536
// This thread isn't a member of *any* thread pool, so just block.
536537
debug_assert!(WorkerThread::current().is_null());
537538
let job = StackJob::new(
539+
Tlv::null(),
538540
|injected| {
539541
let worker_thread = WorkerThread::current();
540542
assert!(injected && !worker_thread.is_null());
@@ -563,6 +565,7 @@ impl Registry {
563565
debug_assert!(current_thread.registry().id() != self.id());
564566
let latch = SpinLatch::cross(current_thread);
565567
let job = StackJob::new(
568+
Tlv::null(),
566569
|injected| {
567570
let worker_thread = WorkerThread::current();
568571
assert!(injected && !worker_thread.is_null());

rayon-core/src/scope/mod.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::broadcast::BroadcastContext;
99
use crate::job::{ArcJob, HeapJob, JobFifo, JobRef};
1010
use crate::latch::{CountLatch, CountLockLatch, Latch};
1111
use crate::registry::{global_registry, in_worker, Registry, WorkerThread};
12+
use crate::tlv::{self, Tlv};
1213
use crate::unwind;
1314
use std::any::Any;
1415
use std::fmt;
@@ -76,6 +77,9 @@ struct ScopeBase<'scope> {
7677
/// `Sync`, but it's still safe to let the `Scope` implement `Sync` because
7778
/// the closures are only *moved* across threads to be executed.
7879
marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
80+
81+
/// The TLV at the scope's creation. Used to set the TLV for spawned jobs.
82+
tlv: Tlv,
7983
}
8084

8185
/// Creates a "fork-join" scope `s` and invokes the closure with a
@@ -540,7 +544,7 @@ impl<'scope> Scope<'scope> {
540544
BODY: FnOnce(&Scope<'scope>) + Send + 'scope,
541545
{
542546
let scope_ptr = ScopePtr(self);
543-
let job = HeapJob::new(move || unsafe {
547+
let job = HeapJob::new(self.base.tlv, move || unsafe {
544548
// SAFETY: this job will execute before the scope ends.
545549
let scope = scope_ptr.as_ref();
546550
ScopeBase::execute_job(&scope.base, move || body(scope))
@@ -600,7 +604,7 @@ impl<'scope> ScopeFifo<'scope> {
600604
BODY: FnOnce(&ScopeFifo<'scope>) + Send + 'scope,
601605
{
602606
let scope_ptr = ScopePtr(self);
603-
let job = HeapJob::new(move || unsafe {
607+
let job = HeapJob::new(self.base.tlv, move || unsafe {
604608
// SAFETY: this job will execute before the scope ends.
605609
let scope = scope_ptr.as_ref();
606610
ScopeBase::execute_job(&scope.base, move || body(scope))
@@ -652,6 +656,7 @@ impl<'scope> ScopeBase<'scope> {
652656
panic: AtomicPtr::new(ptr::null_mut()),
653657
job_completed_latch: ScopeLatch::new(owner),
654658
marker: PhantomData,
659+
tlv: tlv::get(),
655660
}
656661
}
657662

@@ -690,6 +695,10 @@ impl<'scope> ScopeBase<'scope> {
690695
{
691696
let result = unsafe { Self::execute_job_closure(self, func) };
692697
self.job_completed_latch.wait(owner);
698+
699+
// Restore the TLV if we ran some jobs while waiting
700+
tlv::set(self.tlv);
701+
693702
self.maybe_propagate_panic();
694703
result.unwrap() // only None if `op` panicked, and that would have been propagated
695704
}
@@ -749,6 +758,10 @@ impl<'scope> ScopeBase<'scope> {
749758
let panic = self.panic.swap(ptr::null_mut(), Ordering::Relaxed);
750759
if !panic.is_null() {
751760
let value = unsafe { Box::from_raw(panic) };
761+
762+
// Restore the TLV if we ran some jobs while waiting
763+
tlv::set(self.tlv);
764+
752765
unwind::resume_unwinding(*value);
753766
}
754767
}

rayon-core/src/spawn/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::job::*;
22
use crate::registry::Registry;
3+
use crate::tlv::Tlv;
34
use crate::unwind;
45
use std::mem;
56
use std::sync::Arc;
@@ -91,7 +92,7 @@ where
9192
// executed. This ref is decremented at the (*) below.
9293
registry.increment_terminate_count();
9394

94-
HeapJob::new({
95+
HeapJob::new(Tlv::null(), {
9596
let registry = Arc::clone(registry);
9697
move || {
9798
registry.catch_unwind(func);

rayon-core/src/tlv.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//! Allows access to the Rayon's thread local value
2+
//! which is preserved when moving jobs across threads
3+
4+
use std::{cell::Cell, ptr};
5+
6+
thread_local!(pub static TLV: Cell<*const ()> = const { Cell::new(ptr::null()) });
7+
8+
#[derive(Copy, Clone)]
9+
pub(crate) struct Tlv(pub(crate) *const ());
10+
11+
impl Tlv {
12+
#[inline]
13+
pub(crate) fn null() -> Self {
14+
Self(ptr::null())
15+
}
16+
}
17+
18+
unsafe impl Sync for Tlv {}
19+
unsafe impl Send for Tlv {}
20+
21+
/// Sets the current thread-local value
22+
#[inline]
23+
pub(crate) fn set(value: Tlv) {
24+
TLV.with(|tlv| tlv.set(value.0));
25+
}
26+
27+
/// Returns the current thread-local value
28+
#[inline]
29+
pub(crate) fn get() -> Tlv {
30+
TLV.with(|tlv| Tlv(tlv.get()))
31+
}

0 commit comments

Comments
 (0)