@@ -7,7 +7,9 @@ use crate::{
7
7
8
8
use anyhow:: { Context , Result } ;
9
9
use arrow2:: { array:: Array , datatypes:: Field } ;
10
- use nuts_rs:: { ChainProgress , DiagGradNutsSettings , ProgressCallback , Sampler , SamplerWaitResult , Trace } ;
10
+ use nuts_rs:: {
11
+ ChainProgress , DiagGradNutsSettings , ProgressCallback , Sampler , SamplerWaitResult , Trace ,
12
+ } ;
11
13
use pyo3:: {
12
14
exceptions:: PyTimeoutError ,
13
15
ffi:: Py_uintptr_t ,
@@ -19,7 +21,6 @@ use rand::{thread_rng, RngCore};
19
21
#[ pyclass]
20
22
struct PyChainProgress ( ChainProgress ) ;
21
23
22
-
23
24
#[ pymethods]
24
25
impl PyChainProgress {
25
26
#[ getter]
@@ -62,7 +63,6 @@ impl PyChainProgress {
62
63
#[ derive( Clone , Default ) ]
63
64
pub struct PyDiagGradNutsSettings ( DiagGradNutsSettings ) ;
64
65
65
-
66
66
#[ pymethods]
67
67
impl PyDiagGradNutsSettings {
68
68
#[ new]
@@ -71,8 +71,10 @@ impl PyDiagGradNutsSettings {
71
71
let mut rng = thread_rng ( ) ;
72
72
rng. next_u64 ( )
73
73
} ) ;
74
- let mut settings = DiagGradNutsSettings :: default ( ) ;
75
- settings. seed = seed;
74
+ let settings = DiagGradNutsSettings {
75
+ seed,
76
+ ..Default :: default ( )
77
+ } ;
76
78
PyDiagGradNutsSettings ( settings)
77
79
}
78
80
@@ -127,18 +129,12 @@ impl PyDiagGradNutsSettings {
127
129
}
128
130
#[ getter]
129
131
fn initial_step ( & self ) -> f64 {
130
- self . 0
131
- . mass_matrix_adapt
132
- . dual_average_options
133
- . initial_step
132
+ self . 0 . mass_matrix_adapt . dual_average_options . initial_step
134
133
}
135
134
136
135
#[ setter( initial_step) ]
137
136
fn set_initial_step ( & mut self , val : f64 ) {
138
- self . 0
139
- . mass_matrix_adapt
140
- . dual_average_options
141
- . initial_step = val
137
+ self . 0 . mass_matrix_adapt . dual_average_options . initial_step = val
142
138
}
143
139
144
140
#[ getter]
@@ -193,18 +189,12 @@ impl PyDiagGradNutsSettings {
193
189
194
190
#[ setter( target_accept) ]
195
191
fn set_target_accept ( & mut self , val : f64 ) {
196
- self . 0
197
- . mass_matrix_adapt
198
- . dual_average_options
199
- . target_accept = val;
192
+ self . 0 . mass_matrix_adapt . dual_average_options . target_accept = val;
200
193
}
201
194
202
195
#[ getter]
203
196
fn target_accept ( & self ) -> f64 {
204
- self . 0
205
- . mass_matrix_adapt
206
- . dual_average_options
207
- . target_accept
197
+ self . 0 . mass_matrix_adapt . dual_average_options . target_accept
208
198
}
209
199
210
200
#[ getter]
@@ -225,12 +215,18 @@ impl PyDiagGradNutsSettings {
225
215
226
216
#[ getter]
227
217
fn use_grad_based_mass_matrix ( & self ) -> bool {
228
- self . 0 . mass_matrix_adapt . mass_matrix_options . use_grad_based_estimate
218
+ self . 0
219
+ . mass_matrix_adapt
220
+ . mass_matrix_options
221
+ . use_grad_based_estimate
229
222
}
230
223
231
224
#[ setter( use_grad_based_mass_matrix) ]
232
225
fn set_use_grad_based_mass_matrix ( & mut self , val : bool ) {
233
- self . 0 . mass_matrix_adapt . mass_matrix_options . use_grad_based_estimate = val
226
+ self . 0
227
+ . mass_matrix_adapt
228
+ . mass_matrix_options
229
+ . use_grad_based_estimate = val
234
230
}
235
231
}
236
232
@@ -250,7 +246,10 @@ fn make_callback(callback: Option<Py<PyAny>>) -> Option<ProgressCallback> {
250
246
let _ = Python :: with_gil ( |py| {
251
247
let args = PyList :: new_bound (
252
248
py,
253
- stats. into_vec ( ) . into_iter ( ) . map ( |prog| PyChainProgress ( prog) . into_py ( py) )
249
+ stats
250
+ . into_vec ( )
251
+ . into_iter ( )
252
+ . map ( |prog| PyChainProgress ( prog) . into_py ( py) ) ,
254
253
) ;
255
254
callback. call1 ( py, ( args, ) )
256
255
} ) ;
@@ -259,8 +258,8 @@ fn make_callback(callback: Option<Py<PyAny>>) -> Option<ProgressCallback> {
259
258
callback,
260
259
rate : Duration :: from_millis ( 500 ) ,
261
260
} )
262
- } ,
263
- None => { None } ,
261
+ }
262
+ None => None ,
264
263
}
265
264
}
266
265
@@ -430,7 +429,9 @@ impl PySampler {
430
429
} ;
431
430
432
431
let Some ( trace) = trace else {
433
- return Err ( anyhow:: anyhow!( "Sampler failed and did not produce a trace" ) ) ?;
432
+ return Err ( anyhow:: anyhow!(
433
+ "Sampler failed and did not produce a trace"
434
+ ) ) ?;
434
435
} ;
435
436
436
437
trace_to_list ( trace, py)
@@ -456,8 +457,7 @@ impl PySampler {
456
457
}
457
458
}
458
459
459
-
460
- fn trace_to_list < ' py > ( trace : Trace , py : Python < ' py > ) -> PyResult < Bound < ' py , PyList > > {
460
+ fn trace_to_list ( trace : Trace , py : Python < ' _ > ) -> PyResult < Bound < ' _ , PyList > > {
461
461
let list = PyList :: new_bound (
462
462
py,
463
463
trace
@@ -498,7 +498,7 @@ fn export_array(py: Python<'_>, name: String, data: Box<dyn Array>) -> PyResult<
498
498
499
499
/// A Python module implemented in Rust.
500
500
#[ pymodule]
501
- pub fn _lib < ' py > ( m : & Bound < ' py , PyModule > ) -> PyResult < ( ) > {
501
+ pub fn _lib ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
502
502
m. add_class :: < PySampler > ( ) ?;
503
503
m. add_class :: < PyMcModel > ( ) ?;
504
504
m. add_class :: < LogpFunc > ( ) ?;
0 commit comments