Skip to content

Commit 37161ba

Browse files
committed
Make stats optional
1 parent 241a580 commit 37161ba

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

src/adapt_strategy.rs

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ impl ArrowBuilder<DualAverageStats> for DualAverageStatsBuilder {
6464
self.n_steps.push(Some(value.n_steps));
6565
}
6666

67-
fn finalize(mut self) -> StructArray {
67+
fn finalize(mut self) -> Option<StructArray> {
6868
let fields = vec![
6969
Field::new("step_size_bar", DataType::Float64, false),
7070
Field::new("mean_tree_accept", DataType::Float64, false),
@@ -77,7 +77,7 @@ impl ArrowBuilder<DualAverageStats> for DualAverageStatsBuilder {
7777
self.n_steps.as_box(),
7878
];
7979

80-
StructArray::new(DataType::Struct(fields), arrays, None)
80+
Some(StructArray::new(DataType::Struct(fields), arrays, None))
8181
}
8282
}
8383

@@ -177,19 +177,11 @@ impl<F: CpuLogpFunc, M: MassMatrix> AdaptStrategy for DualAverageStrategy<F, M>
177177
}
178178

179179
/// Settings for mass matrix adaptation
180-
#[derive(Clone, Copy, Debug)]
180+
#[derive(Clone, Copy, Debug, Default)]
181181
pub struct DiagAdaptExpSettings {
182182
pub store_mass_matrix: bool,
183183
}
184184

185-
impl Default for DiagAdaptExpSettings {
186-
fn default() -> Self {
187-
Self {
188-
store_mass_matrix: false,
189-
}
190-
}
191-
}
192-
193185
pub(crate) struct ExpWindowDiagAdapt<F> {
194186
dim: usize,
195187
exp_variance_draw: RunningVariance,
@@ -284,7 +276,7 @@ impl ArrowBuilder<ExpWindowDiagAdaptStats> for ExpWindowDiagAdaptStatsBuilder {
284276
}
285277
}
286278

287-
fn finalize(self) -> StructArray {
279+
fn finalize(self) -> Option<StructArray> {
288280
if let Some(mut store) = self.mass_matrix_inv {
289281
let fields = vec![Field::new(
290282
"mass_matrix_inv",
@@ -294,9 +286,9 @@ impl ArrowBuilder<ExpWindowDiagAdaptStats> for ExpWindowDiagAdaptStatsBuilder {
294286

295287
let arrays = vec![store.as_box()];
296288

297-
StructArray::new(DataType::Struct(fields), arrays, None)
289+
Some(StructArray::new(DataType::Struct(fields), arrays, None))
298290
} else {
299-
StructArray::new(DataType::Struct(vec![]), vec![], None)
291+
None
300292
}
301293
}
302294
}
@@ -570,17 +562,24 @@ impl<D1: Debug + ArrowRow, D2: Debug + ArrowRow> ArrowBuilder<CombinedStats<D1,
570562
self.stats2.append_value(&value.stats2);
571563
}
572564

573-
fn finalize(self) -> StructArray {
574-
let mut data1 = self.stats1.finalize().into_data();
575-
let data2 = self.stats2.finalize().into_data();
565+
fn finalize(self) -> Option<StructArray> {
566+
match (self.stats1.finalize(), self.stats2.finalize()) {
567+
(None, None) => None,
568+
(Some(stats1), None) => Some(stats1),
569+
(None, Some(stats2)) => Some(stats2),
570+
(Some(stats1), Some(stats2)) => {
571+
let mut data1 = stats1.into_data();
572+
let data2 = stats2.into_data();
576573

577-
assert!(data1.2.is_none());
578-
assert!(data2.2.is_none());
574+
assert!(data1.2.is_none());
575+
assert!(data2.2.is_none());
579576

580-
data1.0.extend(data2.0);
581-
data1.1.extend(data2.1);
577+
data1.0.extend(data2.0);
578+
data1.1.extend(data2.1);
582579

583-
StructArray::new(DataType::Struct(data1.0), data1.1, None)
580+
Some(StructArray::new(DataType::Struct(data1.0), data1.1, None))
581+
}
582+
}
584583
}
585584
}
586585

src/cpu_potential.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ impl ArrowBuilder<PotentialStats> for PotentialStatsBuilder {
6363
self.step_size.push(Some(value.step_size));
6464
}
6565

66-
fn finalize(mut self) -> StructArray {
66+
fn finalize(mut self) -> Option<StructArray> {
6767
let fields = vec![Field::new("step_size", DataType::Float64, false)];
6868

6969
let arrays = vec![self.step_size.as_box()];
7070

71-
StructArray::new(DataType::Struct(fields), arrays, None)
71+
Some(StructArray::new(DataType::Struct(fields), arrays, None))
7272
}
7373
}
7474

src/nuts.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ pub trait ArrowRow {
421421
#[cfg(feature = "arrow")]
422422
pub trait ArrowBuilder<T: ?Sized> {
423423
fn append_value(&mut self, value: &T);
424-
fn finalize(self) -> StructArray;
424+
fn finalize(self) -> Option<StructArray>;
425425
}
426426

427427
#[derive(Debug)]
@@ -545,13 +545,7 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
545545
self.adapt.append_value(&value.strategy_stats);
546546
}
547547

548-
fn finalize(mut self) -> StructArray {
549-
let hamiltonian = self.hamiltonian.finalize().into_data();
550-
let adapt = self.adapt.finalize().into_data();
551-
552-
assert!(hamiltonian.2.is_none());
553-
assert!(adapt.2.is_none());
554-
548+
fn finalize(mut self) -> Option<StructArray> {
555549
let mut fields = vec![
556550
Field::new("depth", DataType::UInt64, false),
557551
Field::new("maxdepth_reached", DataType::Boolean, false),
@@ -562,9 +556,6 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
562556
Field::new("draw", DataType::UInt64, false),
563557
];
564558

565-
fields.extend(hamiltonian.0);
566-
fields.extend(adapt.0);
567-
568559
let mut arrays = vec![
569560
self.depth.as_box(),
570561
self.maxdepth_reached.as_box(),
@@ -575,10 +566,20 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
575566
self.draw.as_box(),
576567
];
577568

578-
arrays.extend(hamiltonian.1);
579-
arrays.extend(adapt.1);
569+
if let Some(hamiltonian) = self.hamiltonian.finalize() {
570+
let hamiltonian = hamiltonian.into_data();
571+
assert!(hamiltonian.2.is_none());
572+
fields.extend(hamiltonian.0);
573+
arrays.extend(hamiltonian.1);
574+
}
575+
if let Some(adapt) = self.adapt.finalize() {
576+
let adapt = adapt.into_data();
577+
assert!(adapt.2.is_none());
578+
fields.extend(adapt.0);
579+
arrays.extend(adapt.1);
580+
}
580581

581-
StructArray::new(DataType::Struct(fields), arrays, None)
582+
Some(StructArray::new(DataType::Struct(fields), arrays, None))
582583
}
583584
}
584585

0 commit comments

Comments
 (0)