Skip to content

Commit 5ef548b

Browse files
committed
style: Fix formatting and clippy
1 parent 2dd5e92 commit 5ef548b

File tree

3 files changed

+48
-42
lines changed

3 files changed

+48
-42
lines changed

src/pymc.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ use arrow2::{
88
use itertools::{izip, Itertools};
99
use numpy::PyReadonlyArray1;
1010
use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings};
11-
use pyo3::{pyclass, pymethods, types::{PyAnyMethods, PyList}, Bound, PyObject, PyResult};
11+
use pyo3::{
12+
pyclass, pymethods,
13+
types::{PyAnyMethods, PyList},
14+
Bound, PyObject, PyResult,
15+
};
1216
use rand::{distributions::Uniform, prelude::Distribution};
1317

1418
use thiserror::Error;
@@ -188,7 +192,7 @@ impl<'model> DrawStorage for PyMcTrace<'model> {
188192

189193
impl<'model> PyMcTrace<'model> {
190194
fn new(model: &'model PyMcModel, settings: &impl Settings) -> Self {
191-
let draws = (settings.hint_num_draws() + settings.hint_num_tune()) as usize;
195+
let draws = settings.hint_num_draws() + settings.hint_num_tune();
192196
Self {
193197
dim: model.dim,
194198
data: model

src/stan.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,7 @@ fn params(
117117
Some(shape)
118118
})
119119
.unwrap_or(vec![]);
120-
shape
121-
.iter_mut()
122-
.for_each(|max_idx| *max_idx = (*max_idx) + 1);
120+
shape.iter_mut().for_each(|max_idx| *max_idx += 1);
123121
let size = shape.iter().product();
124122
let end_idx = start_idx + size;
125123
variables.push(Parameter {
@@ -170,7 +168,7 @@ impl StanModel {
170168
.ok_or_else(|| anyhow::format_err!("Model is currently in use"))
171169
.context("Failed to access the names of unconstrained parameters")?
172170
.param_unc_names()
173-
.split(",")
171+
.split(',')
174172
.map(|name| name.to_string())
175173
.collect())
176174
}
@@ -261,8 +259,8 @@ fn fortran_to_c_order(data: &[f64], shape: &[usize], out: &mut Vec<f64>) {
261259
}
262260

263261
idx[axis] = 0;
264-
position = position - shape[axis] * strides[axis];
265-
axis = axis + 1;
262+
position -= shape[axis] * strides[axis];
263+
axis += 1;
266264
if axis == rank {
267265
break 'iterate;
268266
}
@@ -284,7 +282,11 @@ impl<'model> Clone for StanTrace<'model> {
284282
// We only need it for `StanTrace.inspect`, which
285283
// doesn't need rng, so we could avoid this strange
286284
// seed of zeros.
287-
let rng = self.model.model.new_rng(0).expect("Could not create stan rng");
285+
let rng = self
286+
.model
287+
.model
288+
.new_rng(0)
289+
.expect("Could not create stan rng");
288290
Self {
289291
inner: self.inner,
290292
model: self.model,
@@ -360,7 +362,7 @@ impl Model for StanModel {
360362
chain: u64,
361363
settings: &S,
362364
) -> anyhow::Result<Self::DrawStorage<'a, S>> {
363-
let draws = (settings.hint_num_tune() + settings.hint_num_draws()) as usize;
365+
let draws = settings.hint_num_tune() + settings.hint_num_draws();
364366
let trace = self
365367
.variables
366368
.iter()
@@ -408,13 +410,13 @@ mod tests {
408410
let data = vec![0., 1., 2., 3., 4., 5.];
409411
let mut out = vec![];
410412
fortran_to_c_order(&data, &[2, 3], &mut out);
411-
let expect = vec![0., 2., 4., 1., 3., 5.];
413+
let expect = [0., 2., 4., 1., 3., 5.];
412414
assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b));
413415

414416
let data = vec![0., 1., 2., 3., 4., 5.];
415417
let mut out = vec![];
416418
fortran_to_c_order(&data, &[3, 2], &mut out);
417-
let expect = vec![0., 3., 1., 4., 2., 5.];
419+
let expect = [0., 3., 1., 4., 2., 5.];
418420
assert!(expect.iter().zip_eq(out.iter()).all(|(a, b)| a == b));
419421

420422
let data = vec![

src/wrapper.rs

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ use crate::{
77

88
use anyhow::{Context, Result};
99
use arrow2::{array::Array, datatypes::Field};
10-
use nuts_rs::{ChainProgress, DiagGradNutsSettings, ProgressCallback, Sampler, SamplerWaitResult, Trace};
10+
use nuts_rs::{
11+
ChainProgress, DiagGradNutsSettings, ProgressCallback, Sampler, SamplerWaitResult, Trace,
12+
};
1113
use pyo3::{
1214
exceptions::PyTimeoutError,
1315
ffi::Py_uintptr_t,
@@ -19,7 +21,6 @@ use rand::{thread_rng, RngCore};
1921
#[pyclass]
2022
struct PyChainProgress(ChainProgress);
2123

22-
2324
#[pymethods]
2425
impl PyChainProgress {
2526
#[getter]
@@ -62,7 +63,6 @@ impl PyChainProgress {
6263
#[derive(Clone, Default)]
6364
pub struct PyDiagGradNutsSettings(DiagGradNutsSettings);
6465

65-
6666
#[pymethods]
6767
impl PyDiagGradNutsSettings {
6868
#[new]
@@ -71,8 +71,10 @@ impl PyDiagGradNutsSettings {
7171
let mut rng = thread_rng();
7272
rng.next_u64()
7373
});
74-
let mut settings = DiagGradNutsSettings::default();
75-
settings.seed = seed;
74+
let settings = DiagGradNutsSettings {
75+
seed,
76+
..Default::default()
77+
};
7678
PyDiagGradNutsSettings(settings)
7779
}
7880

@@ -127,18 +129,12 @@ impl PyDiagGradNutsSettings {
127129
}
128130
#[getter]
129131
fn initial_step(&self) -> f64 {
130-
self.0
131-
.mass_matrix_adapt
132-
.dual_average_options
133-
.initial_step
132+
self.0.mass_matrix_adapt.dual_average_options.initial_step
134133
}
135134

136135
#[setter(initial_step)]
137136
fn set_initial_step(&mut self, val: f64) {
138-
self.0
139-
.mass_matrix_adapt
140-
.dual_average_options
141-
.initial_step = val
137+
self.0.mass_matrix_adapt.dual_average_options.initial_step = val
142138
}
143139

144140
#[getter]
@@ -193,18 +189,12 @@ impl PyDiagGradNutsSettings {
193189

194190
#[setter(target_accept)]
195191
fn set_target_accept(&mut self, val: f64) {
196-
self.0
197-
.mass_matrix_adapt
198-
.dual_average_options
199-
.target_accept = val;
192+
self.0.mass_matrix_adapt.dual_average_options.target_accept = val;
200193
}
201194

202195
#[getter]
203196
fn target_accept(&self) -> f64 {
204-
self.0
205-
.mass_matrix_adapt
206-
.dual_average_options
207-
.target_accept
197+
self.0.mass_matrix_adapt.dual_average_options.target_accept
208198
}
209199

210200
#[getter]
@@ -225,12 +215,18 @@ impl PyDiagGradNutsSettings {
225215

226216
#[getter]
227217
fn use_grad_based_mass_matrix(&self) -> bool {
228-
self.0.mass_matrix_adapt.mass_matrix_options.use_grad_based_estimate
218+
self.0
219+
.mass_matrix_adapt
220+
.mass_matrix_options
221+
.use_grad_based_estimate
229222
}
230223

231224
#[setter(use_grad_based_mass_matrix)]
232225
fn set_use_grad_based_mass_matrix(&mut self, val: bool) {
233-
self.0.mass_matrix_adapt.mass_matrix_options.use_grad_based_estimate = val
226+
self.0
227+
.mass_matrix_adapt
228+
.mass_matrix_options
229+
.use_grad_based_estimate = val
234230
}
235231
}
236232

@@ -250,7 +246,10 @@ fn make_callback(callback: Option<Py<PyAny>>) -> Option<ProgressCallback> {
250246
let _ = Python::with_gil(|py| {
251247
let args = PyList::new_bound(
252248
py,
253-
stats.into_vec().into_iter().map(|prog| PyChainProgress(prog).into_py(py))
249+
stats
250+
.into_vec()
251+
.into_iter()
252+
.map(|prog| PyChainProgress(prog).into_py(py)),
254253
);
255254
callback.call1(py, (args,))
256255
});
@@ -259,8 +258,8 @@ fn make_callback(callback: Option<Py<PyAny>>) -> Option<ProgressCallback> {
259258
callback,
260259
rate: Duration::from_millis(500),
261260
})
262-
},
263-
None => { None },
261+
}
262+
None => None,
264263
}
265264
}
266265

@@ -430,7 +429,9 @@ impl PySampler {
430429
};
431430

432431
let Some(trace) = trace else {
433-
return Err(anyhow::anyhow!("Sampler failed and did not produce a trace"))?;
432+
return Err(anyhow::anyhow!(
433+
"Sampler failed and did not produce a trace"
434+
))?;
434435
};
435436

436437
trace_to_list(trace, py)
@@ -456,8 +457,7 @@ impl PySampler {
456457
}
457458
}
458459

459-
460-
fn trace_to_list<'py>(trace: Trace, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
460+
fn trace_to_list(trace: Trace, py: Python<'_>) -> PyResult<Bound<'_, PyList>> {
461461
let list = PyList::new_bound(
462462
py,
463463
trace
@@ -498,7 +498,7 @@ fn export_array(py: Python<'_>, name: String, data: Box<dyn Array>) -> PyResult<
498498

499499
/// A Python module implemented in Rust.
500500
#[pymodule]
501-
pub fn _lib<'py>(m: &Bound<'py, PyModule>) -> PyResult<()> {
501+
pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> {
502502
m.add_class::<PySampler>()?;
503503
m.add_class::<PyMcModel>()?;
504504
m.add_class::<LogpFunc>()?;

0 commit comments

Comments
 (0)