1
+ use arrow2:: array:: { MutableFixedSizeListArray , TryPush } ;
1
2
#[ cfg( feature = "arrow" ) ]
2
3
use arrow2:: {
3
4
array:: { MutableArray , MutableBooleanArray , MutablePrimitiveArray , StructArray } ,
@@ -368,6 +369,7 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
368
369
pub struct NutsOptions {
369
370
pub maxdepth : u64 ,
370
371
pub store_gradient : bool ,
372
+ pub store_unconstrained : bool ,
371
373
}
372
374
373
375
pub ( crate ) fn draw < P , R , C > (
@@ -435,6 +437,7 @@ pub(crate) struct NutsSampleStats<HStats: Send + Debug, AdaptStats: Send + Debug
435
437
pub chain : u64 ,
436
438
pub draw : u64 ,
437
439
pub gradient : Option < Box < [ f64 ] > > ,
440
+ pub unconstrained : Option < Box < [ f64 ] > > ,
438
441
pub potential_stats : HStats ,
439
442
pub strategy_stats : AdaptStats ,
440
443
}
@@ -461,6 +464,8 @@ pub trait SampleStats: Send + Debug {
461
464
/// The logp gradient at the location of the draw. This is only stored
462
465
/// if NutsOptions.store_gradient is `true`.
463
466
fn gradient ( & self ) -> Option < & [ f64 ] > ;
467
+ /// The draw in the unconstrained space.
468
+ fn unconstrained ( & self ) -> Option < & [ f64 ] > ;
464
469
}
465
470
466
471
impl < H , A > SampleStats for NutsSampleStats < H , A >
@@ -495,6 +500,9 @@ where
495
500
fn gradient ( & self ) -> Option < & [ f64 ] > {
496
501
self . gradient . as_ref ( ) . map ( |x| & x[ ..] )
497
502
}
503
+ fn unconstrained ( & self ) -> Option < & [ f64 ] > {
504
+ self . unconstrained . as_ref ( ) . map ( |x| & x[ ..] )
505
+ }
498
506
}
499
507
500
508
#[ cfg( feature = "arrow" ) ]
@@ -506,6 +514,8 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
506
514
energy : MutablePrimitiveArray < f64 > ,
507
515
chain : MutablePrimitiveArray < u64 > ,
508
516
draw : MutablePrimitiveArray < u64 > ,
517
+ unconstrained : Option < MutableFixedSizeListArray < MutablePrimitiveArray < f64 > > > ,
518
+ gradient : Option < MutableFixedSizeListArray < MutablePrimitiveArray < f64 > > > ,
509
519
hamiltonian : <H :: Stats as ArrowRow >:: Builder ,
510
520
adapt : <A :: Stats as ArrowRow >:: Builder ,
511
521
}
@@ -514,6 +524,21 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
514
524
impl < H : Hamiltonian , A : AdaptStrategy > StatsBuilder < H , A > {
515
525
fn new_with_capacity ( dim : usize , settings : & SamplerArgs ) -> Self {
516
526
let capacity = ( settings. num_tune + settings. num_draws ) as usize ;
527
+
528
+ let gradient = if settings. store_gradient {
529
+ let items = MutablePrimitiveArray :: new ( ) ;
530
+ Some ( MutableFixedSizeListArray :: new_with_field ( items, "item" , false , dim) )
531
+ } else {
532
+ None
533
+ } ;
534
+
535
+ let unconstrained = if settings. store_gradient {
536
+ let items = MutablePrimitiveArray :: new ( ) ;
537
+ Some ( MutableFixedSizeListArray :: new_with_field ( items, "item" , false , dim) )
538
+ } else {
539
+ None
540
+ } ;
541
+
517
542
Self {
518
543
depth : MutablePrimitiveArray :: with_capacity ( capacity) ,
519
544
maxdepth_reached : MutableBooleanArray :: with_capacity ( capacity) ,
@@ -522,6 +547,8 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
522
547
energy : MutablePrimitiveArray :: with_capacity ( capacity) ,
523
548
chain : MutablePrimitiveArray :: with_capacity ( capacity) ,
524
549
draw : MutablePrimitiveArray :: with_capacity ( capacity) ,
550
+ gradient,
551
+ unconstrained,
525
552
hamiltonian : <H :: Stats as ArrowRow >:: new_builder ( dim, settings) ,
526
553
adapt : <A :: Stats as ArrowRow >:: new_builder ( dim, settings) ,
527
554
}
@@ -541,6 +568,28 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
541
568
self . chain . push ( Some ( value. chain ) ) ;
542
569
self . draw . push ( Some ( value. draw ) ) ;
543
570
571
+ if let Some ( store) = self . gradient . as_mut ( ) {
572
+ store
573
+ . try_push (
574
+ value
575
+ . gradient ( )
576
+ . as_ref ( )
577
+ . map ( |vals| vals. iter ( ) . map ( |& x| Some ( x) ) )
578
+ )
579
+ . unwrap ( ) ;
580
+ }
581
+
582
+ if let Some ( store) = self . unconstrained . as_mut ( ) {
583
+ store
584
+ . try_push (
585
+ value
586
+ . unconstrained ( )
587
+ . as_ref ( )
588
+ . map ( |vals| vals. iter ( ) . map ( |& x| Some ( x) ) )
589
+ )
590
+ . unwrap ( ) ;
591
+ }
592
+
544
593
self . hamiltonian . append_value ( & value. potential_stats ) ;
545
594
self . adapt . append_value ( & value. strategy_stats ) ;
546
595
}
@@ -579,6 +628,16 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
579
628
arrays. extend ( adapt. 1 ) ;
580
629
}
581
630
631
+ if let Some ( mut gradient) = self . gradient . take ( ) {
632
+ fields. push ( Field :: new ( "gradient" , gradient. data_type ( ) . clone ( ) , true ) ) ;
633
+ arrays. push ( gradient. as_box ( ) ) ;
634
+ }
635
+
636
+ if let Some ( mut unconstrained) = self . unconstrained . take ( ) {
637
+ fields. push ( Field :: new ( "unconstrained" , unconstrained. data_type ( ) . clone ( ) , true ) ) ;
638
+ arrays. push ( unconstrained. as_box ( ) ) ;
639
+ }
640
+
582
641
Some ( StructArray :: new ( DataType :: Struct ( fields) , arrays, None ) )
583
642
}
584
643
}
@@ -737,6 +796,13 @@ where
737
796
} else {
738
797
None
739
798
} ,
799
+ unconstrained : if self . options . store_unconstrained {
800
+ let mut unconstrained: Box < [ f64 ] > = vec ! [ 0f64 ; self . potential. dim( ) ] . into ( ) ;
801
+ state. write_position ( & mut unconstrained) ;
802
+ Some ( unconstrained)
803
+ } else {
804
+ None
805
+ } ,
740
806
} ;
741
807
self . strategy . adapt (
742
808
& mut self . options ,
0 commit comments