mirror of
https://github.com/rust-lang/rust.git
synced 2024-11-22 23:04:33 +00:00
Auto merge of #115273 - the8472:take-fold, r=cuviper
Optimize Take::{fold, for_each} when wrapping TrustedRandomAccess iterators
This commit is contained in:
commit
c4f25777a0
@ -1,5 +1,7 @@
|
||||
use crate::cmp;
|
||||
use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen};
|
||||
use crate::iter::{
|
||||
adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen, TrustedRandomAccess,
|
||||
};
|
||||
use crate::num::NonZeroUsize;
|
||||
use crate::ops::{ControlFlow, Try};
|
||||
|
||||
@ -98,26 +100,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl_fold_via_try_fold! { fold -> try_fold }
|
||||
#[inline]
|
||||
fn fold<B, F>(self, init: B, f: F) -> B
|
||||
where
|
||||
Self: Sized,
|
||||
F: FnMut(B, Self::Item) -> B,
|
||||
{
|
||||
Self::spec_fold(self, init, f)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
|
||||
// The default implementation would use a unit accumulator, so we can
|
||||
// avoid a stateful closure by folding over the remaining number
|
||||
// of items we wish to return instead.
|
||||
fn check<'a, Item>(
|
||||
mut action: impl FnMut(Item) + 'a,
|
||||
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
|
||||
move |more, x| {
|
||||
action(x);
|
||||
more.checked_sub(1)
|
||||
}
|
||||
}
|
||||
|
||||
let remaining = self.n;
|
||||
if remaining > 0 {
|
||||
self.iter.try_fold(remaining - 1, check(f));
|
||||
}
|
||||
fn for_each<F: FnMut(Self::Item)>(self, f: F) {
|
||||
Self::spec_for_each(self, f)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@ -249,3 +243,72 @@ impl<I> FusedIterator for Take<I> where I: FusedIterator {}
|
||||
|
||||
#[unstable(feature = "trusted_len", issue = "37572")]
|
||||
unsafe impl<I: TrustedLen> TrustedLen for Take<I> {}
|
||||
|
||||
trait SpecTake: Iterator {
|
||||
fn spec_fold<B, F>(self, init: B, f: F) -> B
|
||||
where
|
||||
Self: Sized,
|
||||
F: FnMut(B, Self::Item) -> B;
|
||||
|
||||
fn spec_for_each<F: FnMut(Self::Item)>(self, f: F);
|
||||
}
|
||||
|
||||
impl<I: Iterator> SpecTake for Take<I> {
|
||||
#[inline]
|
||||
default fn spec_fold<B, F>(mut self, init: B, f: F) -> B
|
||||
where
|
||||
Self: Sized,
|
||||
F: FnMut(B, Self::Item) -> B,
|
||||
{
|
||||
use crate::ops::NeverShortCircuit;
|
||||
self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
default fn spec_for_each<F: FnMut(Self::Item)>(mut self, f: F) {
|
||||
// The default implementation would use a unit accumulator, so we can
|
||||
// avoid a stateful closure by folding over the remaining number
|
||||
// of items we wish to return instead.
|
||||
fn check<'a, Item>(
|
||||
mut action: impl FnMut(Item) + 'a,
|
||||
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
|
||||
move |more, x| {
|
||||
action(x);
|
||||
more.checked_sub(1)
|
||||
}
|
||||
}
|
||||
|
||||
let remaining = self.n;
|
||||
if remaining > 0 {
|
||||
self.iter.try_fold(remaining - 1, check(f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: Iterator + TrustedRandomAccess> SpecTake for Take<I> {
|
||||
#[inline]
|
||||
fn spec_fold<B, F>(mut self, init: B, mut f: F) -> B
|
||||
where
|
||||
Self: Sized,
|
||||
F: FnMut(B, Self::Item) -> B,
|
||||
{
|
||||
let mut acc = init;
|
||||
let end = self.n.min(self.iter.size());
|
||||
for i in 0..end {
|
||||
// SAFETY: i < end <= self.iter.size() and we discard the iterator at the end
|
||||
let val = unsafe { self.iter.__iterator_get_unchecked(i) };
|
||||
acc = f(acc, val);
|
||||
}
|
||||
acc
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn spec_for_each<F: FnMut(Self::Item)>(mut self, mut f: F) {
|
||||
let end = self.n.min(self.iter.size());
|
||||
for i in 0..end {
|
||||
// SAFETY: i < end <= self.iter.size() and we discard the iterator at the end
|
||||
let val = unsafe { self.iter.__iterator_get_unchecked(i) };
|
||||
f(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
15
tests/codegen/lib-optimizations/iter-sum.rs
Normal file
15
tests/codegen/lib-optimizations/iter-sum.rs
Normal file
@ -0,0 +1,15 @@
|
||||
// ignore-debug: the debug assertions get in the way
|
||||
// compile-flags: -O
|
||||
// only-x86_64 (vectorization varies between architectures)
|
||||
#![crate_type = "lib"]
|
||||
|
||||
|
||||
// Ensure that slice + take + sum gets vectorized.
|
||||
// Currently this relies on the slice::Iter::try_fold implementation
|
||||
// CHECK-LABEL: @slice_take_sum
|
||||
#[no_mangle]
|
||||
pub fn slice_take_sum(s: &[u64], l: usize) -> u64 {
|
||||
// CHECK: vector.body:
|
||||
// CHECK: ret
|
||||
s.iter().take(l).sum()
|
||||
}
|
Loading…
Reference in New Issue
Block a user