Skip to content

Export more details sample stats for divergences #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ rand_distr = "0.4.3"
multiversion = "0.7.2"
itertools = "0.11.0"
crossbeam = "0.8.2"
thiserror = "1.0.40"
thiserror = "1.0.43"
rayon = "1.7.0"
arrow2 = { version = "0.17.2", optional = true }
arrow2 = { version = "0.17.3", optional = true }
rand_chacha = "0.3.1"
anyhow = "1.0.71"
anyhow = "1.0.72"

[dev-dependencies]
proptest = "1.2.0"
pretty_assertions = "1.3.0"
pretty_assertions = "1.4.0"
criterion = "0.5.1"
nix = "0.26.2"
approx = "0.5.1"
Expand Down
2 changes: 1 addition & 1 deletion src/adapt_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
state
.grad
.iter()
.map(|&grad| grad.abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT))
.map(|&grad| grad.abs().clamp(LOWER_LIMIT, UPPER_LIMIT).recip())
.map(|var| if var.is_finite() { Some(var) } else { Some(1.) }),
);
}
Expand Down
15 changes: 14 additions & 1 deletion src/cpu_potential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
let div_info = DivergenceInfo {
logp_function_error: Some(Box::new(logp_error)),
start_location: Some(start.q.clone()),
start_gradient: Some(start.grad.clone()),
start_momentum: Some(start.p.clone()),
end_location: None,
start_idx_in_trajectory: Some(start.idx_in_trajectory),
end_idx_in_trajectory: None,
Expand All @@ -142,7 +144,9 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
let divergence_info = DivergenceInfo {
logp_function_error: None,
start_location: Some(start.q.clone()),
start_gradient: Some(start.grad.clone()),
end_location: Some(out.q.clone()),
start_momentum: Some(out.p.clone()),
start_idx_in_trajectory: Some(start.index_in_trajectory()),
end_idx_in_trajectory: Some(out.index_in_trajectory()),
energy_error: Some(energy_error),
Expand All @@ -165,7 +169,16 @@ impl<F: CpuLogpFunc, M: MassMatrix> Hamiltonian for EuclideanPotential<F, M> {
}
self.update_potential_gradient(&mut state)
.map_err(|e| NutsError::LogpFailure(Box::new(e)))?;
Ok(state)
if state
.grad
.iter()
.cloned()
.any(|val| (val == 0f64) | !val.is_finite())
{
Err(NutsError::BadInitGrad())
} else {
Ok(state)
}
}

fn randomize_momentum<R: rand::Rng + ?Sized>(&self, state: &mut Self::State, rng: &mut R) {
Expand Down
4 changes: 4 additions & 0 deletions src/cpu_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ impl crate::nuts::State for State {
out.copy_from_slice(&self.grad);
}

fn write_momentum(&self, out: &mut [f64]) {
out.copy_from_slice(&self.p);
}

fn energy(&self) -> f64 {
self.kinetic_energy + self.potential_energy
}
Expand Down
152 changes: 151 additions & 1 deletion src/nuts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow2::array::{MutableFixedSizeListArray, TryPush};
use arrow2::array::{MutableFixedSizeListArray, MutableUtf8Array, TryPush};
#[cfg(feature = "arrow")]
use arrow2::{
array::{MutableArray, MutableBooleanArray, MutablePrimitiveArray, StructArray},
Expand All @@ -13,13 +13,17 @@ use crate::math::logaddexp;
#[cfg(feature = "arrow")]
use crate::SamplerArgs;

#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NutsError {
#[error("Logp function returned error: {0}")]
LogpFailure(Box<dyn std::error::Error + Send + Sync>),

#[error("Could not serialize sample stats")]
SerializeFailure(),

#[error("Could not initialize state because of bad initial gradient.")]
BadInitGrad(),
}

pub type Result<T> = std::result::Result<T, NutsError>;
Expand All @@ -33,7 +37,9 @@ pub type Result<T> = std::result::Result<T, NutsError>;
/// failed)
#[derive(Debug)]
pub struct DivergenceInfo {
pub start_momentum: Option<Box<[f64]>>,
pub start_location: Option<Box<[f64]>>,
pub start_gradient: Option<Box<[f64]>>,
pub end_location: Option<Box<[f64]>>,
pub energy_error: Option<f64>,
pub end_idx_in_trajectory: Option<i64>,
Expand Down Expand Up @@ -148,6 +154,9 @@ pub trait State: Clone + Debug {
/// Write the gradient stored in the state to a different location
fn write_gradient(&self, out: &mut [f64]);

/// Write the momentum in the state to a different location
fn write_momentum(&self, out: &mut [f64]);

/// Compute the termination criterion for NUTS
fn is_turning(&self, other: &Self) -> bool;

Expand Down Expand Up @@ -519,6 +528,11 @@ pub struct StatsBuilder<H: Hamiltonian, A: AdaptStrategy> {
hamiltonian: <H::Stats as ArrowRow>::Builder,
adapt: <A::Stats as ArrowRow>::Builder,
diverging: MutableBooleanArray,
divergence_start: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_start_grad: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_end: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_momentum: Option<MutableFixedSizeListArray<MutablePrimitiveArray<f64>>>,
divergence_msg: Option<MutableUtf8Array<i64>>,
}

#[cfg(feature = "arrow")]
Expand All @@ -544,6 +558,40 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
None
};

let (div_start, div_start_grad, div_end, div_mom, div_msg) = if settings.store_divergences {
let start_location_prim = MutablePrimitiveArray::new();
let start_location_list =
MutableFixedSizeListArray::new_with_field(start_location_prim, "item", false, dim);

let start_grad_prim = MutablePrimitiveArray::new();
let start_grad_list =
MutableFixedSizeListArray::new_with_field(start_grad_prim, "item", false, dim);

let end_location_prim = MutablePrimitiveArray::new();
let end_location_list =
MutableFixedSizeListArray::new_with_field(end_location_prim, "item", false, dim);

let momentum_location_prim = MutablePrimitiveArray::new();
let momentum_location_list = MutableFixedSizeListArray::new_with_field(
momentum_location_prim,
"item",
false,
dim,
);

let msg_list = MutableUtf8Array::new();

(
Some(start_location_list),
Some(start_grad_list),
Some(end_location_list),
Some(momentum_location_list),
Some(msg_list),
)
} else {
(None, None, None, None, None)
};

Self {
depth: MutablePrimitiveArray::with_capacity(capacity),
maxdepth_reached: MutableBooleanArray::with_capacity(capacity),
Expand All @@ -557,6 +605,11 @@ impl<H: Hamiltonian, A: AdaptStrategy> StatsBuilder<H, A> {
hamiltonian: <H::Stats as ArrowRow>::new_builder(dim, settings),
adapt: <A::Stats as ArrowRow>::new_builder(dim, settings),
diverging: MutableBooleanArray::with_capacity(capacity),
divergence_start: div_start,
divergence_start_grad: div_start_grad,
divergence_end: div_end,
divergence_momentum: div_mom,
divergence_msg: div_msg,
}
}
}
Expand Down Expand Up @@ -597,6 +650,58 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
.unwrap();
}

let info_option = value.divergence_info();
if let Some(div_start) = self.divergence_start.as_mut() {
div_start
.try_push(info_option.and_then(|info| {
info.start_location
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

let info_option = value.divergence_info();
if let Some(div_grad) = self.divergence_start_grad.as_mut() {
div_grad
.try_push(info_option.and_then(|info| {
info.start_gradient
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

if let Some(div_end) = self.divergence_end.as_mut() {
div_end
.try_push(info_option.and_then(|info| {
info.end_location
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

if let Some(div_mom) = self.divergence_momentum.as_mut() {
div_mom
.try_push(info_option.and_then(|info| {
info.start_momentum
.as_ref()
.map(|vals| vals.iter().map(|&x| Some(x)))
}))
.unwrap();
}

if let Some(div_msg) = self.divergence_msg.as_mut() {
div_msg
.try_push(info_option.and_then(|info| {
info.logp_function_error
.as_ref()
.map(|err| format!("{}", err))
}))
.unwrap();
}

self.hamiltonian.append_value(&value.potential_stats);
self.adapt.append_value(&value.strategy_stats);
}
Expand Down Expand Up @@ -651,6 +756,51 @@ impl<H: Hamiltonian, A: AdaptStrategy> ArrowBuilder<NutsSampleStats<H::Stats, A:
arrays.push(unconstrained.as_box());
}

if let Some(mut div_start) = self.divergence_start.take() {
fields.push(Field::new(
"divergence_start",
div_start.data_type().clone(),
true,
));
arrays.push(div_start.as_box());
}

if let Some(mut div_start_grad) = self.divergence_start_grad.take() {
fields.push(Field::new(
"divergence_start_gradient",
div_start_grad.data_type().clone(),
true,
));
arrays.push(div_start_grad.as_box());
}

if let Some(mut div_end) = self.divergence_end.take() {
fields.push(Field::new(
"divergence_end",
div_end.data_type().clone(),
true,
));
arrays.push(div_end.as_box());
}

if let Some(mut div_mom) = self.divergence_momentum.take() {
fields.push(Field::new(
"divergence_momentum",
div_mom.data_type().clone(),
true,
));
arrays.push(div_mom.as_box());
}

if let Some(mut div_msg) = self.divergence_msg.take() {
fields.push(Field::new(
"divergence_message",
div_msg.data_type().clone(),
true,
));
arrays.push(div_msg.as_box());
}

Some(StructArray::new(DataType::Struct(fields), arrays, None))
}
}
Expand Down