Skip to content

Commit fc5e9d4

Browse files
committed
Export sampler stats as arrow object
1 parent b5054ca commit fc5e9d4

File tree

10 files changed

+437
-300
lines changed

10 files changed

+437
-300
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ crossbeam = "0.8.1"
2929
thiserror = "1.0.31"
3030
rayon = "1.5.3"
3131
ndarray = "0.15.4"
32+
arrow2 = { version = "0.17.0", optional = true }
3233

3334
[dev-dependencies]
3435
proptest = "1.0.0"
@@ -43,5 +44,7 @@ harness = false
4344

4445
[features]
4546
nightly = ["simd_support"]
47+
default = ["arrow"]
4648

4749
simd_support = []
50+
arrow = ["dep:arrow2"]

src/adapt_strategy.rs

Lines changed: 157 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
use std::{fmt::Debug, marker::PhantomData};
22

3+
#[cfg(feature = "arrow")]
4+
use arrow2::{
5+
array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush},
6+
datatypes::{DataType, Field},
7+
};
38
use itertools::izip;
49

510
use crate::{
611
cpu_potential::{CpuLogpFunc, EuclideanPotential},
7-
mass_matrix::{
8-
DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance,
9-
},
10-
nuts::{
11-
AdaptStrategy, AsSampleStatVec, Collector, Hamiltonian, NutsOptions, SampleStatItem,
12-
SampleStatValue,
13-
},
12+
mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance},
13+
nuts::{AdaptStrategy, Collector, Hamiltonian, NutsOptions},
1414
stepsize::{AcceptanceRateCollector, DualAverage, DualAverageOptions},
15+
DivergenceInfo,
1516
};
1617

18+
#[cfg(feature = "arrow")]
19+
use crate::nuts::{ArrowBuilder, ArrowRow};
20+
#[cfg(feature = "arrow")]
21+
use crate::SamplerArgs;
22+
1723
const LOWER_LIMIT: f64 = 1e-10f64;
1824
const UPPER_LIMIT: f64 = 1e10f64;
1925

@@ -36,22 +42,55 @@ impl<F, M> DualAverageStrategy<F, M> {
3642
}
3743
}
3844

39-
4045
#[derive(Debug, Clone, Copy)]
4146
pub struct DualAverageStats {
4247
pub step_size_bar: f64,
4348
pub mean_tree_accept: f64,
4449
pub n_steps: u64,
4550
}
4651

47-
impl AsSampleStatVec for DualAverageStats {
48-
fn add_to_vec(&self, vec: &mut Vec<SampleStatItem>) {
49-
vec.push(("step_size_bar", SampleStatValue::F64(self.step_size_bar)));
50-
vec.push((
51-
"mean_tree_accept",
52-
SampleStatValue::F64(self.mean_tree_accept),
53-
));
54-
vec.push(("n_steps", SampleStatValue::U64(self.n_steps)));
52+
#[cfg(feature = "arrow")]
53+
pub struct DualAverageStatsBuilder {
54+
step_size_bar: MutablePrimitiveArray<f64>,
55+
mean_tree_accept: MutablePrimitiveArray<f64>,
56+
n_steps: MutablePrimitiveArray<u64>,
57+
}
58+
59+
#[cfg(feature = "arrow")]
60+
impl ArrowBuilder<DualAverageStats> for DualAverageStatsBuilder {
61+
fn append_value(&mut self, value: &DualAverageStats) {
62+
self.step_size_bar.push(Some(value.step_size_bar));
63+
self.mean_tree_accept.push(Some(value.mean_tree_accept));
64+
self.n_steps.push(Some(value.n_steps));
65+
}
66+
67+
fn finalize(mut self) -> StructArray {
68+
let fields = vec![
69+
Field::new("step_size_bar", DataType::Float64, false),
70+
Field::new("mean_tree_accept", DataType::Float64, false),
71+
Field::new("n_steps", DataType::UInt64, false),
72+
];
73+
74+
let arrays = vec![
75+
self.step_size_bar.as_box(),
76+
self.mean_tree_accept.as_box(),
77+
self.n_steps.as_box(),
78+
];
79+
80+
StructArray::new(DataType::Struct(fields), arrays, None)
81+
}
82+
}
83+
84+
#[cfg(feature = "arrow")]
85+
impl ArrowRow for DualAverageStats {
86+
type Builder = DualAverageStatsBuilder;
87+
88+
fn new_builder(_dim: usize, _settings: &SamplerArgs) -> Self::Builder {
89+
Self::Builder {
90+
step_size_bar: MutablePrimitiveArray::new(),
91+
mean_tree_accept: MutablePrimitiveArray::new(),
92+
n_steps: MutablePrimitiveArray::new(),
93+
}
5594
}
5695
}
5796

@@ -220,18 +259,53 @@ impl<F: CpuLogpFunc> ExpWindowDiagAdapt<F> {
220259
}
221260
}
222261

223-
224262
#[derive(Clone, Debug)]
225263
pub struct ExpWindowDiagAdaptStats {
226264
pub mass_matrix_inv: Option<Box<[f64]>>,
227265
}
228266

229-
impl AsSampleStatVec for ExpWindowDiagAdaptStats {
230-
fn add_to_vec(&self, vec: &mut Vec<SampleStatItem>) {
231-
vec.push((
267+
#[cfg(feature = "arrow")]
268+
pub struct ExpWindowDiagAdaptStatsBuilder {
269+
mass_matrix_inv: MutableFixedSizeListArray<MutablePrimitiveArray<f64>>,
270+
}
271+
272+
#[cfg(feature = "arrow")]
273+
impl ArrowBuilder<ExpWindowDiagAdaptStats> for ExpWindowDiagAdaptStatsBuilder {
274+
fn append_value(&mut self, value: &ExpWindowDiagAdaptStats) {
275+
self.mass_matrix_inv
276+
.try_push(
277+
value
278+
.mass_matrix_inv
279+
.as_ref()
280+
.map(|vals| vals.iter().map(|&x| Some(x))),
281+
)
282+
.unwrap();
283+
}
284+
285+
fn finalize(mut self) -> StructArray {
286+
let fields = vec![Field::new(
232287
"mass_matrix_inv",
233-
SampleStatValue::OptionArray(self.mass_matrix_inv.clone()),
234-
));
288+
self.mass_matrix_inv.data_type().clone(),
289+
true,
290+
)];
291+
292+
let arrays = vec![self.mass_matrix_inv.as_box()];
293+
294+
StructArray::new(DataType::Struct(fields), arrays, None)
295+
}
296+
}
297+
298+
#[cfg(feature = "arrow")]
299+
impl ArrowRow for ExpWindowDiagAdaptStats {
300+
type Builder = ExpWindowDiagAdaptStatsBuilder;
301+
302+
fn new_builder(dim: usize, _settings: &SamplerArgs) -> Self::Builder {
303+
let items = MutablePrimitiveArray::new();
304+
// TODO Add only based on settings
305+
let values = MutableFixedSizeListArray::new_with_field(items, "item", false, dim);
306+
Self::Builder {
307+
mass_matrix_inv: values,
308+
}
235309
}
236310
}
237311

@@ -260,16 +334,19 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
260334
state: &<Self::Potential as Hamiltonian>::State,
261335
) {
262336
self.exp_variance_draw.add_sample(state.q.iter().copied());
263-
self.exp_variance_draw_bg.add_sample(state.q.iter().copied());
264-
self.exp_variance_grad.add_sample(state.grad.iter().copied());
265-
self.exp_variance_grad_bg.add_sample(state.grad.iter().copied());
337+
self.exp_variance_draw_bg
338+
.add_sample(state.q.iter().copied());
339+
self.exp_variance_grad
340+
.add_sample(state.grad.iter().copied());
341+
self.exp_variance_grad_bg
342+
.add_sample(state.grad.iter().copied());
266343

267344
potential.mass_matrix.update_diag(
268-
state.grad.iter().map(|&grad| {
269-
Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT))
270-
})
345+
state
346+
.grad
347+
.iter()
348+
.map(|&grad| Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT))),
271349
);
272-
273350
}
274351

275352
fn adapt(
@@ -303,7 +380,6 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
303380
}
304381
}
305382

306-
307383
pub(crate) struct GradDiagStrategy<F: CpuLogpFunc> {
308384
step_size: DualAverageStrategy<F, DiagMassMatrix>,
309385
mass_matrix: ExpWindowDiagAdapt<F>,
@@ -332,8 +408,6 @@ impl Default for GradDiagOptions {
332408
dual_average_options: DualAverageSettings::default(),
333409
mass_matrix_options: DiagAdaptExpSettings::default(),
334410
early_window: 0.3,
335-
//step_size_window: 0.08,
336-
//step_size_window: 0.15,
337411
step_size_window: 0.2,
338412
mass_matrix_switch_freq: 60,
339413
early_mass_matrix_switch_freq: 10,
@@ -345,7 +419,7 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
345419
type Potential = EuclideanPotential<F, DiagMassMatrix>;
346420
type Collector = CombinedCollector<
347421
AcceptanceRateCollector<<EuclideanPotential<F, DiagMassMatrix> as Hamiltonian>::State>,
348-
DrawGradCollector
422+
DrawGradCollector,
349423
>;
350424
type Stats = CombinedStats<DualAverageStats, ExpWindowDiagAdaptStats>;
351425
type Options = GradDiagOptions;
@@ -404,14 +478,16 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
404478
self.mass_matrix.update_estimators(&collector.collector2);
405479
}
406480
self.mass_matrix.update_potential(potential);
407-
self.step_size.adapt(options, potential, draw, &collector.collector1);
481+
self.step_size
482+
.adapt(options, potential, draw, &collector.collector1);
408483
return;
409484
}
410485

411486
if draw == self.num_tune - 1 {
412487
self.step_size.finalize();
413488
}
414-
self.step_size.adapt(options, potential, draw, &collector.collector1);
489+
self.step_size
490+
.adapt(options, potential, draw, &collector.collector1);
415491
}
416492

417493
fn new_collector(&self) -> Self::Collector {
@@ -438,17 +514,58 @@ impl<F: CpuLogpFunc> AdaptStrategy for GradDiagStrategy<F> {
438514
}
439515
}
440516

517+
#[cfg(feature = "arrow")]
518+
#[derive(Debug, Clone)]
519+
pub struct CombinedStats<D1: Debug + ArrowRow, D2: Debug + ArrowRow> {
520+
pub stats1: D1,
521+
pub stats2: D2,
522+
}
441523

524+
#[cfg(not(feature = "arrow"))]
442525
#[derive(Debug, Clone)]
443526
pub struct CombinedStats<D1: Debug, D2: Debug> {
444527
pub stats1: D1,
445528
pub stats2: D2,
446529
}
447530

448-
impl<D1: AsSampleStatVec, D2: AsSampleStatVec> AsSampleStatVec for CombinedStats<D1, D2> {
449-
fn add_to_vec(&self, vec: &mut Vec<SampleStatItem>) {
450-
self.stats1.add_to_vec(vec);
451-
self.stats2.add_to_vec(vec);
531+
#[cfg(feature = "arrow")]
532+
pub struct CombinedStatsBuilder<D1: ArrowRow, D2: ArrowRow> {
533+
stats1: D1::Builder,
534+
stats2: D2::Builder,
535+
}
536+
537+
#[cfg(feature = "arrow")]
538+
impl<D1: Debug + ArrowRow, D2: Debug + ArrowRow> ArrowRow for CombinedStats<D1, D2> {
539+
type Builder = CombinedStatsBuilder<D1, D2>;
540+
541+
fn new_builder(dim: usize, settings: &SamplerArgs) -> Self::Builder {
542+
Self::Builder {
543+
stats1: D1::new_builder(dim, settings),
544+
stats2: D2::new_builder(dim, settings),
545+
}
546+
}
547+
}
548+
549+
#[cfg(feature = "arrow")]
550+
impl<D1: Debug + ArrowRow, D2: Debug + ArrowRow> ArrowBuilder<CombinedStats<D1, D2>>
551+
for CombinedStatsBuilder<D1, D2>
552+
{
553+
fn append_value(&mut self, value: &CombinedStats<D1, D2>) {
554+
self.stats1.append_value(&value.stats1);
555+
self.stats2.append_value(&value.stats2);
556+
}
557+
558+
fn finalize(self) -> StructArray {
559+
let mut data1 = self.stats1.finalize().into_data();
560+
let data2 = self.stats2.finalize().into_data();
561+
562+
assert!(data1.2.is_none());
563+
assert!(data2.2.is_none());
564+
565+
data1.0.extend(data2.0);
566+
data1.1.extend(data2.1);
567+
568+
StructArray::new(DataType::Struct(data1.0), data1.1, None)
452569
}
453570
}
454571

@@ -468,7 +585,7 @@ where
468585
&mut self,
469586
start: &Self::State,
470587
end: &Self::State,
471-
divergence_info: Option<&dyn crate::nuts::DivergenceInfo>,
588+
divergence_info: Option<&DivergenceInfo>,
472589
) {
473590
self.collector1
474591
.register_leapfrog(start, end, divergence_info);

0 commit comments

Comments
 (0)