Skip to content

Commit fa136df

Browse files
committed
WIP for new sampler interface
1 parent fc5e9d4 commit fa136df

File tree

5 files changed

+35
-24
lines changed

5 files changed

+35
-24
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ thiserror = "1.0.31"
3030
rayon = "1.5.3"
3131
ndarray = "0.15.4"
3232
arrow2 = { version = "0.17.0", optional = true }
33+
rand_chacha = "0.3.1"
34+
anyhow = "1.0.70"
3335

3436
[dev-dependencies]
3537
proptest = "1.0.0"

src/cpu_potential.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use crate::nuts::{ArrowBuilder, ArrowRow};
2323
/// If a non-recoverable error occurs during sampling, the sampler will
2424
/// stop and return an error.
2525
pub trait CpuLogpFunc {
26-
type Err: Debug + Send + LogpError + 'static;
26+
type Err: Debug + Send + Sync + LogpError + 'static;
2727

2828
fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::Err>;
2929
fn dim(&self) -> usize;

src/cpu_sampler.rs

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ pub enum ParallelSamplingError {
7373
#[error("Creating a logp function failed")]
7474
LogpFuncCreation {
7575
#[from]
76-
source: Box<dyn std::error::Error + Send + Sync>,
76+
source: anyhow::Error,
7777
},
7878
}
7979

@@ -85,20 +85,22 @@ where
8585
{
8686
//type Func: CpuLogpFunc;
8787

88-
fn make_logp_func(
89-
&self,
90-
chain: usize,
91-
) -> Result<Func, Box<dyn std::error::Error + Send + Sync>>;
88+
fn make_logp_func(&self, chain: usize) -> Result<Func, anyhow::Error>;
9289
fn dim(&self) -> usize;
9390
}
9491

9592
/// Sample several chains in parallel and return all of the samples live in a channel
96-
pub fn sample_parallel<M: CpuLogpFuncMaker<F> + 'static, F: CpuLogpFunc, I: InitPointFunc>(
93+
pub fn sample_parallel<
94+
M: CpuLogpFuncMaker<F> + 'static,
95+
F: CpuLogpFunc,
96+
I: InitPointFunc,
97+
R: Rng + ?Sized,
98+
>(
9799
logp_func_maker: M,
98100
init_point_func: &mut I,
99101
settings: SamplerArgs,
100102
n_chains: u64,
101-
seed: u64,
103+
rng: &mut R,
102104
n_try_init: u64,
103105
) -> Result<
104106
(
@@ -111,7 +113,9 @@ pub fn sample_parallel<M: CpuLogpFuncMaker<F> + 'static, F: CpuLogpFunc, I: Init
111113
let mut func = logp_func_maker.make_logp_func(0)?;
112114
assert!(ndim == func.dim());
113115
let draws = settings.num_tune + settings.num_draws;
114-
let mut rng = StdRng::seed_from_u64(seed.wrapping_sub(1));
116+
//let mut rng = StdRng::from_rng(rng).expect("Could not seed rng");
117+
let mut rng = rand_chacha::ChaCha8Rng::from_rng(rng).unwrap();
118+
115119
let mut points: Vec<Result<(Box<[f64]>, Box<[f64]>), <F as CpuLogpFunc>::Err>> = (0..n_chains)
116120
.map(|_| {
117121
let mut position = vec![0.; ndim];
@@ -143,17 +147,21 @@ pub fn sample_parallel<M: CpuLogpFuncMaker<F> + 'static, F: CpuLogpFunc, I: Init
143147
let (sender, receiver) = crossbeam::channel::bounded(128);
144148

145149
let handle = std::thread::spawn(move || {
150+
let rng = rng.clone();
146151
let results: Vec<Result<(), ParallelSamplingError>> = points
147152
.into_par_iter()
148153
.with_max_len(1)
149154
.enumerate()
150155
.map_with(sender, |sender, (chain, point)| {
151156
let func = logp_func_maker.make_logp_func(chain)?;
157+
let mut rng = rng.clone();
158+
rng.set_stream(chain as u64);
152159
let mut sampler = new_sampler(
153160
func,
154161
settings,
155162
chain as u64,
156-
seed.wrapping_add(chain as u64),
163+
//seed.wrapping_add(chain as u64),
164+
&mut rng,
157165
);
158166
sampler.set_position(&point.0)?;
159167
for _ in 0..draws {
@@ -172,11 +180,11 @@ pub fn sample_parallel<M: CpuLogpFuncMaker<F> + 'static, F: CpuLogpFunc, I: Init
172180
}
173181

174182
/// Create a new sampler
175-
pub fn new_sampler<F: CpuLogpFunc>(
183+
pub fn new_sampler<F: CpuLogpFunc, R: Rng + ?Sized>(
176184
logp: F,
177185
settings: SamplerArgs,
178186
chain: u64,
179-
seed: u64,
187+
rng: &mut R,
180188
) -> impl Chain {
181189
use crate::nuts::AdaptStrategy;
182190
let num_tune = settings.num_tune;
@@ -195,21 +203,20 @@ pub fn new_sampler<F: CpuLogpFunc>(
195203
store_gradient: settings.store_gradient,
196204
};
197205

198-
//let rng = { rand::rngs::StdRng::seed_from_u64(seed) };
199-
let rng = rand::rngs::SmallRng::seed_from_u64(seed);
206+
let rng = rand::rngs::SmallRng::from_rng(rng).expect("Could not seed rng");
200207

201208
NutsChain::new(potential, strategy, options, rng, chain)
202209
}
203210

204-
pub fn sample_sequentially<F: CpuLogpFunc>(
211+
pub fn sample_sequentially<F: CpuLogpFunc, R: Rng + ?Sized>(
205212
logp: F,
206213
settings: SamplerArgs,
207214
start: &[f64],
208215
draws: u64,
209216
chain: u64,
210-
seed: u64,
217+
rng: &mut R,
211218
) -> Result<impl Iterator<Item = Result<(Box<[f64]>, impl SampleStats), NutsError>>, NutsError> {
212-
let mut sampler = new_sampler(logp, settings, chain, seed);
219+
let mut sampler = new_sampler(logp, settings, chain, rng);
213220
sampler.set_position(start)?;
214221
Ok((0..draws).into_iter().map(move |_| sampler.draw()))
215222
}
@@ -259,10 +266,7 @@ pub mod test_logps {
259266
impl CpuLogpFuncMaker<NormalLogp> for NormalLogp {
260267
//type Func = Self;
261268

262-
fn make_logp_func(
263-
&self,
264-
chain: usize,
265-
) -> Result<NormalLogp, Box<dyn std::error::Error + Send + Sync>> {
269+
fn make_logp_func(&self, chain: usize) -> Result<NormalLogp, anyhow::Error> {
266270
Ok(self.clone())
267271
}
268272

@@ -362,6 +366,7 @@ mod tests {
362366

363367
use itertools::Itertools;
364368
use pretty_assertions::assert_eq;
369+
use rand::{rngs::StdRng, SeedableRng};
365370

366371
#[test]
367372
fn sample_seq() {
@@ -371,7 +376,9 @@ mod tests {
371376
settings.num_draws = 100;
372377
let start = vec![0.2; 10];
373378

374-
let chain = sample_sequentially(logp.clone(), settings, &start, 200, 1, 42).unwrap();
379+
let mut rng = StdRng::seed_from_u64(42);
380+
381+
let chain = sample_sequentially(logp.clone(), settings, &start, 200, 1, &mut rng).unwrap();
375382
let mut draws = chain.collect_vec();
376383
assert_eq!(draws.len(), 200);
377384

@@ -384,7 +391,7 @@ mod tests {
384391
let maker = logp;
385392

386393
let (handles, chains) =
387-
sample_parallel(maker, &mut JitterInitFunc::new(), settings, 4, 42, 10).unwrap();
394+
sample_parallel(maker, &mut JitterInitFunc::new(), settings, 4, &mut rng, 10).unwrap();
388395
let mut draws = chains.iter().collect_vec();
389396
assert_eq!(draws.len(), 800);
390397
assert!(handles.join().is_ok());

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,6 @@ pub use cpu_sampler::{
107107
new_sampler, sample_parallel, sample_sequentially, CpuLogpFuncMaker, InitPointFunc,
108108
JitterInitFunc, ParallelChainResult, ParallelSamplingError, SamplerArgs,
109109
};
110+
#[cfg(feature = "arrow")]
111+
pub use nuts::ArrowBuilder;
110112
pub use nuts::{Chain, DivergenceInfo, LogpError, NutsError, SampleStats};

src/nuts.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::SamplerArgs;
1515
#[derive(Error, Debug)]
1616
pub enum NutsError {
1717
#[error("Logp function returned error: {0}")]
18-
LogpFailure(Box<dyn std::error::Error + Send>),
18+
LogpFailure(Box<dyn std::error::Error + Send + Sync>),
1919

2020
#[error("Could not serialize sample stats")]
2121
SerializeFailure(),

0 commit comments

Comments
 (0)