Skip to content

Commit cfcce8e

Browse files
committed
specialize iter::ArrayChunks::fold for TrustedRandomAccess iters
This is fairly safe use of TRA since it consumes the iterator so no struct in an unsafe state will be left exposed to user code
1 parent eb3f001 commit cfcce8e

File tree

1 file changed

+86
-3
lines changed

1 file changed

+86
-3
lines changed

library/core/src/iter/adapters/array_chunks.rs

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use crate::array;
2-
use crate::iter::{ByRefSized, FusedIterator, Iterator};
3-
use crate::ops::{ControlFlow, Try};
2+
use crate::const_closure::ConstFnMutClosure;
3+
use crate::iter::{ByRefSized, FusedIterator, Iterator, TrustedRandomAccessNoCoerce};
4+
use crate::mem::{self, MaybeUninit};
5+
use crate::ops::{ControlFlow, NeverShortCircuit, Try};
46

57
/// An iterator over `N` elements of the iterator at a time.
68
///
@@ -82,7 +84,13 @@ where
8284
}
8385
}
8486

85-
impl_fold_via_try_fold! { fold -> try_fold }
87+
fn fold<B, F>(self, init: B, f: F) -> B
88+
where
89+
Self: Sized,
90+
F: FnMut(B, Self::Item) -> B,
91+
{
92+
<Self as SpecFold>::fold(self, init, f)
93+
}
8694
}
8795

8896
#[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")]
@@ -168,3 +176,78 @@ where
168176
self.iter.len() < N
169177
}
170178
}
179+
180+
trait SpecFold: Iterator {
181+
fn fold<B, F>(self, init: B, f: F) -> B
182+
where
183+
Self: Sized,
184+
F: FnMut(B, Self::Item) -> B;
185+
}
186+
187+
impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
188+
where
189+
I: Iterator,
190+
{
191+
#[inline]
192+
default fn fold<B, F>(mut self, init: B, mut f: F) -> B
193+
where
194+
Self: Sized,
195+
F: FnMut(B, Self::Item) -> B,
196+
{
197+
let fold = ConstFnMutClosure::new(&mut f, NeverShortCircuit::wrap_mut_2_imp);
198+
self.try_fold(init, fold).0
199+
}
200+
}
201+
202+
impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
203+
where
204+
I: Iterator + TrustedRandomAccessNoCoerce,
205+
{
206+
#[inline]
207+
fn fold<B, F>(mut self, init: B, mut f: F) -> B
208+
where
209+
Self: Sized,
210+
F: FnMut(B, Self::Item) -> B,
211+
{
212+
if self.remainder.is_some() {
213+
return init;
214+
}
215+
216+
let mut accum = init;
217+
let inner_len = self.iter.size();
218+
let mut i = 0;
219+
// Use a while loop because (0..len).step_by(N) doesn't optimize well.
220+
while inner_len - i >= N {
221+
let mut chunk = MaybeUninit::uninit_array();
222+
let mut guard = array::Guard { array_mut: &mut chunk, initialized: 0 };
223+
for j in 0..N {
224+
// SAFETY: The method consumes the iterator and the loop condition ensures that
225+
// all accesses are in bounds and only happen once.
226+
guard.array_mut[j].write(unsafe { self.iter.__iterator_get_unchecked(i + j) });
227+
guard.initialized = j + 1;
228+
}
229+
mem::forget(guard);
230+
// SAFETY: The loop above initialized all elements
231+
let chunk = unsafe { MaybeUninit::array_assume_init(chunk) };
232+
accum = f(accum, chunk);
233+
i += N;
234+
}
235+
236+
let remainder = inner_len % N;
237+
238+
let mut tail = MaybeUninit::uninit_array();
239+
let mut guard = array::Guard { array_mut: &mut tail, initialized: 0 };
240+
for i in 0..remainder {
241+
// SAFETY: the remainder was not visited by the previous loop, so we're still only
242+
// accessing each element once
243+
let val = unsafe { self.iter.__iterator_get_unchecked(inner_len - remainder + i) };
244+
guard.array_mut[i].write(val);
245+
guard.initialized = i + 1;
246+
}
247+
mem::forget(guard);
248+
// SAFETY: the loop above initialized elements up to the `remainder` index
249+
self.remainder = Some(unsafe { array::IntoIter::new_unchecked(tail, 0..remainder) });
250+
251+
accum
252+
}
253+
}

0 commit comments

Comments
 (0)