Skip to content

Commit 6af13ea

Browse files
committed
Store gradient and unconstrained draw if requested
1 parent 37161ba commit 6af13ea

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

src/adapt_strategy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ mod test {
686686
let options = NutsOptions {
687687
maxdepth: 10u64,
688688
store_gradient: true,
689+
store_unconstrained: true,
689690
};
690691

691692
let rng = {

src/cpu_sampler.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::thread::JoinHandle;
44
use thiserror::Error;
55

66
use crate::{
7-
adapt_strategy::{DualAverageSettings, GradDiagOptions, GradDiagStrategy},
7+
adapt_strategy::{GradDiagOptions, GradDiagStrategy},
88
cpu_potential::EuclideanPotential,
99
mass_matrix::DiagMassMatrix,
1010
nuts::{Chain, NutsChain, NutsError, NutsOptions, SampleStats},
@@ -30,8 +30,6 @@ pub struct SamplerArgs {
3030
pub max_energy_error: f64,
3131
/// Store detailed information about each divergence in the sampler stats
3232
pub store_divergences: bool,
33-
/// Settings for step size adaptation.
34-
pub step_size_adapt: DualAverageSettings,
3533
/// Settings for mass matrix adaptation.
3634
pub mass_matrix_adapt: GradDiagOptions,
3735
}
@@ -46,7 +44,6 @@ impl Default for SamplerArgs {
4644
store_gradient: false,
4745
store_unconstrained: false,
4846
store_divergences: true,
49-
step_size_adapt: DualAverageSettings::default(),
5047
mass_matrix_adapt: GradDiagOptions::default(),
5148
}
5249
}
@@ -89,8 +86,6 @@ pub trait CpuLogpFuncMaker<Func>: Send + Sync
8986
where
9087
Func: CpuLogpFunc,
9188
{
92-
//type Func: CpuLogpFunc;
93-
9489
fn make_logp_func(&self, chain: usize) -> Result<Func, anyhow::Error>;
9590
fn dim(&self) -> usize;
9691
}
@@ -194,19 +189,15 @@ pub fn new_sampler<F: CpuLogpFunc, R: Rng + ?Sized>(
194189
) -> impl Chain {
195190
use crate::nuts::AdaptStrategy;
196191
let num_tune = settings.num_tune;
197-
//let step_size_adapt = DualAverageStrategy::new(settings.step_size_adapt, num_tune, logp.dim());
198-
//let mass_matrix_adapt =
199-
// ExpWindowDiagAdapt::new(settings.mass_matrix_adapt, num_tune, logp.dim());
200-
201192
let strategy = GradDiagStrategy::new(settings.mass_matrix_adapt, num_tune, logp.dim());
202-
203193
let mass_matrix = DiagMassMatrix::new(logp.dim());
204194
let max_energy_error = settings.max_energy_error;
205195
let potential = EuclideanPotential::new(logp, mass_matrix, max_energy_error, 1f64);
206196

207197
let options = NutsOptions {
208198
maxdepth: settings.maxdepth,
209199
store_gradient: settings.store_gradient,
200+
store_unconstrained: settings.store_unconstrained,
210201
};
211202

212203
let rng = rand::rngs::SmallRng::from_rng(rng).expect("Could not seed rng");

src/nuts.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use arrow2::array::{MutableFixedSizeListArray, TryPush};
12
#[cfg(feature = "arrow")]
23
use arrow2::{
34
array::{MutableArray, MutableBooleanArray, MutablePrimitiveArray, StructArray},
@@ -368,6 +369,7 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
368369
pub struct NutsOptions {
369370
pub maxdepth: u64,
370371
pub store_gradient: bool,
372+
pub store_unconstrained: bool,
371373
}
372374

373375
pub(crate) fn draw<P, R, C>(
@@ -435,6 +437,7 @@ pub(crate) struct NutsSampleStats<HStats: Send + Debug, AdaptStats: Send + Debug
435437
pub chain: u64,
436438
pub draw: u64,
437439
pub gradient: Option<Box<[f64]>>,
440+
pub unconstrained: Option<Box<[f64]>>,
438441
pub potential_stats: HStats,
439442
pub strategy_stats: AdaptStats,
440443
}
@@ -461,6 +464,8 @@ pub trait SampleStats: Send + Debug {
461464
/// The logp gradient at the location of the draw. This is only stored
462465
/// if NutsOptions.store_gradient is `true`.
463466
fn gradient(&self) -> Option<&[f64]>;
467+
/// The draw in the unconstrained space.
468+
fn unconstrained(&self) -> Option<&[f64]>;
464469
}
465470

466471
impl<H, A> SampleStats for NutsSampleStats<H, A>
@@ -495,6 +500,9 @@ where
495500
fn gradient(&self) -> Option<&[f64]> {
496501
self.gradient.as_ref().map(|x| &x[..])
497502
}
503+
fn unconstrained(&self) -> Option<&[f64]> {
504+
self.unconstrained.as_ref().map(|x| &x[..])
505+
}
498506
}
499507

500508
#[cfg(feature = "arrow")]
@@ -506,6 +514,8 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
506514
energy: MutablePrimitiveArray<f64>,
507515
chain: MutablePrimitiveArray<u64>,
508516
draw: MutablePrimitiveArray<u64>,
517+
unconstrained: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
518+
gradient: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
509519
hamiltonian: <H::Stats as ArrowRow>::Builder,
510520
adapt: <A::Stats as ArrowRow>::Builder,
511521
}
@@ -514,6 +524,21 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
514524
impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
515525
fn new_with_capacity(dim: usize, settings: &SamplerArgs) -> Self {
516526
let capacity = (settings.num_tune + settings.num_draws) as usize;
527+
528+
let gradient = if settings.store_gradient {
529+
let items = MutablePrimitiveArray::new();
530+
Some(MutableFixedSizeListArray::new_with_field(items, "item", false, dim))
531+
} else {
532+
None
533+
};
534+
535+
let unconstrained = if settings.store_gradient {
536+
let items = MutablePrimitiveArray::new();
537+
Some(MutableFixedSizeListArray::new_with_field(items, "item", false, dim))
538+
} else {
539+
None
540+
};
541+
517542
Self {
518543
depth: MutablePrimitiveArray::with_capacity(capacity),
519544
maxdepth_reached: MutableBooleanArray::with_capacity(capacity),
@@ -522,6 +547,8 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
522547
energy: MutablePrimitiveArray::with_capacity(capacity),
523548
chain: MutablePrimitiveArray::with_capacity(capacity),
524549
draw: MutablePrimitiveArray::with_capacity(capacity),
550+
gradient,
551+
unconstrained,
525552
hamiltonian: <H::Stats as ArrowRow>::new_builder(dim, settings),
526553
adapt: <A::Stats as ArrowRow>::new_builder(dim, settings),
527554
}
@@ -541,6 +568,28 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
541568
self.chain.push(Some(value.chain));
542569
self.draw.push(Some(value.draw));
543570

571+
if let Some(store) = self.gradient.as_mut() {
572+
store
573+
.try_push(
574+
value
575+
.gradient()
576+
.as_ref()
577+
.map(|vals| vals.iter().map(|&x| Some(x)))
578+
)
579+
.unwrap();
580+
}
581+
582+
if let Some(store) = self.unconstrained.as_mut() {
583+
store
584+
.try_push(
585+
value
586+
.unconstrained()
587+
.as_ref()
588+
.map(|vals| vals.iter().map(|&x| Some(x)))
589+
)
590+
.unwrap();
591+
}
592+
544593
self.hamiltonian.append_value(&value.potential_stats);
545594
self.adapt.append_value(&value.strategy_stats);
546595
}
@@ -579,6 +628,16 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
579628
arrays.extend(adapt.1);
580629
}
581630

631+
if let Some(mut gradient) = self.gradient.take() {
632+
fields.push(Field::new("gradient", gradient.data_type().clone(), true));
633+
arrays.push(gradient.as_box());
634+
}
635+
636+
if let Some(mut unconstrained) = self.unconstrained.take() {
637+
fields.push(Field::new("unconstrained", unconstrained.data_type().clone(), true));
638+
arrays.push(unconstrained.as_box());
639+
}
640+
582641
Some(StructArray::new(DataType::Struct(fields), arrays, None))
583642
}
584643
}
@@ -737,6 +796,13 @@ where
737796
} else {
738797
None
739798
},
799+
unconstrained: if self.options.store_unconstrained {
800+
let mut unconstrained: Box<[f64]> = vec![0f64; self.potential.dim()].into();
801+
state.write_position(&mut unconstrained);
802+
Some(unconstrained)
803+
} else {
804+
None
805+
},
740806
};
741807
self.strategy.adapt(
742808
&mut self.options,

0 commit comments

Comments
 (0)