File tree Expand file tree Collapse file tree 2 files changed +14
-0
lines changed Expand file tree Collapse file tree 2 files changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -428,6 +428,10 @@ def sample(
428
428
return_raw_trace: bool, default=False
429
429
Return the raw trace object (an apache arrow structure)
430
430
instead of converting to arviz.
431
+ use_grad_based_mass_matrix: bool, default=True
432
+ Use a mass matrix estimate that is based on draw and gradient
433
+ variance. Set to `False` to get mass matrix adaptation more
434
+ similar to PyMC and Stan.
431
435
**kwargs
432
436
Pass additional arguments to nutpie._lib.PySamplerArgs
433
437
Original file line number Diff line number Diff line change @@ -222,6 +222,16 @@ impl PyDiagGradNutsSettings {
222
222
. mass_matrix_options
223
223
. store_mass_matrix = val;
224
224
}
225
+
226
+ #[ getter]
227
+ fn use_grad_based_mass_matrix ( & self ) -> bool {
228
+ self . 0 . mass_matrix_adapt . mass_matrix_options . use_grad_based_estimate
229
+ }
230
+
231
+ #[ setter( use_grad_based_mass_matrix) ]
232
+ fn set_use_grad_based_mass_matrix ( & mut self , val : bool ) {
233
+ self . 0 . mass_matrix_adapt . mass_matrix_options . use_grad_based_estimate = val
234
+ }
225
235
}
226
236
227
237
pub ( crate ) enum SamplerState {
You can’t perform that action at this time.
0 commit comments