Rollup merge of #103446 - the8472:tra-array-chunks, r=Mark-Simulacrum

Specialize `iter::ArrayChunks::fold` for TrustedRandomAccess iterators

```
OLD:
test iter::bench_trusted_random_access_chunks                      ... bench:         368 ns/iter (+/- 4)
NEW:
test iter::bench_trusted_random_access_chunks                      ... bench:          30 ns/iter (+/- 0)
```

The resulting assembly is similar to #103166 but the specialization kicks in under different (partially overlapping) conditions compared to that PR. They're complementary.

In principle a TRA-based specialization could be applied to all `ArrayChunks` methods, including `next()` as we do for `Zip` but that would have all the same hazards as the Zip specialization. Only doing it for `fold` is far less hazardous. The downside is that it only helps with internal, exhaustive iteration. I.e. `for _ in` or `try_fold` will not benefit.

Note that the regular, `try_fold`-based and the specialized `fold()` impl have observably slightly different behavior. Namely the specialized variant does not fetch the remainder elements from the underlying iterator. We do have a few other places in the standard library where beyond-the-end-of-iteration side-effects are being elided under some circumstances but not others.

Inspired by https://old.reddit.com/r/rust/comments/yaft60/zerocost_iterator_abstractionsnot_so_zerocost/
This commit is contained in:
Dylan DPC 2022-11-08 11:23:50 +05:30 committed by GitHub
commit b695ed3f20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 148 additions and 29 deletions

View File

@ -1,3 +1,4 @@
use core::borrow::Borrow;
use core::iter::*;
use core::mem;
use core::num::Wrapping;
@ -403,13 +404,31 @@ fn bench_trusted_random_access_adapters(b: &mut Bencher) {
/// Exercises the iter::Copied specialization for slice::Iter
#[bench]
fn bench_copied_array_chunks(b: &mut Bencher) {
fn bench_copied_chunks(b: &mut Bencher) {
let v = vec![1u8; 1024];
b.iter(|| {
let mut iter = black_box(&v).iter().copied();
let mut acc = Wrapping(0);
// This uses a while-let loop to side-step the TRA specialization in ArrayChunks
while let Ok(chunk) = iter.next_chunk::<{ mem::size_of::<u64>() }>() {
let d = u64::from_ne_bytes(chunk);
acc += Wrapping(d.rotate_left(7).wrapping_add(1));
}
acc
})
}
/// Exercises the TrustedRandomAccess specialization in ArrayChunks
#[bench]
fn bench_trusted_random_access_chunks(b: &mut Bencher) {
let v = vec![1u8; 1024];
b.iter(|| {
black_box(&v)
.iter()
.copied()
// this shows that we're not relying on the slice::Iter specialization in Copied
.map(|b| *b.borrow())
.array_chunks::<{ mem::size_of::<u64>() }>()
.map(|ary| {
let d = u64::from_ne_bytes(ary);

View File

@ -5,6 +5,7 @@
#![feature(test)]
#![feature(trusted_random_access)]
#![feature(iter_array_chunks)]
#![feature(iter_next_chunk)]
extern crate test;

View File

@ -865,24 +865,6 @@ where
return Ok(Try::from_output(unsafe { mem::zeroed() }));
}
struct Guard<'a, T, const N: usize> {
array_mut: &'a mut [MaybeUninit<T>; N],
initialized: usize,
}
impl<T, const N: usize> Drop for Guard<'_, T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);
// SAFETY: this slice will contain only initialized objects.
unsafe {
crate::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
&mut self.array_mut.get_unchecked_mut(..self.initialized),
));
}
}
}
let mut array = MaybeUninit::uninit_array::<N>();
let mut guard = Guard { array_mut: &mut array, initialized: 0 };
@ -896,13 +878,11 @@ where
ControlFlow::Continue(elem) => elem,
};
// SAFETY: `guard.initialized` starts at 0, is increased by one in the
// loop and the loop is aborted once it reaches N (which is
// `array.len()`).
// SAFETY: `guard.initialized` starts at 0, which means push can be called
// at most N times, which this loop does.
unsafe {
guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
guard.push_unchecked(item);
}
guard.initialized += 1;
}
None => {
let alive = 0..guard.initialized;
@ -920,6 +900,55 @@ where
Ok(Try::from_output(output))
}
/// Panic guard for incremental initialization of arrays.
///
/// Disarm the guard with `mem::forget` once the array has been initialized.
///
/// # Safety
///
/// All write accesses to this structure are unsafe and must maintain a correct
/// count of `initialized` elements.
///
/// To minimize indirection fields are still pub but callers should at least use
/// `push_unchecked` to signal that something unsafe is going on.
pub(crate) struct Guard<'a, T, const N: usize> {
/// The array to be initialized.
pub array_mut: &'a mut [MaybeUninit<T>; N],
/// The number of items that have been initialized so far.
pub initialized: usize,
}
impl<T, const N: usize> Guard<'_, T, N> {
/// Adds an item to the array and updates the initialized item counter.
///
/// # Safety
///
/// No more than N elements must be initialized.
#[inline]
pub unsafe fn push_unchecked(&mut self, item: T) {
// SAFETY: If `initialized` was correct before and the caller does not
// invoke this method more than N times then writes will be in-bounds
// and slots will not be initialized more than once.
unsafe {
self.array_mut.get_unchecked_mut(self.initialized).write(item);
self.initialized = self.initialized.unchecked_add(1);
}
}
}
impl<T, const N: usize> Drop for Guard<'_, T, N> {
fn drop(&mut self) {
debug_assert!(self.initialized <= N);
// SAFETY: this slice will contain only initialized objects.
unsafe {
crate::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
&mut self.array_mut.get_unchecked_mut(..self.initialized),
));
}
}
}
/// Returns the next chunk of `N` items from the iterator or errors with an
/// iterator over the remainder. Used for `Iterator::next_chunk`.
#[inline]

View File

@ -1,6 +1,8 @@
use crate::array;
use crate::iter::{ByRefSized, FusedIterator, Iterator};
use crate::ops::{ControlFlow, Try};
use crate::const_closure::ConstFnMutClosure;
use crate::iter::{ByRefSized, FusedIterator, Iterator, TrustedRandomAccessNoCoerce};
use crate::mem::{self, MaybeUninit};
use crate::ops::{ControlFlow, NeverShortCircuit, Try};
/// An iterator over `N` elements of the iterator at a time.
///
@ -82,7 +84,13 @@ where
}
}
impl_fold_via_try_fold! { fold -> try_fold }
fn fold<B, F>(self, init: B, f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
<Self as SpecFold>::fold(self, init, f)
}
}
#[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")]
@ -168,3 +176,64 @@ where
self.iter.len() < N
}
}
trait SpecFold: Iterator {
fn fold<B, F>(self, init: B, f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B;
}
impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
where
I: Iterator,
{
#[inline]
default fn fold<B, F>(mut self, init: B, mut f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
let fold = ConstFnMutClosure::new(&mut f, NeverShortCircuit::wrap_mut_2_imp);
self.try_fold(init, fold).0
}
}
impl<I, const N: usize> SpecFold for ArrayChunks<I, N>
where
I: Iterator + TrustedRandomAccessNoCoerce,
{
#[inline]
fn fold<B, F>(mut self, init: B, mut f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
let mut accum = init;
let inner_len = self.iter.size();
let mut i = 0;
// Use a while loop because (0..len).step_by(N) doesn't optimize well.
while inner_len - i >= N {
let mut chunk = MaybeUninit::uninit_array();
let mut guard = array::Guard { array_mut: &mut chunk, initialized: 0 };
while guard.initialized < N {
// SAFETY: The method consumes the iterator and the loop condition ensures that
// all accesses are in bounds and only happen once.
unsafe {
let idx = i + guard.initialized;
guard.push_unchecked(self.iter.__iterator_get_unchecked(idx));
}
}
mem::forget(guard);
// SAFETY: The loop above initialized all elements
let chunk = unsafe { MaybeUninit::array_assume_init(chunk) };
accum = f(accum, chunk);
i += N;
}
// unlike try_fold this method does not need to take care of the remainder
// since `self` will be dropped
accum
}
}

View File

@ -139,7 +139,8 @@ fn test_iterator_array_chunks_fold() {
let result =
(0..10).map(|_| CountDrop::new(&count)).array_chunks::<3>().fold(0, |acc, _item| acc + 1);
assert_eq!(result, 3);
assert_eq!(count.get(), 10);
// fold impls may or may not process the remainder
assert!(count.get() <= 10 && count.get() >= 9);
}
#[test]