From 01e97981488ddb0b8194b6f4e27c3592bcd2c8d1 Mon Sep 17 00:00:00 2001
From: Ben Kimock <kimockb@gmail.com>
Date: Mon, 4 Sep 2023 12:40:08 -0400
Subject: [PATCH] Reimplement FileEncoder with a small-write optimization

---
 compiler/rustc_metadata/src/rmeta/table.rs    |   6 +-
 .../src/dep_graph/serialized.rs               |  11 +-
 compiler/rustc_serialize/src/leb128.rs        |  20 +-
 compiler/rustc_serialize/src/lib.rs           |   3 +
 compiler/rustc_serialize/src/opaque.rs        | 270 +++++-------------
 compiler/rustc_serialize/tests/leb128.rs      |  14 +-
 6 files changed, 102 insertions(+), 222 deletions(-)

diff --git a/compiler/rustc_metadata/src/rmeta/table.rs b/compiler/rustc_metadata/src/rmeta/table.rs
index e906c906347..35987072ed6 100644
--- a/compiler/rustc_metadata/src/rmeta/table.rs
+++ b/compiler/rustc_metadata/src/rmeta/table.rs
@@ -5,7 +5,6 @@ use rustc_hir::def::{CtorKind, CtorOf};
 use rustc_index::Idx;
 use rustc_middle::ty::{ParameterizedOverTcx, UnusedGenericParams};
 use rustc_serialize::opaque::FileEncoder;
-use rustc_serialize::Encoder as _;
 use rustc_span::hygiene::MacroKind;
 use std::marker::PhantomData;
 use std::num::NonZeroUsize;
@@ -468,7 +467,10 @@ impl<I: Idx, const N: usize, T: FixedSizeEncoding<ByteArray = [u8; N]>> TableBui
 
         let width = self.width;
         for block in &self.blocks {
-            buf.emit_raw_bytes(&block[..width]);
+            buf.write_with(|dest| {
+                *dest = *block;
+                width
+            });
         }
 
         LazyTable::from_position_and_encoded_size(
diff --git a/compiler/rustc_query_system/src/dep_graph/serialized.rs b/compiler/rustc_query_system/src/dep_graph/serialized.rs
index 4ba0cb31d0b..3fd83f79a48 100644
--- a/compiler/rustc_query_system/src/dep_graph/serialized.rs
+++ b/compiler/rustc_query_system/src/dep_graph/serialized.rs
@@ -394,7 +394,10 @@ struct NodeInfo<K: DepKind> {
 impl<K: DepKind> Encodable<FileEncoder> for NodeInfo<K> {
     fn encode(&self, e: &mut FileEncoder) {
         let header = SerializedNodeHeader::new(self);
-        e.emit_raw_bytes(&header.bytes);
+        e.write_with(|dest| {
+            *dest = header.bytes;
+            header.bytes.len()
+        });
 
         if header.len().is_none() {
             e.emit_usize(self.edges.len());
@@ -402,8 +405,10 @@ impl<K: DepKind> Encodable<FileEncoder> for NodeInfo<K> {
 
         let bytes_per_index = header.bytes_per_index();
         for node_index in self.edges.iter() {
-            let bytes = node_index.as_u32().to_le_bytes();
-            e.emit_raw_bytes(&bytes[..bytes_per_index]);
+            e.write_with(|dest| {
+                *dest = node_index.as_u32().to_le_bytes();
+                bytes_per_index
+            });
         }
     }
 }
diff --git a/compiler/rustc_serialize/src/leb128.rs b/compiler/rustc_serialize/src/leb128.rs
index e568b9e6786..ca661bac78c 100644
--- a/compiler/rustc_serialize/src/leb128.rs
+++ b/compiler/rustc_serialize/src/leb128.rs
@@ -15,23 +15,20 @@ pub const fn largest_max_leb128_len() -> usize {
 macro_rules! impl_write_unsigned_leb128 {
     ($fn_name:ident, $int_ty:ty) => {
         #[inline]
-        pub fn $fn_name(
-            out: &mut [::std::mem::MaybeUninit<u8>; max_leb128_len::<$int_ty>()],
-            mut value: $int_ty,
-        ) -> &[u8] {
+        pub fn $fn_name(out: &mut [u8; max_leb128_len::<$int_ty>()], mut value: $int_ty) -> usize {
             let mut i = 0;
 
             loop {
                 if value < 0x80 {
                     unsafe {
-                        *out.get_unchecked_mut(i).as_mut_ptr() = value as u8;
+                        *out.get_unchecked_mut(i) = value as u8;
                     }
 
                     i += 1;
                     break;
                 } else {
                     unsafe {
-                        *out.get_unchecked_mut(i).as_mut_ptr() = ((value & 0x7f) | 0x80) as u8;
+                        *out.get_unchecked_mut(i) = ((value & 0x7f) | 0x80) as u8;
                     }
 
                     value >>= 7;
@@ -39,7 +36,7 @@ macro_rules! impl_write_unsigned_leb128 {
                 }
             }
 
-            unsafe { ::std::mem::MaybeUninit::slice_assume_init_ref(&out.get_unchecked(..i)) }
+            i
         }
     };
 }
@@ -87,10 +84,7 @@ impl_read_unsigned_leb128!(read_usize_leb128, usize);
 macro_rules! impl_write_signed_leb128 {
     ($fn_name:ident, $int_ty:ty) => {
         #[inline]
-        pub fn $fn_name(
-            out: &mut [::std::mem::MaybeUninit<u8>; max_leb128_len::<$int_ty>()],
-            mut value: $int_ty,
-        ) -> &[u8] {
+        pub fn $fn_name(out: &mut [u8; max_leb128_len::<$int_ty>()], mut value: $int_ty) -> usize {
             let mut i = 0;
 
             loop {
@@ -104,7 +98,7 @@ macro_rules! impl_write_signed_leb128 {
                 }
 
                 unsafe {
-                    *out.get_unchecked_mut(i).as_mut_ptr() = byte;
+                    *out.get_unchecked_mut(i) = byte;
                 }
 
                 i += 1;
@@ -114,7 +108,7 @@ macro_rules! impl_write_signed_leb128 {
                 }
             }
 
-            unsafe { ::std::mem::MaybeUninit::slice_assume_init_ref(&out.get_unchecked(..i)) }
+            i
         }
     };
 }
diff --git a/compiler/rustc_serialize/src/lib.rs b/compiler/rustc_serialize/src/lib.rs
index ce8503918b4..dd40b3cf028 100644
--- a/compiler/rustc_serialize/src/lib.rs
+++ b/compiler/rustc_serialize/src/lib.rs
@@ -17,6 +17,9 @@ Core encoding and decoding interfaces.
 #![feature(new_uninit)]
 #![feature(allocator_api)]
 #![feature(ptr_sub_ptr)]
+#![feature(slice_first_last_chunk)]
+#![feature(inline_const)]
+#![feature(const_option)]
 #![cfg_attr(test, feature(test))]
 #![allow(rustc::internal)]
 #![deny(rustc::untranslatable_diagnostic)]
diff --git a/compiler/rustc_serialize/src/opaque.rs b/compiler/rustc_serialize/src/opaque.rs
index f1b7e8d9ae0..fcd35f2ea57 100644
--- a/compiler/rustc_serialize/src/opaque.rs
+++ b/compiler/rustc_serialize/src/opaque.rs
@@ -3,10 +3,8 @@ use crate::serialize::{Decodable, Decoder, Encodable, Encoder};
 use std::fs::File;
 use std::io::{self, Write};
 use std::marker::PhantomData;
-use std::mem::MaybeUninit;
 use std::ops::Range;
 use std::path::Path;
-use std::ptr;
 
 // -----------------------------------------------------------------------------
 // Encoder
@@ -24,10 +22,9 @@ const BUF_SIZE: usize = 8192;
 /// size of the buffer, rather than the full length of the encoded data, and
 /// because it doesn't need to reallocate memory along the way.
 pub struct FileEncoder {
-    /// The input buffer. For adequate performance, we need more control over
-    /// buffering than `BufWriter` offers. If `BufWriter` ever offers a raw
-    /// buffer access API, we can use it, and remove `buf` and `buffered`.
-    buf: Box<[MaybeUninit<u8>]>,
+    /// The input buffer. For adequate performance, we need to be able to write
+    /// directly to the unwritten region of the buffer, without calling copy_from_slice.
+    buf: Box<[u8; BUF_SIZE]>,
     buffered: usize,
     flushed: usize,
     file: File,
@@ -38,15 +35,11 @@ pub struct FileEncoder {
 
 impl FileEncoder {
     pub fn new<P: AsRef<Path>>(path: P) -> io::Result<Self> {
-        // Create the file for reading and writing, because some encoders do both
-        // (e.g. the metadata encoder when -Zmeta-stats is enabled)
-        let file = File::options().read(true).write(true).create(true).truncate(true).open(path)?;
-
         Ok(FileEncoder {
-            buf: Box::new_uninit_slice(BUF_SIZE),
+            buf: vec![0u8; BUF_SIZE].into_boxed_slice().try_into().unwrap(),
             buffered: 0,
             flushed: 0,
-            file,
+            file: File::create(path)?,
             res: Ok(()),
         })
     }
@@ -54,94 +47,20 @@ impl FileEncoder {
     #[inline]
     pub fn position(&self) -> usize {
         // Tracking position this way instead of having a `self.position` field
-        // means that we don't have to update the position on every write call.
+        // means that we only need to update `self.buffered` on a write call,
+        // as opposed to updating `self.position` and `self.buffered`.
         self.flushed + self.buffered
     }
 
-    pub fn flush(&mut self) {
-        // This is basically a copy of `BufWriter::flush`. If `BufWriter` ever
-        // offers a raw buffer access API, we can use it, and remove this.
-
-        /// Helper struct to ensure the buffer is updated after all the writes
-        /// are complete. It tracks the number of written bytes and drains them
-        /// all from the front of the buffer when dropped.
-        struct BufGuard<'a> {
-            buffer: &'a mut [u8],
-            encoder_buffered: &'a mut usize,
-            encoder_flushed: &'a mut usize,
-            flushed: usize,
-        }
-
-        impl<'a> BufGuard<'a> {
-            fn new(
-                buffer: &'a mut [u8],
-                encoder_buffered: &'a mut usize,
-                encoder_flushed: &'a mut usize,
-            ) -> Self {
-                assert_eq!(buffer.len(), *encoder_buffered);
-                Self { buffer, encoder_buffered, encoder_flushed, flushed: 0 }
-            }
-
-            /// The unwritten part of the buffer
-            fn remaining(&self) -> &[u8] {
-                &self.buffer[self.flushed..]
-            }
-
-            /// Flag some bytes as removed from the front of the buffer
-            fn consume(&mut self, amt: usize) {
-                self.flushed += amt;
-            }
-
-            /// true if all of the bytes have been written
-            fn done(&self) -> bool {
-                self.flushed >= *self.encoder_buffered
-            }
-        }
-
-        impl Drop for BufGuard<'_> {
-            fn drop(&mut self) {
-                if self.flushed > 0 {
-                    if self.done() {
-                        *self.encoder_flushed += *self.encoder_buffered;
-                        *self.encoder_buffered = 0;
-                    } else {
-                        self.buffer.copy_within(self.flushed.., 0);
-                        *self.encoder_flushed += self.flushed;
-                        *self.encoder_buffered -= self.flushed;
-                    }
-                }
-            }
-        }
-
-        // If we've already had an error, do nothing. It'll get reported after
-        // `finish` is called.
-        if self.res.is_err() {
-            return;
-        }
-
-        let mut guard = BufGuard::new(
-            unsafe { MaybeUninit::slice_assume_init_mut(&mut self.buf[..self.buffered]) },
-            &mut self.buffered,
-            &mut self.flushed,
-        );
-
-        while !guard.done() {
-            match self.file.write(guard.remaining()) {
-                Ok(0) => {
-                    self.res = Err(io::Error::new(
-                        io::ErrorKind::WriteZero,
-                        "failed to write the buffered data",
-                    ));
-                    return;
-                }
-                Ok(n) => guard.consume(n),
-                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
-                Err(e) => {
-                    self.res = Err(e);
-                    return;
-                }
-            }
+    #[cold]
+    #[inline(never)]
+    pub fn flush(&mut self) -> &mut [u8; BUF_SIZE] {
+        if self.res.is_ok() {
+            self.res = self.file.write_all(&self.buf[..self.buffered]);
         }
+        self.flushed += self.buffered;
+        self.buffered = 0;
+        &mut self.buf
     }
 
     pub fn file(&self) -> &File {
@@ -149,99 +68,64 @@ impl FileEncoder {
     }
 
     #[inline]
-    fn write_one(&mut self, value: u8) {
-        let mut buffered = self.buffered;
+    fn buffer_empty(&mut self) -> &mut [u8] {
+        // SAFETY: self.buffered is inbounds as an invariant of the type
+        unsafe { self.buf.get_unchecked_mut(self.buffered..) }
+    }
 
-        if std::intrinsics::unlikely(buffered + 1 > BUF_SIZE) {
-            self.flush();
-            buffered = 0;
+    #[cold]
+    #[inline(never)]
+    fn write_all_cold_path(&mut self, buf: &[u8]) {
+        if let Some(dest) = self.flush().get_mut(..buf.len()) {
+            dest.copy_from_slice(buf);
+            self.buffered += buf.len();
+        } else {
+            if self.res.is_ok() {
+                self.res = self.file.write_all(buf);
+            }
+            self.flushed += buf.len();
         }
-
-        // SAFETY: The above check and `flush` ensures that there is enough
-        // room to write the input to the buffer.
-        unsafe {
-            *MaybeUninit::slice_as_mut_ptr(&mut self.buf).add(buffered) = value;
-        }
-
-        self.buffered = buffered + 1;
     }
 
     #[inline]
     fn write_all(&mut self, buf: &[u8]) {
-        let buf_len = buf.len();
-
-        if std::intrinsics::likely(buf_len <= BUF_SIZE) {
-            let mut buffered = self.buffered;
-
-            if std::intrinsics::unlikely(buffered + buf_len > BUF_SIZE) {
-                self.flush();
-                buffered = 0;
-            }
-
-            // SAFETY: The above check and `flush` ensures that there is enough
-            // room to write the input to the buffer.
-            unsafe {
-                let src = buf.as_ptr();
-                let dst = MaybeUninit::slice_as_mut_ptr(&mut self.buf).add(buffered);
-                ptr::copy_nonoverlapping(src, dst, buf_len);
-            }
-
-            self.buffered = buffered + buf_len;
+        if let Some(dest) = self.buffer_empty().get_mut(..buf.len()) {
+            dest.copy_from_slice(buf);
+            self.buffered += buf.len();
         } else {
-            self.write_all_unbuffered(buf);
+            self.write_all_cold_path(buf);
         }
     }
 
-    fn write_all_unbuffered(&mut self, mut buf: &[u8]) {
-        // If we've already had an error, do nothing. It'll get reported after
-        // `finish` is called.
-        if self.res.is_err() {
-            return;
-        }
-
-        if self.buffered > 0 {
+    /// Write up to `N` bytes to this encoder.
+    ///
+    /// Whenever possible, use this function to do writes whose length has a small and
+    /// compile-time constant upper bound.
+    #[inline]
+    pub fn write_with<const N: usize, V>(&mut self, mut visitor: V)
+    where
+        V: FnMut(&mut [u8; N]) -> usize,
+    {
+        let flush_threshold = const { BUF_SIZE.checked_sub(N).unwrap() };
+        if std::intrinsics::unlikely(self.buffered > flush_threshold) {
             self.flush();
         }
-
-        // This is basically a copy of `Write::write_all` but also updates our
-        // `self.flushed`. It's necessary because `Write::write_all` does not
-        // return the number of bytes written when an error is encountered, and
-        // without that, we cannot accurately update `self.flushed` on error.
-        while !buf.is_empty() {
-            match self.file.write(buf) {
-                Ok(0) => {
-                    self.res = Err(io::Error::new(
-                        io::ErrorKind::WriteZero,
-                        "failed to write whole buffer",
-                    ));
-                    return;
-                }
-                Ok(n) => {
-                    buf = &buf[n..];
-                    self.flushed += n;
-                }
-                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
-                Err(e) => {
-                    self.res = Err(e);
-                    return;
-                }
-            }
-        }
+        // SAFETY: We checked above that that N < self.buffer_empty().len(),
+        // and if isn't, flush ensures that our empty buffer is now BUF_SIZE.
+        // We produce a post-mono error if N > BUF_SIZE.
+        let buf = unsafe { self.buffer_empty().first_chunk_mut::<N>().unwrap_unchecked() };
+        let written = visitor(buf);
+        debug_assert!(written <= N);
+        // We have to ensure that an errant visitor cannot cause self.buffered to exeed BUF_SIZE.
+        self.buffered += written.min(N);
     }
 
     pub fn finish(mut self) -> Result<usize, io::Error> {
         self.flush();
-
-        let res = std::mem::replace(&mut self.res, Ok(()));
-        res.map(|()| self.position())
-    }
-}
-
-impl Drop for FileEncoder {
-    fn drop(&mut self) {
-        // Likely to be a no-op, because `finish` should have been called and
-        // it also flushes. But do it just in case.
-        let _result = self.flush();
+        match self.res {
+            Ok(()) => Ok(self.position()),
+            Err(e) => Err(e),
+        }
     }
 }
 
@@ -250,25 +134,7 @@ macro_rules! write_leb128 {
         #[inline]
         fn $this_fn(&mut self, v: $int_ty) {
             const MAX_ENCODED_LEN: usize = $crate::leb128::max_leb128_len::<$int_ty>();
-
-            let mut buffered = self.buffered;
-
-            // This can't overflow because BUF_SIZE and MAX_ENCODED_LEN are both
-            // quite small.
-            if std::intrinsics::unlikely(buffered + MAX_ENCODED_LEN > BUF_SIZE) {
-                self.flush();
-                buffered = 0;
-            }
-
-            // SAFETY: The above check and flush ensures that there is enough
-            // room to write the encoded value to the buffer.
-            let buf = unsafe {
-                &mut *(self.buf.as_mut_ptr().add(buffered)
-                    as *mut [MaybeUninit<u8>; MAX_ENCODED_LEN])
-            };
-
-            let encoded = leb128::$write_leb_fn(buf, v);
-            self.buffered = buffered + encoded.len();
+            self.write_with::<MAX_ENCODED_LEN, _>(|buf| leb128::$write_leb_fn(buf, v))
         }
     };
 }
@@ -281,12 +147,18 @@ impl Encoder for FileEncoder {
 
     #[inline]
     fn emit_u16(&mut self, v: u16) {
-        self.write_all(&v.to_le_bytes());
+        self.write_with(|buf| {
+            *buf = v.to_le_bytes();
+            2
+        });
     }
 
     #[inline]
     fn emit_u8(&mut self, v: u8) {
-        self.write_one(v);
+        self.write_with(|buf: &mut [u8; 1]| {
+            buf[0] = v;
+            1
+        });
     }
 
     write_leb128!(emit_isize, isize, write_isize_leb128);
@@ -296,7 +168,10 @@ impl Encoder for FileEncoder {
 
     #[inline]
     fn emit_i16(&mut self, v: i16) {
-        self.write_all(&v.to_le_bytes());
+        self.write_with(|buf| {
+            *buf = v.to_le_bytes();
+            2
+        });
     }
 
     #[inline]
@@ -495,7 +370,10 @@ impl Encodable<FileEncoder> for IntEncodedWithFixedSize {
     #[inline]
     fn encode(&self, e: &mut FileEncoder) {
         let _start_pos = e.position();
-        e.emit_raw_bytes(&self.0.to_le_bytes());
+        e.write_with(|buf| {
+            *buf = self.0.to_le_bytes();
+            buf.len()
+        });
         let _end_pos = e.position();
         debug_assert_eq!((_end_pos - _start_pos), IntEncodedWithFixedSize::ENCODED_SIZE);
     }
diff --git a/compiler/rustc_serialize/tests/leb128.rs b/compiler/rustc_serialize/tests/leb128.rs
index 7872e778431..dc9b32a968b 100644
--- a/compiler/rustc_serialize/tests/leb128.rs
+++ b/compiler/rustc_serialize/tests/leb128.rs
@@ -1,8 +1,4 @@
-#![feature(maybe_uninit_slice)]
-#![feature(maybe_uninit_uninit_array)]
-
 use rustc_serialize::leb128::*;
-use std::mem::MaybeUninit;
 use rustc_serialize::Decoder;
 
 macro_rules! impl_test_unsigned_leb128 {
@@ -24,9 +20,10 @@ macro_rules! impl_test_unsigned_leb128 {
 
             let mut stream = Vec::new();
 
+            let mut buf = Default::default();
             for &x in &values {
-                let mut buf = MaybeUninit::uninit_array();
-                stream.extend($write_fn_name(&mut buf, x));
+                let n = $write_fn_name(&mut buf, x);
+                stream.extend(&buf[..n]);
             }
 
             let mut decoder = rustc_serialize::opaque::MemDecoder::new(&stream, 0);
@@ -70,9 +67,10 @@ macro_rules! impl_test_signed_leb128 {
 
             let mut stream = Vec::new();
 
+            let mut buf = Default::default();
             for &x in &values {
-                let mut buf = MaybeUninit::uninit_array();
-                stream.extend($write_fn_name(&mut buf, x));
+                let n = $write_fn_name(&mut buf, x);
+                stream.extend(&buf[..n]);
             }
 
             let mut decoder = rustc_serialize::opaque::MemDecoder::new(&stream, 0);