Skip to content

Commit 65eeb5c

Browse files
committed
Some clippy fixes
1 parent fa136df commit 65eeb5c

File tree

4 files changed

+19
-22
lines changed

4 files changed

+19
-22
lines changed

src/cpu_sampler.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use rand::{prelude::StdRng, Rng, SeedableRng};
1+
use rand::{Rng, SeedableRng};
22
use rayon::prelude::*;
33
use std::thread::JoinHandle;
44
use thiserror::Error;
@@ -124,7 +124,7 @@ pub fn sample_parallel<
124124

125125
let mut error = None;
126126
for _ in 0..n_try_init {
127-
match func.logp(&mut position, &mut grad) {
127+
match func.logp(&position, &mut grad) {
128128
Err(e) => error = Some(e),
129129
Ok(_) => {
130130
error = None;
@@ -218,7 +218,7 @@ pub fn sample_sequentially<F: CpuLogpFunc, R: Rng + ?Sized>(
218218
) -> Result<impl Iterator<Item = Result<(Box<[f64]>, impl SampleStats), NutsError>>, NutsError> {
219219
let mut sampler = new_sampler(logp, settings, chain, rng);
220220
sampler.set_position(start)?;
221-
Ok((0..draws).into_iter().map(move |_| sampler.draw()))
221+
Ok((0..draws).map(move |_| sampler.draw()))
222222
}
223223

224224
/// Initialize chains using uniform jitter around zero or some other provided value
@@ -264,9 +264,7 @@ pub mod test_logps {
264264
}
265265

266266
impl CpuLogpFuncMaker<NormalLogp> for NormalLogp {
267-
//type Func = Self;
268-
269-
fn make_logp_func(&self, chain: usize) -> Result<NormalLogp, anyhow::Error> {
267+
fn make_logp_func(&self, _chain: usize) -> Result<NormalLogp, anyhow::Error> {
270268
Ok(self.clone())
271269
}
272270

@@ -397,7 +395,7 @@ mod tests {
397395
assert!(handles.join().is_ok());
398396

399397
let draw0 = draws.remove(100);
400-
let (vals, stats) = draw0;
398+
let (vals, _) = draw0;
401399
assert_eq!(vals.len(), 10);
402400
}
403401
}

src/cpu_state.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ impl Drop for AlignedArray {
119119
impl Clone for AlignedArray {
120120
fn clone(&self) -> Self {
121121
let mut new = AlignedArray::new(self.size);
122-
new.copy_from_slice(&self);
122+
new.copy_from_slice(self);
123123
new
124124
}
125125
}
@@ -214,9 +214,9 @@ impl crate::nuts::State for State {
214214

215215
fn is_turning(&self, other: &Self) -> bool {
216216
let (start, end) = if self.idx_in_trajectory < other.idx_in_trajectory {
217-
(&*self, other)
217+
(self, other)
218218
} else {
219-
(other, &*self)
219+
(other, self)
220220
};
221221

222222
let a = start.idx_in_trajectory;

src/lib.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//! ```
1515
//! use nuts_rs::{CpuLogpFunc, LogpError, new_sampler, SamplerArgs, Chain, SampleStats};
1616
//! use thiserror::Error;
17+
//! use rand::thread_rng;
1718
//!
1819
//! // Define a function that computes the unnormalized posterior density
1920
//! // and its gradient.
@@ -60,8 +61,8 @@
6061
//! let logp_func = PosteriorDensity {};
6162
//!
6263
//! let chain = 0;
63-
//! let seed = 42;
64-
//! let mut sampler = new_sampler(logp_func, sampler_args, chain, seed);
64+
//! let mut rng = thread_rng();
65+
//! let mut sampler = new_sampler(logp_func, sampler_args, chain, &mut rng);
6566
//!
6667
//! // Set to some initial position and start drawing samples.
6768
//! sampler.set_position(&vec![0f64; 10]).expect("Unrecoverable error during init");
@@ -108,5 +109,5 @@ pub use cpu_sampler::{
108109
JitterInitFunc, ParallelChainResult, ParallelSamplingError, SamplerArgs,
109110
};
110111
#[cfg(feature = "arrow")]
111-
pub use nuts::ArrowBuilder;
112+
pub use nuts::{ArrowBuilder, ArrowRow};
112113
pub use nuts::{Chain, DivergenceInfo, LogpError, NutsError, SampleStats};

src/nuts.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
238238
rng: &mut R,
239239
potential: &mut P,
240240
direction: Direction,
241-
options: &NutsOptions,
242241
collector: &mut C,
243242
) -> ExtendResult<P, C>
244243
where
@@ -253,7 +252,7 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
253252

254253
while other.depth < self.depth {
255254
use ExtendResult::*;
256-
other = match other.extend(pool, rng, potential, direction, options, collector) {
255+
other = match other.extend(pool, rng, potential, direction, collector) {
257256
Ok(tree) => tree,
258257
Turning(_) => {
259258
return Turning(self);
@@ -358,13 +357,9 @@ impl<P: Hamiltonian, C: Collector<State = P::State>> NutsTree<P, C> {
358357
}
359358

360359
fn info(&self, maxdepth: bool, divergence_info: Option<DivergenceInfo>) -> SampleInfo {
361-
let info: Option<DivergenceInfo> = match divergence_info {
362-
Some(info) => Some(info),
363-
None => None,
364-
};
365360
SampleInfo {
366361
depth: self.depth,
367-
divergence_info: info,
362+
divergence_info,
368363
reached_maxdepth: maxdepth,
369364
}
370365
}
@@ -395,7 +390,7 @@ where
395390
let mut tree = NutsTree::new(init.clone());
396391
while tree.depth < options.maxdepth {
397392
let direction: Direction = rng.gen();
398-
tree = match tree.extend(pool, rng, potential, direction, options, collector) {
393+
tree = match tree.extend(pool, rng, potential, direction, collector) {
399394
ExtendResult::Ok(tree) => tree,
400395
ExtendResult::Turning(tree) => {
401396
let info = tree.info(false, None);
@@ -769,6 +764,8 @@ where
769764
#[cfg(test)]
770765
#[cfg(feature = "arrow")]
771766
mod tests {
767+
use rand::thread_rng;
768+
772769
use crate::{adapt_strategy::test_logps::NormalLogp, new_sampler, Chain, SamplerArgs};
773770

774771
use super::ArrowBuilder;
@@ -779,8 +776,9 @@ mod tests {
779776
let func = NormalLogp::new(ndim, 3.);
780777

781778
let settings = SamplerArgs::default();
779+
let mut rng = thread_rng();
782780

783-
let mut chain = new_sampler(func, settings, 0, 0);
781+
let mut chain = new_sampler(func, settings, 0, &mut rng);
784782

785783
let mut builder = chain.stats_builder(ndim, &settings);
786784

0 commit comments

Comments
 (0)