Skip to content

Commit 0e2ce97

Browse files
authored
Merge pull request #1803 from remlse/rocket-managed-state
Remove blocking IO
2 parents 4d33d5d + 6710fb0 commit 0e2ce97

File tree

5 files changed

+189
-227
lines changed

5 files changed

+189
-227
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ rocket_dyn_templates = { version = "=0.1.0-rc.3", features = ["handlebars"] }
1818
serde = { version = "1.0", features = ["derive"] }
1919
serde_yaml = "0.8.17"
2020
sass-rs = "0.2.1"
21-
reqwest = { version = "0.11.4", features = ["blocking", "json"] }
21+
reqwest = { version = "0.11.4", features = ["json"] }
2222
toml = "0.5"
2323
serde_json = "1.0"
2424
rust_team_data = { git = "https://github.com/rust-lang/team" }

src/cache.rs

Lines changed: 18 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,166 +1,32 @@
1-
use std::any::Any;
2-
use std::collections::HashMap;
31
use std::error::Error;
4-
use std::sync::RwLock;
2+
use std::sync::Arc;
53
use std::time::Instant;
64

5+
use rocket::tokio::sync::RwLock;
76
use rocket::tokio::task;
8-
9-
type CacheItem = (Box<dyn Any + Send + Sync>, Instant);
10-
type Generator = fn() -> Result<Box<dyn Any>, Box<dyn Error>>;
7+
use rocket::State;
118

129
const CACHE_TTL_SECS: u64 = 120;
1310

14-
lazy_static! {
15-
static ref CACHE: RwLock<HashMap<Generator, CacheItem>> = RwLock::new(HashMap::new());
16-
}
17-
18-
pub async fn get<T>(generator: Generator) -> Result<T, Box<dyn Error>>
19-
where
20-
T: Send + Sync + Clone + 'static,
21-
{
22-
if let Some(cached) = get_cached(generator) {
23-
Ok(cached)
24-
} else {
25-
task::spawn_blocking(move || {
26-
update_cache::<T>(generator)
27-
// stringify the error to make it Send
28-
.map_err(|e| e.to_string())
29-
})
30-
.await
31-
.map_err(Box::new)?
32-
// put the previously stringified error back in a box
33-
.map_err(|e| e.as_str().into())
34-
}
35-
}
11+
pub type Cache<T> = State<Arc<RwLock<T>>>;
3612

37-
fn get_cached<T>(generator: Generator) -> Option<T>
38-
where
39-
T: Send + Sync + Clone + 'static,
40-
{
41-
let cache = CACHE.read().unwrap();
42-
cache.get(&generator).map(|&(ref data, timestamp)| {
13+
#[async_trait]
14+
pub trait Cached: Send + Sync + Clone + 'static {
15+
fn get_timestamp(&self) -> Instant;
16+
async fn fetch() -> Result<Self, Box<dyn Error + Send + Sync>>;
17+
async fn get(cache: &Cache<Self>) -> Self {
18+
let cached = cache.read().await.clone();
19+
let timestamp = cached.get_timestamp();
4320
if timestamp.elapsed().as_secs() > CACHE_TTL_SECS {
4421
// Update the cache in the background
45-
task::spawn_blocking(move || {
46-
let _ = update_cache::<T>(generator);
22+
let cache: Arc<_> = cache.inner().clone();
23+
task::spawn(async move {
24+
match Self::fetch().await {
25+
Ok(data) => *cache.write().await = data,
26+
Err(e) => eprintln!("failed to update cache: {e}"),
27+
}
4728
});
4829
}
49-
data.downcast_ref::<T>().unwrap().clone()
50-
})
51-
}
52-
53-
fn update_cache<T>(generator: Generator) -> Result<T, Box<dyn Error>>
54-
where
55-
T: Send + Sync + Clone + 'static,
56-
{
57-
if let Ok(data) = generator()?.downcast::<T>() {
58-
let cloned: T = (*data).clone();
59-
CACHE
60-
.write()
61-
.unwrap()
62-
.insert(generator, (Box::new(cloned), Instant::now()));
63-
Ok(*data)
64-
} else {
65-
Err("the generator returned the wrong type".into())
66-
}
67-
}
68-
69-
#[cfg(test)]
70-
mod tests {
71-
use rocket::tokio;
72-
73-
use super::{get, Generator, CACHE, CACHE_TTL_SECS};
74-
use std::any::Any;
75-
use std::error::Error;
76-
use std::sync::atomic::{AtomicBool, Ordering};
77-
use std::thread;
78-
use std::time::{Duration, Instant};
79-
80-
#[tokio::test]
81-
async fn test_cache_basic() {
82-
static GENERATOR_CALLED: AtomicBool = AtomicBool::new(false);
83-
84-
fn generator() -> Result<Box<dyn Any>, Box<dyn Error>> {
85-
GENERATOR_CALLED.store(true, Ordering::SeqCst);
86-
Ok(Box::new("hello world"))
87-
}
88-
89-
// The first time it will call the generator
90-
GENERATOR_CALLED.store(false, Ordering::SeqCst);
91-
assert_eq!(get::<&'static str>(generator).await.unwrap(), "hello world");
92-
assert!(GENERATOR_CALLED.load(Ordering::SeqCst));
93-
94-
// The second time it won't call the generator, but reuse the latest value
95-
GENERATOR_CALLED.store(false, Ordering::SeqCst);
96-
assert_eq!(get::<&'static str>(generator).await.unwrap(), "hello world");
97-
assert!(!GENERATOR_CALLED.load(Ordering::SeqCst));
98-
}
99-
100-
#[tokio::test]
101-
async fn test_cache_refresh() {
102-
static GENERATOR_CALLED: AtomicBool = AtomicBool::new(false);
103-
104-
fn generator() -> Result<Box<dyn Any>, Box<dyn Error>> {
105-
GENERATOR_CALLED.store(true, Ordering::SeqCst);
106-
thread::sleep(Duration::from_millis(100));
107-
Ok(Box::new("hello world"))
108-
}
109-
110-
// Initialize the value in the cache
111-
GENERATOR_CALLED.store(false, Ordering::SeqCst);
112-
assert_eq!(get::<&'static str>(generator).await.unwrap(), "hello world");
113-
assert!(GENERATOR_CALLED.load(Ordering::SeqCst));
114-
115-
// Tweak the cache to fake an expired TTL
116-
let expired = Instant::now() - Duration::from_secs(CACHE_TTL_SECS * 2);
117-
CACHE
118-
.write()
119-
.unwrap()
120-
.get_mut(&(generator as Generator))
121-
.unwrap()
122-
.1 = expired;
123-
124-
// The second time it won't call the generator, but start another thread to refresh the
125-
// value in the background
126-
GENERATOR_CALLED.store(false, Ordering::SeqCst);
127-
assert_eq!(get::<&'static str>(generator).await.unwrap(), "hello world");
128-
assert!(!GENERATOR_CALLED.load(Ordering::SeqCst));
129-
130-
// Then the background updater thread will finish
131-
thread::sleep(Duration::from_millis(200));
132-
assert!(GENERATOR_CALLED.load(Ordering::SeqCst));
133-
}
134-
135-
#[tokio::test]
136-
async fn test_errors_skip_cache() {
137-
static GENERATOR_CALLED: AtomicBool = AtomicBool::new(false);
138-
139-
fn generator() -> Result<Box<dyn Any>, Box<dyn Error>> {
140-
GENERATOR_CALLED.store(true, Ordering::SeqCst);
141-
Err("an error".into())
142-
}
143-
144-
// The first time it will call the generator
145-
GENERATOR_CALLED.store(false, Ordering::SeqCst);
146-
assert_eq!(
147-
get::<&'static str>(generator)
148-
.await
149-
.unwrap_err()
150-
.to_string(),
151-
"an error"
152-
);
153-
assert!(GENERATOR_CALLED.load(Ordering::SeqCst));
154-
155-
// The second time it will also call the generator
156-
GENERATOR_CALLED.store(false, Ordering::SeqCst);
157-
assert_eq!(
158-
get::<&'static str>(generator)
159-
.await
160-
.unwrap_err()
161-
.to_string(),
162-
"an error"
163-
);
164-
assert!(GENERATOR_CALLED.load(Ordering::SeqCst));
30+
cached
16531
}
16632
}

src/main.rs

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,21 @@ mod redirect;
3030
mod rust_version;
3131
mod teams;
3232

33+
use cache::Cache;
34+
use cache::Cached;
3335
use production::User;
36+
use rocket::tokio::sync::RwLock;
37+
use rust_version::RustReleasePost;
38+
use rust_version::RustVersion;
3439
use teams::encode_zulip_stream;
40+
use teams::RustTeams;
3541

3642
use std::collections::hash_map::DefaultHasher;
3743
use std::env;
3844
use std::fs;
3945
use std::hash::Hasher;
4046
use std::path::{Path, PathBuf};
47+
use std::sync::Arc;
4148

4249
use rand::seq::SliceRandom;
4350

@@ -182,13 +189,20 @@ fn robots_txt() -> Option<content::RawText<&'static str>> {
182189
}
183190

184191
#[get("/")]
185-
async fn index() -> Template {
186-
render_index(ENGLISH.into()).await
192+
async fn index(
193+
version_cache: &Cache<RustVersion>,
194+
release_post_cache: &Cache<RustReleasePost>,
195+
) -> Template {
196+
render_index(ENGLISH.into(), version_cache, release_post_cache).await
187197
}
188198

189199
#[get("/<locale>", rank = 3)]
190-
async fn index_locale(locale: SupportedLocale) -> Template {
191-
render_index(locale.0).await
200+
async fn index_locale(
201+
locale: SupportedLocale,
202+
version_cache: &Cache<RustVersion>,
203+
release_post_cache: &Cache<RustReleasePost>,
204+
) -> Template {
205+
render_index(locale.0, version_cache, release_post_cache).await
192206
}
193207

194208
#[get("/<category>")]
@@ -202,27 +216,35 @@ fn category_locale(category: Category, locale: SupportedLocale) -> Template {
202216
}
203217

204218
#[get("/governance")]
205-
async fn governance() -> Result<Template, Status> {
206-
render_governance(ENGLISH.into()).await
219+
async fn governance(teams_cache: &Cache<RustTeams>) -> Result<Template, Status> {
220+
render_governance(ENGLISH.into(), teams_cache).await
207221
}
208222

209223
#[get("/governance/<section>/<team>", rank = 2)]
210-
async fn team(section: String, team: String) -> Result<Template, Result<Redirect, Status>> {
211-
render_team(section, team, ENGLISH.into()).await
224+
async fn team(
225+
section: String,
226+
team: String,
227+
teams_cache: &Cache<RustTeams>,
228+
) -> Result<Template, Result<Redirect, Status>> {
229+
render_team(section, team, ENGLISH.into(), teams_cache).await
212230
}
213231

214232
#[get("/<locale>/governance", rank = 8)]
215-
async fn governance_locale(locale: SupportedLocale) -> Result<Template, Status> {
216-
render_governance(locale.0).await
233+
async fn governance_locale(
234+
locale: SupportedLocale,
235+
teams_cache: &Cache<RustTeams>,
236+
) -> Result<Template, Status> {
237+
render_governance(locale.0, teams_cache).await
217238
}
218239

219240
#[get("/<locale>/governance/<section>/<team>", rank = 12)]
220241
async fn team_locale(
221242
section: String,
222243
team: String,
223244
locale: SupportedLocale,
245+
teams_cache: &Cache<RustTeams>,
224246
) -> Result<Template, Result<Redirect, Status>> {
225-
render_team(section, team, locale.0).await
247+
render_team(section, team, locale.0, teams_cache).await
226248
}
227249

228250
#[get("/production/users")]
@@ -344,19 +366,26 @@ fn concat_app_js(files: Vec<&str>) -> String {
344366
String::from(&js_path[1..])
345367
}
346368

347-
async fn render_index(lang: String) -> Template {
369+
async fn render_index(
370+
lang: String,
371+
version_cache: &Cache<RustVersion>,
372+
release_post_cache: &Cache<RustReleasePost>,
373+
) -> Template {
348374
#[derive(Serialize)]
349375
struct IndexData {
350376
rust_version: String,
351377
rust_release_post: String,
352378
}
353379

354380
let page = "index".to_string();
381+
let release_post = rust_version::rust_release_post(release_post_cache).await;
355382
let data = IndexData {
356-
rust_version: rust_version::rust_version().await.unwrap_or_default(),
357-
rust_release_post: rust_version::rust_release_post()
358-
.await
359-
.map_or_else(String::new, |v| format!("https://blog.rust-lang.org/{}", v)),
383+
rust_version: rust_version::rust_version(version_cache).await,
384+
rust_release_post: if !release_post.is_empty() {
385+
format!("https://blog.rust-lang.org/{}", release_post)
386+
} else {
387+
String::new()
388+
},
360389
};
361390
let context = Context::new(page.clone(), "", true, data, lang);
362391
Template::render(page, context)
@@ -383,8 +412,11 @@ fn render_production(lang: String) -> Template {
383412
Template::render(page, context)
384413
}
385414

386-
async fn render_governance(lang: String) -> Result<Template, Status> {
387-
match teams::index_data().await {
415+
async fn render_governance(
416+
lang: String,
417+
teams_cache: &Cache<RustTeams>,
418+
) -> Result<Template, Status> {
419+
match teams::index_data(teams_cache).await {
388420
Ok(data) => {
389421
let page = "governance/index".to_string();
390422
let context = Context::new(page.clone(), "governance-page-title", false, data, lang);
@@ -402,8 +434,9 @@ async fn render_team(
402434
section: String,
403435
team: String,
404436
lang: String,
437+
teams_cache: &Cache<RustTeams>,
405438
) -> Result<Template, Result<Redirect, Status>> {
406-
match teams::page_data(&section, &team).await {
439+
match teams::page_data(&section, &team, teams_cache).await {
407440
Ok(data) => {
408441
let page = "governance/group".to_string();
409442
let name = format!("governance-team-{}-name", data.team.name);
@@ -448,7 +481,7 @@ fn render_subject(category: Category, subject: String, lang: String) -> Result<T
448481
}
449482

450483
#[launch]
451-
fn rocket() -> _ {
484+
async fn rocket() -> _ {
452485
let templating = Template::custom(|engine| {
453486
engine
454487
.handlebars
@@ -461,9 +494,16 @@ fn rocket() -> _ {
461494
.register_helper("encode-zulip-stream", Box::new(encode_zulip_stream));
462495
});
463496

497+
let rust_version = RustVersion::fetch().await.unwrap_or_default();
498+
let rust_release_post = RustReleasePost::fetch().await.unwrap_or_default();
499+
let teams = RustTeams::fetch().await.unwrap_or_default();
500+
464501
rocket::build()
465502
.attach(templating)
466503
.attach(headers::InjectHeaders)
504+
.manage(Arc::new(RwLock::new(rust_version)))
505+
.manage(Arc::new(RwLock::new(rust_release_post)))
506+
.manage(Arc::new(RwLock::new(teams)))
467507
.mount(
468508
"/",
469509
routes![

0 commit comments

Comments
 (0)