Skip to content

Commit 95fd608

Browse files
committed
feat: Add option to use draw base mass matrix estimate
1 parent a29a56a commit 95fd608

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

python/nutpie/sample.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,10 @@ def sample(
428428
return_raw_trace: bool, default=False
429429
Return the raw trace object (an apache arrow structure)
430430
instead of converting to arviz.
431+
use_grad_based_mass_matrix: bool, default=True
432+
Use a mass matrix estimate that is based on draw and gradient
433+
variance. Set to `False` to get mass matrix adaptation more
434+
similar to PyMC and Stan.
431435
**kwargs
432436
Pass additional arguments to nutpie._lib.PySamplerArgs
433437

src/wrapper.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,16 @@ impl PyDiagGradNutsSettings {
222222
.mass_matrix_options
223223
.store_mass_matrix = val;
224224
}
225+
226+
#[getter]
227+
fn use_grad_based_mass_matrix(&self) -> bool {
228+
self.0.mass_matrix_adapt.mass_matrix_options.use_grad_based_estimate
229+
}
230+
231+
#[setter(use_grad_based_mass_matrix)]
232+
fn set_use_grad_based_mass_matrix(&mut self, val: bool) {
233+
self.0.mass_matrix_adapt.mass_matrix_options.use_grad_based_estimate = val
234+
}
225235
}
226236

227237
pub(crate) enum SamplerState {

0 commit comments

Comments
 (0)