Skip to content

Commit 7232b84

Browse files
committed
Use actual thread-local queues instead of using a RwLock
1 parent 0baba46 commit 7232b84

File tree

2 files changed

+35
-26
lines changed

2 files changed

+35
-26
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ concurrent-queue = "2.0.0"
2121
fastrand = "2.0.0"
2222
futures-lite = { version = "2.0.0", default-features = false }
2323
slab = "0.4.4"
24+
thread_local = "1.0"
2425

2526
[target.'cfg(target_family = "wasm")'.dependencies]
2627
futures-lite = { version = "2.0.0", default-features = false, features = ["std"] }

src/lib.rs

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
use std::fmt;
3737
use std::future::Future;
3838
use std::marker::PhantomData;
39+
use std::ops::Deref;
3940
use std::panic::{RefUnwindSafe, UnwindSafe};
4041
use std::rc::Rc;
4142
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
@@ -47,6 +48,7 @@ use async_task::{Builder, Runnable};
4748
use concurrent_queue::ConcurrentQueue;
4849
use futures_lite::{future, prelude::*};
4950
use slab::Slab;
51+
use thread_local::ThreadLocal;
5052

5153
#[doc(no_inline)]
5254
pub use async_task::Task;
@@ -508,7 +510,7 @@ struct State {
508510
queue: ConcurrentQueue<Runnable>,
509511

510512
/// Local queues created by runners.
511-
local_queues: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
513+
local_queues: ThreadLocal<LocalQueue>,
512514

513515
/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
514516
notified: AtomicBool,
@@ -525,7 +527,7 @@ impl State {
525527
fn new() -> State {
526528
State {
527529
queue: ConcurrentQueue::unbounded(),
528-
local_queues: RwLock::new(Vec::new()),
530+
local_queues: ThreadLocal::new(),
529531
notified: AtomicBool::new(true),
530532
sleepers: Mutex::new(Sleepers {
531533
count: 0,
@@ -756,9 +758,6 @@ struct Runner<'a> {
756758
/// Inner ticker.
757759
ticker: Ticker<'a>,
758760

759-
/// The local queue.
760-
local: Arc<ConcurrentQueue<Runnable>>,
761-
762761
/// Bumped every time a runnable task is found.
763762
ticks: AtomicUsize,
764763
}
@@ -769,38 +768,36 @@ impl Runner<'_> {
769768
let runner = Runner {
770769
state,
771770
ticker: Ticker::new(state),
772-
local: Arc::new(ConcurrentQueue::bounded(512)),
773771
ticks: AtomicUsize::new(0),
774772
};
775-
state
776-
.local_queues
777-
.write()
778-
.unwrap()
779-
.push(runner.local.clone());
780773
runner
781774
}
782775

783776
/// Waits for the next runnable task to run.
784777
async fn runnable(&self, rng: &mut fastrand::Rng) -> Runnable {
778+
let local_queue = self.state.local_queues.get_or_default();
779+
785780
let runnable = self
786781
.ticker
787782
.runnable_with(|| {
783+
let local_queue = self.state.local_queues.get_or_default();
784+
788785
// Try the local queue.
789-
if let Ok(r) = self.local.pop() {
786+
if let Ok(r) = local_queue.pop() {
790787
return Some(r);
791788
}
792789

793790
// Try stealing from the global queue.
794791
if let Ok(r) = self.state.queue.pop() {
795-
steal(&self.state.queue, &self.local);
792+
steal(&self.state.queue, &local_queue);
796793
return Some(r);
797794
}
798795

799796
// Try stealing from other runners.
800-
let local_queues = self.state.local_queues.read().unwrap();
797+
let local_queues = &self.state.local_queues;
801798

802799
// Pick a random starting point in the iterator list and rotate the list.
803-
let n = local_queues.len();
800+
let n = local_queues.iter().count();
804801
let start = rng.usize(..n);
805802
let iter = local_queues
806803
.iter()
@@ -809,12 +806,12 @@ impl Runner<'_> {
809806
.take(n);
810807

811808
// Remove this runner's local queue.
812-
let iter = iter.filter(|local| !Arc::ptr_eq(local, &self.local));
809+
let iter = iter.filter(|local| !core::ptr::eq(local, &local_queue));
813810

814811
// Try stealing from each local queue in the list.
815812
for local in iter {
816-
steal(local, &self.local);
817-
if let Ok(r) = self.local.pop() {
813+
steal(local, &local_queue);
814+
if let Ok(r) = local_queue.pop() {
818815
return Some(r);
819816
}
820817
}
@@ -828,7 +825,7 @@ impl Runner<'_> {
828825

829826
if ticks % 64 == 0 {
830827
// Steal tasks from the global queue to ensure fair task scheduling.
831-
steal(&self.state.queue, &self.local);
828+
steal(&self.state.queue, &local_queue);
832829
}
833830

834831
runnable
@@ -838,14 +835,10 @@ impl Runner<'_> {
838835
impl Drop for Runner<'_> {
839836
fn drop(&mut self) {
840837
// Remove the local queue.
841-
self.state
842-
.local_queues
843-
.write()
844-
.unwrap()
845-
.retain(|local| !Arc::ptr_eq(local, &self.local));
838+
let local_queue = self.state.local_queues.get_or_default();
846839

847840
// Re-schedule remaining tasks in the local queue.
848-
while let Ok(r) = self.local.pop() {
841+
while let Ok(r) = local_queue.pop() {
849842
r.schedule();
850843
}
851844
}
@@ -937,11 +930,26 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
937930
f.debug_struct(name)
938931
.field("active", &ActiveTasks(&state.active))
939932
.field("global_tasks", &state.queue.len())
940-
.field("local_runners", &LocalRunners(&state.local_queues))
941933
.field("sleepers", &SleepCount(&state.sleepers))
942934
.finish()
943935
}
944936

937+
struct LocalQueue(ConcurrentQueue<Runnable>);
938+
939+
impl Default for LocalQueue {
940+
fn default() -> Self {
941+
Self(ConcurrentQueue::bounded(512))
942+
}
943+
}
944+
945+
impl Deref for LocalQueue {
946+
type Target = ConcurrentQueue<Runnable>;
947+
948+
fn deref(&self) -> &Self::Target {
949+
&self.0
950+
}
951+
}
952+
945953
/// Runs a closure when dropped.
946954
struct CallOnDrop<F: FnMut()>(F);
947955

0 commit comments

Comments
 (0)