From 00bc9e8ac4cc6fba98eef558554e7fcd747e49c1 Mon Sep 17 00:00:00 2001 From: austinabell Date: Sun, 14 Aug 2022 13:25:13 -0400 Subject: [PATCH] fix(iter::skip): Optimize `next` and `nth` implementations of `Skip` --- library/core/src/iter/adapters/skip.rs | 27 +++++++++++++++------ library/core/tests/iter/adapters/skip.rs | 31 ++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/library/core/src/iter/adapters/skip.rs b/library/core/src/iter/adapters/skip.rs index 2c283100f07..dbf0ae9eca3 100644 --- a/library/core/src/iter/adapters/skip.rs +++ b/library/core/src/iter/adapters/skip.rs @@ -33,21 +33,32 @@ where #[inline] fn next(&mut self) -> Option { if unlikely(self.n > 0) { - self.iter.nth(crate::mem::take(&mut self.n) - 1)?; + self.iter.nth(crate::mem::take(&mut self.n)) + } else { + self.iter.next() } - self.iter.next() } #[inline] fn nth(&mut self, n: usize) -> Option { - // Can't just add n + self.n due to overflow. if self.n > 0 { - let to_skip = self.n; - self.n = 0; - // nth(n) skips n+1 - self.iter.nth(to_skip - 1)?; + let skip: usize = crate::mem::take(&mut self.n); + // Checked add to handle overflow case. + let n = match skip.checked_add(n) { + Some(nth) => nth, + None => { + // In case of overflow, load skip value, before loading `n`. + // Because the amount of elements to iterate is beyond `usize::MAX`, this + // is split into two `nth` calls where the `skip` `nth` call is discarded. + self.iter.nth(skip - 1)?; + n + } + }; + // Load nth element including skip. + self.iter.nth(n) + } else { + self.iter.nth(n) } - self.iter.nth(n) } #[inline] diff --git a/library/core/tests/iter/adapters/skip.rs b/library/core/tests/iter/adapters/skip.rs index 65f235e86aa..754641834e8 100644 --- a/library/core/tests/iter/adapters/skip.rs +++ b/library/core/tests/iter/adapters/skip.rs @@ -201,3 +201,34 @@ fn test_skip_non_fused() { // advance it further. `Unfuse` tests that this doesn't happen by panicking in that scenario. let _ = non_fused.skip(20).next(); } + +#[test] +fn test_skip_non_fused_nth_overflow() { + let non_fused = Unfuse::new(0..10); + + // Ensures that calling skip and `nth` where the sum would overflow does not fail for non-fused + // iterators. + let _ = non_fused.skip(20).nth(usize::MAX); +} + +#[test] +fn test_skip_overflow_wrapping() { + // Test to ensure even on overflowing on `skip+nth` the correct amount of elements are yielded. + struct WrappingIterator(usize); + + impl Iterator for WrappingIterator { + type Item = usize; + + fn next(&mut self) -> core::option::Option { + ::nth(self, 0) + } + + fn nth(&mut self, nth: usize) -> core::option::Option { + self.0 = self.0.wrapping_add(nth.wrapping_add(1)); + Some(self.0) + } + } + + let wrap = WrappingIterator(0); + assert_eq!(wrap.skip(20).nth(usize::MAX), Some(20)); +}