1
1
use std:: { fmt:: Debug , marker:: PhantomData } ;
2
2
3
+ #[ cfg( feature = "arrow" ) ]
4
+ use arrow2:: {
5
+ array:: { MutableArray , MutableFixedSizeListArray , MutablePrimitiveArray , StructArray , TryPush } ,
6
+ datatypes:: { DataType , Field } ,
7
+ } ;
3
8
use itertools:: izip;
4
9
5
10
use crate :: {
6
11
cpu_potential:: { CpuLogpFunc , EuclideanPotential } ,
7
- mass_matrix:: {
8
- DiagMassMatrix , DrawGradCollector , MassMatrix , RunningVariance ,
9
- } ,
10
- nuts:: {
11
- AdaptStrategy , AsSampleStatVec , Collector , Hamiltonian , NutsOptions , SampleStatItem ,
12
- SampleStatValue ,
13
- } ,
12
+ mass_matrix:: { DiagMassMatrix , DrawGradCollector , MassMatrix , RunningVariance } ,
13
+ nuts:: { AdaptStrategy , Collector , Hamiltonian , NutsOptions } ,
14
14
stepsize:: { AcceptanceRateCollector , DualAverage , DualAverageOptions } ,
15
+ DivergenceInfo ,
15
16
} ;
16
17
18
+ #[ cfg( feature = "arrow" ) ]
19
+ use crate :: nuts:: { ArrowBuilder , ArrowRow } ;
20
+ #[ cfg( feature = "arrow" ) ]
21
+ use crate :: SamplerArgs ;
22
+
17
23
const LOWER_LIMIT : f64 = 1e-10f64 ;
18
24
const UPPER_LIMIT : f64 = 1e10f64 ;
19
25
@@ -36,22 +42,55 @@ impl<F, M> DualAverageStrategy<F, M> {
36
42
}
37
43
}
38
44
39
-
40
45
#[ derive( Debug , Clone , Copy ) ]
41
46
pub struct DualAverageStats {
42
47
pub step_size_bar : f64 ,
43
48
pub mean_tree_accept : f64 ,
44
49
pub n_steps : u64 ,
45
50
}
46
51
47
- impl AsSampleStatVec for DualAverageStats {
48
- fn add_to_vec ( & self , vec : & mut Vec < SampleStatItem > ) {
49
- vec. push ( ( "step_size_bar" , SampleStatValue :: F64 ( self . step_size_bar ) ) ) ;
50
- vec. push ( (
51
- "mean_tree_accept" ,
52
- SampleStatValue :: F64 ( self . mean_tree_accept ) ,
53
- ) ) ;
54
- vec. push ( ( "n_steps" , SampleStatValue :: U64 ( self . n_steps ) ) ) ;
52
+ #[ cfg( feature = "arrow" ) ]
53
+ pub struct DualAverageStatsBuilder {
54
+ step_size_bar : MutablePrimitiveArray < f64 > ,
55
+ mean_tree_accept : MutablePrimitiveArray < f64 > ,
56
+ n_steps : MutablePrimitiveArray < u64 > ,
57
+ }
58
+
59
+ #[ cfg( feature = "arrow" ) ]
60
+ impl ArrowBuilder < DualAverageStats > for DualAverageStatsBuilder {
61
+ fn append_value ( & mut self , value : & DualAverageStats ) {
62
+ self . step_size_bar . push ( Some ( value. step_size_bar ) ) ;
63
+ self . mean_tree_accept . push ( Some ( value. mean_tree_accept ) ) ;
64
+ self . n_steps . push ( Some ( value. n_steps ) ) ;
65
+ }
66
+
67
+ fn finalize ( mut self ) -> StructArray {
68
+ let fields = vec ! [
69
+ Field :: new( "step_size_bar" , DataType :: Float64 , false ) ,
70
+ Field :: new( "mean_tree_accept" , DataType :: Float64 , false ) ,
71
+ Field :: new( "n_steps" , DataType :: UInt64 , false ) ,
72
+ ] ;
73
+
74
+ let arrays = vec ! [
75
+ self . step_size_bar. as_box( ) ,
76
+ self . mean_tree_accept. as_box( ) ,
77
+ self . n_steps. as_box( ) ,
78
+ ] ;
79
+
80
+ StructArray :: new ( DataType :: Struct ( fields) , arrays, None )
81
+ }
82
+ }
83
+
84
+ #[ cfg( feature = "arrow" ) ]
85
+ impl ArrowRow for DualAverageStats {
86
+ type Builder = DualAverageStatsBuilder ;
87
+
88
+ fn new_builder ( _dim : usize , _settings : & SamplerArgs ) -> Self :: Builder {
89
+ Self :: Builder {
90
+ step_size_bar : MutablePrimitiveArray :: new ( ) ,
91
+ mean_tree_accept : MutablePrimitiveArray :: new ( ) ,
92
+ n_steps : MutablePrimitiveArray :: new ( ) ,
93
+ }
55
94
}
56
95
}
57
96
@@ -220,18 +259,53 @@ impl<F: CpuLogpFunc> ExpWindowDiagAdapt<F> {
220
259
}
221
260
}
222
261
223
-
224
262
#[ derive( Clone , Debug ) ]
225
263
pub struct ExpWindowDiagAdaptStats {
226
264
pub mass_matrix_inv : Option < Box < [ f64 ] > > ,
227
265
}
228
266
229
- impl AsSampleStatVec for ExpWindowDiagAdaptStats {
230
- fn add_to_vec ( & self , vec : & mut Vec < SampleStatItem > ) {
231
- vec. push ( (
267
+ #[ cfg( feature = "arrow" ) ]
268
+ pub struct ExpWindowDiagAdaptStatsBuilder {
269
+ mass_matrix_inv : MutableFixedSizeListArray < MutablePrimitiveArray < f64 > > ,
270
+ }
271
+
272
+ #[ cfg( feature = "arrow" ) ]
273
+ impl ArrowBuilder < ExpWindowDiagAdaptStats > for ExpWindowDiagAdaptStatsBuilder {
274
+ fn append_value ( & mut self , value : & ExpWindowDiagAdaptStats ) {
275
+ self . mass_matrix_inv
276
+ . try_push (
277
+ value
278
+ . mass_matrix_inv
279
+ . as_ref ( )
280
+ . map ( |vals| vals. iter ( ) . map ( |& x| Some ( x) ) ) ,
281
+ )
282
+ . unwrap ( ) ;
283
+ }
284
+
285
+ fn finalize ( mut self ) -> StructArray {
286
+ let fields = vec ! [ Field :: new(
232
287
"mass_matrix_inv" ,
233
- SampleStatValue :: OptionArray ( self . mass_matrix_inv . clone ( ) ) ,
234
- ) ) ;
288
+ self . mass_matrix_inv. data_type( ) . clone( ) ,
289
+ true ,
290
+ ) ] ;
291
+
292
+ let arrays = vec ! [ self . mass_matrix_inv. as_box( ) ] ;
293
+
294
+ StructArray :: new ( DataType :: Struct ( fields) , arrays, None )
295
+ }
296
+ }
297
+
298
+ #[ cfg( feature = "arrow" ) ]
299
+ impl ArrowRow for ExpWindowDiagAdaptStats {
300
+ type Builder = ExpWindowDiagAdaptStatsBuilder ;
301
+
302
+ fn new_builder ( dim : usize , _settings : & SamplerArgs ) -> Self :: Builder {
303
+ let items = MutablePrimitiveArray :: new ( ) ;
304
+ // TODO Add only based on settings
305
+ let values = MutableFixedSizeListArray :: new_with_field ( items, "item" , false , dim) ;
306
+ Self :: Builder {
307
+ mass_matrix_inv : values,
308
+ }
235
309
}
236
310
}
237
311
@@ -260,16 +334,19 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
260
334
state : & <Self :: Potential as Hamiltonian >:: State ,
261
335
) {
262
336
self . exp_variance_draw . add_sample ( state. q . iter ( ) . copied ( ) ) ;
263
- self . exp_variance_draw_bg . add_sample ( state. q . iter ( ) . copied ( ) ) ;
264
- self . exp_variance_grad . add_sample ( state. grad . iter ( ) . copied ( ) ) ;
265
- self . exp_variance_grad_bg . add_sample ( state. grad . iter ( ) . copied ( ) ) ;
337
+ self . exp_variance_draw_bg
338
+ . add_sample ( state. q . iter ( ) . copied ( ) ) ;
339
+ self . exp_variance_grad
340
+ . add_sample ( state. grad . iter ( ) . copied ( ) ) ;
341
+ self . exp_variance_grad_bg
342
+ . add_sample ( state. grad . iter ( ) . copied ( ) ) ;
266
343
267
344
potential. mass_matrix . update_diag (
268
- state. grad . iter ( ) . map ( |& grad| {
269
- Some ( ( grad) . abs ( ) . recip ( ) . clamp ( LOWER_LIMIT , UPPER_LIMIT ) )
270
- } )
345
+ state
346
+ . grad
347
+ . iter ( )
348
+ . map ( |& grad| Some ( ( grad) . abs ( ) . recip ( ) . clamp ( LOWER_LIMIT , UPPER_LIMIT ) ) ) ,
271
349
) ;
272
-
273
350
}
274
351
275
352
fn adapt (
@@ -303,7 +380,6 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
303
380
}
304
381
}
305
382
306
-
307
383
pub ( crate ) struct GradDiagStrategy < F : CpuLogpFunc > {
308
384
step_size : DualAverageStrategy < F , DiagMassMatrix > ,
309
385
mass_matrix : ExpWindowDiagAdapt < F > ,
@@ -332,8 +408,6 @@ impl Default for GradDiagOptions {
332
408
dual_average_options : DualAverageSettings :: default ( ) ,
333
409
mass_matrix_options : DiagAdaptExpSettings :: default ( ) ,
334
410
early_window : 0.3 ,
335
- //step_size_window: 0.08,
336
- //step_size_window: 0.15,
337
411
step_size_window : 0.2 ,
338
412
mass_matrix_switch_freq : 60 ,
339
413
early_mass_matrix_switch_freq : 10 ,
@@ -345,7 +419,7 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
345
419
type Potential = EuclideanPotential < F , DiagMassMatrix > ;
346
420
type Collector = CombinedCollector <
347
421
AcceptanceRateCollector < <EuclideanPotential < F , DiagMassMatrix > as Hamiltonian >:: State > ,
348
- DrawGradCollector
422
+ DrawGradCollector ,
349
423
> ;
350
424
type Stats = CombinedStats < DualAverageStats , ExpWindowDiagAdaptStats > ;
351
425
type Options = GradDiagOptions ;
@@ -404,14 +478,16 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
404
478
self . mass_matrix . update_estimators ( & collector. collector2 ) ;
405
479
}
406
480
self . mass_matrix . update_potential ( potential) ;
407
- self . step_size . adapt ( options, potential, draw, & collector. collector1 ) ;
481
+ self . step_size
482
+ . adapt ( options, potential, draw, & collector. collector1 ) ;
408
483
return ;
409
484
}
410
485
411
486
if draw == self . num_tune - 1 {
412
487
self . step_size . finalize ( ) ;
413
488
}
414
- self . step_size . adapt ( options, potential, draw, & collector. collector1 ) ;
489
+ self . step_size
490
+ . adapt ( options, potential, draw, & collector. collector1 ) ;
415
491
}
416
492
417
493
fn new_collector ( & self ) -> Self :: Collector {
@@ -438,17 +514,58 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
438
514
}
439
515
}
440
516
517
+ #[ cfg( feature = "arrow" ) ]
518
+ #[ derive( Debug , Clone ) ]
519
+ pub struct CombinedStats < D1 : Debug + ArrowRow , D2 : Debug + ArrowRow > {
520
+ pub stats1 : D1 ,
521
+ pub stats2 : D2 ,
522
+ }
441
523
524
+ #[ cfg( not( feature = "arrow" ) ) ]
442
525
#[ derive( Debug , Clone ) ]
443
526
pub struct CombinedStats < D1 : Debug , D2 : Debug > {
444
527
pub stats1 : D1 ,
445
528
pub stats2 : D2 ,
446
529
}
447
530
448
- impl < D1 : AsSampleStatVec , D2 : AsSampleStatVec > AsSampleStatVec for CombinedStats < D1 , D2 > {
449
- fn add_to_vec ( & self , vec : & mut Vec < SampleStatItem > ) {
450
- self . stats1 . add_to_vec ( vec) ;
451
- self . stats2 . add_to_vec ( vec) ;
531
+ #[ cfg( feature = "arrow" ) ]
532
+ pub struct CombinedStatsBuilder < D1 : ArrowRow , D2 : ArrowRow > {
533
+ stats1 : D1 :: Builder ,
534
+ stats2 : D2 :: Builder ,
535
+ }
536
+
537
+ #[ cfg( feature = "arrow" ) ]
538
+ impl < D1 : Debug + ArrowRow , D2 : Debug + ArrowRow > ArrowRow for CombinedStats < D1 , D2 > {
539
+ type Builder = CombinedStatsBuilder < D1 , D2 > ;
540
+
541
+ fn new_builder ( dim : usize , settings : & SamplerArgs ) -> Self :: Builder {
542
+ Self :: Builder {
543
+ stats1 : D1 :: new_builder ( dim, settings) ,
544
+ stats2 : D2 :: new_builder ( dim, settings) ,
545
+ }
546
+ }
547
+ }
548
+
549
+ #[ cfg( feature = "arrow" ) ]
550
+ impl < D1 : Debug + ArrowRow , D2 : Debug + ArrowRow > ArrowBuilder < CombinedStats < D1 , D2 > >
551
+ for CombinedStatsBuilder < D1 , D2 >
552
+ {
553
+ fn append_value ( & mut self , value : & CombinedStats < D1 , D2 > ) {
554
+ self . stats1 . append_value ( & value. stats1 ) ;
555
+ self . stats2 . append_value ( & value. stats2 ) ;
556
+ }
557
+
558
+ fn finalize ( self ) -> StructArray {
559
+ let mut data1 = self . stats1 . finalize ( ) . into_data ( ) ;
560
+ let data2 = self . stats2 . finalize ( ) . into_data ( ) ;
561
+
562
+ assert ! ( data1. 2 . is_none( ) ) ;
563
+ assert ! ( data2. 2 . is_none( ) ) ;
564
+
565
+ data1. 0 . extend ( data2. 0 ) ;
566
+ data1. 1 . extend ( data2. 1 ) ;
567
+
568
+ StructArray :: new ( DataType :: Struct ( data1. 0 ) , data1. 1 , None )
452
569
}
453
570
}
454
571
@@ -468,7 +585,7 @@ where
468
585
& mut self ,
469
586
start : & Self :: State ,
470
587
end : & Self :: State ,
471
- divergence_info : Option < & dyn crate :: nuts :: DivergenceInfo > ,
588
+ divergence_info : Option < & DivergenceInfo > ,
472
589
) {
473
590
self . collector1
474
591
. register_leapfrog ( start, end, divergence_info) ;
0 commit comments