@@ -73,7 +73,7 @@ pub enum ParallelSamplingError {
73
73
#[ error( "Creating a logp function failed" ) ]
74
74
LogpFuncCreation {
75
75
#[ from]
76
- source : Box < dyn std :: error :: Error + Send + Sync > ,
76
+ source : anyhow :: Error ,
77
77
} ,
78
78
}
79
79
@@ -85,20 +85,22 @@ where
85
85
{
86
86
//type Func: CpuLogpFunc;
87
87
88
- fn make_logp_func (
89
- & self ,
90
- chain : usize ,
91
- ) -> Result < Func , Box < dyn std:: error:: Error + Send + Sync > > ;
88
+ fn make_logp_func ( & self , chain : usize ) -> Result < Func , anyhow:: Error > ;
92
89
fn dim ( & self ) -> usize ;
93
90
}
94
91
95
92
/// Sample several chains in parallel and return all of the samples live in a channel
96
- pub fn sample_parallel < M : CpuLogpFuncMaker < F > + ' static , F : CpuLogpFunc , I : InitPointFunc > (
93
+ pub fn sample_parallel <
94
+ M : CpuLogpFuncMaker < F > + ' static ,
95
+ F : CpuLogpFunc ,
96
+ I : InitPointFunc ,
97
+ R : Rng + ?Sized ,
98
+ > (
97
99
logp_func_maker : M ,
98
100
init_point_func : & mut I ,
99
101
settings : SamplerArgs ,
100
102
n_chains : u64 ,
101
- seed : u64 ,
103
+ rng : & mut R ,
102
104
n_try_init : u64 ,
103
105
) -> Result <
104
106
(
@@ -111,7 +113,9 @@ pub fn sample_parallel<M: CpuLogpFuncMaker<F> + 'static, F: CpuLogpFunc, I: Init
111
113
let mut func = logp_func_maker. make_logp_func ( 0 ) ?;
112
114
assert ! ( ndim == func. dim( ) ) ;
113
115
let draws = settings. num_tune + settings. num_draws ;
114
- let mut rng = StdRng :: seed_from_u64 ( seed. wrapping_sub ( 1 ) ) ;
116
+ //let mut rng = StdRng::from_rng(rng).expect("Could not seed rng");
117
+ let mut rng = rand_chacha:: ChaCha8Rng :: from_rng ( rng) . unwrap ( ) ;
118
+
115
119
let mut points: Vec < Result < ( Box < [ f64 ] > , Box < [ f64 ] > ) , <F as CpuLogpFunc >:: Err > > = ( 0 ..n_chains)
116
120
. map ( |_| {
117
121
let mut position = vec ! [ 0. ; ndim] ;
@@ -143,17 +147,21 @@ pub fn sample_parallel<M: CpuLogpFuncMaker<F> + 'static, F: CpuLogpFunc, I: Init
143
147
let ( sender, receiver) = crossbeam:: channel:: bounded ( 128 ) ;
144
148
145
149
let handle = std:: thread:: spawn ( move || {
150
+ let rng = rng. clone ( ) ;
146
151
let results: Vec < Result < ( ) , ParallelSamplingError > > = points
147
152
. into_par_iter ( )
148
153
. with_max_len ( 1 )
149
154
. enumerate ( )
150
155
. map_with ( sender, |sender, ( chain, point) | {
151
156
let func = logp_func_maker. make_logp_func ( chain) ?;
157
+ let mut rng = rng. clone ( ) ;
158
+ rng. set_stream ( chain as u64 ) ;
152
159
let mut sampler = new_sampler (
153
160
func,
154
161
settings,
155
162
chain as u64 ,
156
- seed. wrapping_add ( chain as u64 ) ,
163
+ //seed.wrapping_add(chain as u64),
164
+ & mut rng,
157
165
) ;
158
166
sampler. set_position ( & point. 0 ) ?;
159
167
for _ in 0 ..draws {
@@ -172,11 +180,11 @@ pub fn sample_parallel<M: CpuLogpFuncMaker<F> + 'static, F: CpuLogpFunc, I: Init
172
180
}
173
181
174
182
/// Create a new sampler
175
- pub fn new_sampler < F : CpuLogpFunc > (
183
+ pub fn new_sampler < F : CpuLogpFunc , R : Rng + ? Sized > (
176
184
logp : F ,
177
185
settings : SamplerArgs ,
178
186
chain : u64 ,
179
- seed : u64 ,
187
+ rng : & mut R ,
180
188
) -> impl Chain {
181
189
use crate :: nuts:: AdaptStrategy ;
182
190
let num_tune = settings. num_tune ;
@@ -195,21 +203,20 @@ pub fn new_sampler<F: CpuLogpFunc>(
195
203
store_gradient : settings. store_gradient ,
196
204
} ;
197
205
198
- //let rng = { rand::rngs::StdRng::seed_from_u64(seed) };
199
- let rng = rand:: rngs:: SmallRng :: seed_from_u64 ( seed) ;
206
+ let rng = rand:: rngs:: SmallRng :: from_rng ( rng) . expect ( "Could not seed rng" ) ;
200
207
201
208
NutsChain :: new ( potential, strategy, options, rng, chain)
202
209
}
203
210
204
- pub fn sample_sequentially < F : CpuLogpFunc > (
211
+ pub fn sample_sequentially < F : CpuLogpFunc , R : Rng + ? Sized > (
205
212
logp : F ,
206
213
settings : SamplerArgs ,
207
214
start : & [ f64 ] ,
208
215
draws : u64 ,
209
216
chain : u64 ,
210
- seed : u64 ,
217
+ rng : & mut R ,
211
218
) -> Result < impl Iterator < Item = Result < ( Box < [ f64 ] > , impl SampleStats ) , NutsError > > , NutsError > {
212
- let mut sampler = new_sampler ( logp, settings, chain, seed ) ;
219
+ let mut sampler = new_sampler ( logp, settings, chain, rng ) ;
213
220
sampler. set_position ( start) ?;
214
221
Ok ( ( 0 ..draws) . into_iter ( ) . map ( move |_| sampler. draw ( ) ) )
215
222
}
@@ -259,10 +266,7 @@ pub mod test_logps {
259
266
impl CpuLogpFuncMaker < NormalLogp > for NormalLogp {
260
267
//type Func = Self;
261
268
262
- fn make_logp_func (
263
- & self ,
264
- chain : usize ,
265
- ) -> Result < NormalLogp , Box < dyn std:: error:: Error + Send + Sync > > {
269
+ fn make_logp_func ( & self , chain : usize ) -> Result < NormalLogp , anyhow:: Error > {
266
270
Ok ( self . clone ( ) )
267
271
}
268
272
@@ -362,6 +366,7 @@ mod tests {
362
366
363
367
use itertools:: Itertools ;
364
368
use pretty_assertions:: assert_eq;
369
+ use rand:: { rngs:: StdRng , SeedableRng } ;
365
370
366
371
#[ test]
367
372
fn sample_seq ( ) {
@@ -371,7 +376,9 @@ mod tests {
371
376
settings. num_draws = 100 ;
372
377
let start = vec ! [ 0.2 ; 10 ] ;
373
378
374
- let chain = sample_sequentially ( logp. clone ( ) , settings, & start, 200 , 1 , 42 ) . unwrap ( ) ;
379
+ let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
380
+
381
+ let chain = sample_sequentially ( logp. clone ( ) , settings, & start, 200 , 1 , & mut rng) . unwrap ( ) ;
375
382
let mut draws = chain. collect_vec ( ) ;
376
383
assert_eq ! ( draws. len( ) , 200 ) ;
377
384
@@ -384,7 +391,7 @@ mod tests {
384
391
let maker = logp;
385
392
386
393
let ( handles, chains) =
387
- sample_parallel ( maker, & mut JitterInitFunc :: new ( ) , settings, 4 , 42 , 10 ) . unwrap ( ) ;
394
+ sample_parallel ( maker, & mut JitterInitFunc :: new ( ) , settings, 4 , & mut rng , 10 ) . unwrap ( ) ;
388
395
let mut draws = chains. iter ( ) . collect_vec ( ) ;
389
396
assert_eq ! ( draws. len( ) , 800 ) ;
390
397
assert ! ( handles. join( ) . is_ok( ) ) ;
0 commit comments