Skip to content

Commit 241a580

Browse files
committed
Add option to store divs and unconstrained draws
1 parent 5f33010 commit 241a580

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

src/adapt_strategy.rs

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -266,45 +266,60 @@ pub struct ExpWindowDiagAdaptStats {
266266

267267
#[cfg(feature = "arrow")]
268268
pub struct ExpWindowDiagAdaptStatsBuilder {
269-
mass_matrix_inv: MutableFixedSizeListArray<MutablePrimitiveArray<f64>>,
269+
mass_matrix_inv: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
270270
}
271271

272272
#[cfg(feature = "arrow")]
273273
impl ArrowBuilder<ExpWindowDiagAdaptStats> for ExpWindowDiagAdaptStatsBuilder {
274274
fn append_value(&mut self, value: &ExpWindowDiagAdaptStats) {
275-
self.mass_matrix_inv
276-
.try_push(
277-
value
278-
.mass_matrix_inv
279-
.as_ref()
280-
.map(|vals| vals.iter().map(|&x| Some(x))),
281-
)
282-
.unwrap();
275+
if let Some(store) = self.mass_matrix_inv.as_mut() {
276+
store
277+
.try_push(
278+
value
279+
.mass_matrix_inv
280+
.as_ref()
281+
.map(|vals| vals.iter().map(|&x| Some(x))),
282+
)
283+
.unwrap();
284+
}
283285
}
284286

285-
fn finalize(mut self) -> StructArray {
286-
let fields = vec![Field::new(
287-
"mass_matrix_inv",
288-
self.mass_matrix_inv.data_type().clone(),
289-
true,
290-
)];
287+
fn finalize(self) -> StructArray {
288+
if let Some(mut store) = self.mass_matrix_inv {
289+
let fields = vec![Field::new(
290+
"mass_matrix_inv",
291+
store.data_type().clone(),
292+
true,
293+
)];
291294

292-
let arrays = vec![self.mass_matrix_inv.as_box()];
295+
let arrays = vec![store.as_box()];
293296

294-
StructArray::new(DataType::Struct(fields), arrays, None)
297+
StructArray::new(DataType::Struct(fields), arrays, None)
298+
} else {
299+
StructArray::new(DataType::Struct(vec![]), vec![], None)
300+
}
295301
}
296302
}
297303

298304
#[cfg(feature = "arrow")]
299305
impl ArrowRow for ExpWindowDiagAdaptStats {
300306
type Builder = ExpWindowDiagAdaptStatsBuilder;
301307

302-
fn new_builder(dim: usize, _settings: &SamplerArgs) -> Self::Builder {
303-
let items = MutablePrimitiveArray::new();
304-
// TODO Add only based on settings
305-
let values = MutableFixedSizeListArray::new_with_field(items, "item", false, dim);
306-
Self::Builder {
307-
mass_matrix_inv: values,
308+
fn new_builder(dim: usize, settings: &SamplerArgs) -> Self::Builder {
309+
if settings
310+
.mass_matrix_adapt
311+
.mass_matrix_options
312+
.store_mass_matrix
313+
{
314+
let items = MutablePrimitiveArray::new();
315+
let values = MutableFixedSizeListArray::new_with_field(items, "item", false, dim);
316+
Self::Builder {
317+
mass_matrix_inv: Some(values),
318+
}
319+
} else {
320+
Self::Builder {
321+
mass_matrix_inv: None,
322+
}
308323
}
309324
}
310325
}

src/cpu_sampler.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ pub struct SamplerArgs {
2323
pub maxdepth: u64,
2424
/// Store the gradient in the SampleStats
2525
pub store_gradient: bool,
26+
/// Store each unconstrained parameter vector in the sampler stats
27+
pub store_unconstrained: bool,
2628
/// If the energy error is larger than this threshold we treat the leapfrog
2729
/// step as a divergence.
2830
pub max_energy_error: f64,
31+
/// Store detailed information about each divergence in the sampler stats
32+
pub store_divergences: bool,
2933
/// Settings for step size adaptation.
3034
pub step_size_adapt: DualAverageSettings,
3135
/// Settings for mass matrix adaptation.
@@ -40,6 +44,8 @@ impl Default for SamplerArgs {
4044
maxdepth: 10,
4145
max_energy_error: 1000f64,
4246
store_gradient: false,
47+
store_unconstrained: false,
48+
store_divergences: true,
4349
step_size_adapt: DualAverageSettings::default(),
4450
mass_matrix_adapt: GradDiagOptions::default(),
4551
}

0 commit comments

Comments
 (0)