Auto merge of #107007 - TDecking:float_parsing_improvments, r=Mark-Simulacrum

Improve the floating point parser in dec2flt.

Greetings everyone,

I've benn studying the rust floating point parser recently and made the following tweaks:

* Remove all remaining traces of `unsafe`. The parser is now 100% safe Rust.
* The trick in which eight digits are processed in parallel is now in a loop.
* Parsing of inf/NaN values has been reworked.

On my system, the changes result in performance improvements for some input values.
This commit is contained in:
bors 2023-04-10 14:09:09 +00:00
commit a73288371e
5 changed files with 196 additions and 296 deletions

View File

@ -1,165 +1,60 @@
//! Common utilities, for internal use only.
use crate::ptr;
/// Helper methods to process immutable bytes.
pub(crate) trait ByteSlice: AsRef<[u8]> {
unsafe fn first_unchecked(&self) -> u8 {
debug_assert!(!self.is_empty());
// SAFETY: safe as long as self is not empty
unsafe { *self.as_ref().get_unchecked(0) }
}
/// Get if the slice contains no elements.
fn is_empty(&self) -> bool {
self.as_ref().is_empty()
}
/// Check if the slice at least `n` length.
fn check_len(&self, n: usize) -> bool {
n <= self.as_ref().len()
}
/// Check if the first character in the slice is equal to c.
fn first_is(&self, c: u8) -> bool {
self.as_ref().first() == Some(&c)
}
/// Check if the first character in the slice is equal to c1 or c2.
fn first_is2(&self, c1: u8, c2: u8) -> bool {
if let Some(&c) = self.as_ref().first() { c == c1 || c == c2 } else { false }
}
/// Bounds-checked test if the first character in the slice is a digit.
fn first_isdigit(&self) -> bool {
if let Some(&c) = self.as_ref().first() { c.is_ascii_digit() } else { false }
}
/// Check if self starts with u with a case-insensitive comparison.
fn starts_with_ignore_case(&self, u: &[u8]) -> bool {
debug_assert!(self.as_ref().len() >= u.len());
let iter = self.as_ref().iter().zip(u.iter());
let d = iter.fold(0, |i, (&x, &y)| i | (x ^ y));
d == 0 || d == 32
}
/// Get the remaining slice after the first N elements.
fn advance(&self, n: usize) -> &[u8] {
&self.as_ref()[n..]
}
/// Get the slice after skipping all leading characters equal c.
fn skip_chars(&self, c: u8) -> &[u8] {
let mut s = self.as_ref();
while s.first_is(c) {
s = s.advance(1);
}
s
}
/// Get the slice after skipping all leading characters equal c1 or c2.
fn skip_chars2(&self, c1: u8, c2: u8) -> &[u8] {
let mut s = self.as_ref();
while s.first_is2(c1, c2) {
s = s.advance(1);
}
s
}
pub(crate) trait ByteSlice {
/// Read 8 bytes as a 64-bit integer in little-endian order.
unsafe fn read_u64_unchecked(&self) -> u64 {
debug_assert!(self.check_len(8));
let src = self.as_ref().as_ptr() as *const u64;
// SAFETY: safe as long as self is at least 8 bytes
u64::from_le(unsafe { ptr::read_unaligned(src) })
}
fn read_u64(&self) -> u64;
/// Try to read the next 8 bytes from the slice.
fn read_u64(&self) -> Option<u64> {
if self.check_len(8) {
// SAFETY: self must be at least 8 bytes.
Some(unsafe { self.read_u64_unchecked() })
} else {
None
}
}
/// Calculate the offset of slice from another.
fn offset_from(&self, other: &Self) -> isize {
other.as_ref().len() as isize - self.as_ref().len() as isize
}
}
impl ByteSlice for [u8] {}
/// Helper methods to process mutable bytes.
pub(crate) trait ByteSliceMut: AsMut<[u8]> {
/// Write a 64-bit integer as 8 bytes in little-endian order.
unsafe fn write_u64_unchecked(&mut self, value: u64) {
debug_assert!(self.as_mut().len() >= 8);
let dst = self.as_mut().as_mut_ptr() as *mut u64;
// NOTE: we must use `write_unaligned`, since dst is not
// guaranteed to be properly aligned. Miri will warn us
// if we use `write` instead of `write_unaligned`, as expected.
// SAFETY: safe as long as self is at least 8 bytes
unsafe {
ptr::write_unaligned(dst, u64::to_le(value));
}
}
}
fn write_u64(&mut self, value: u64);
impl ByteSliceMut for [u8] {}
/// Bytes wrapper with specialized methods for ASCII characters.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct AsciiStr<'a> {
slc: &'a [u8],
}
impl<'a> AsciiStr<'a> {
pub fn new(slc: &'a [u8]) -> Self {
Self { slc }
}
/// Advance the view by n, advancing it in-place to (n..).
pub unsafe fn step_by(&mut self, n: usize) -> &mut Self {
// SAFETY: safe as long n is less than the buffer length
self.slc = unsafe { self.slc.get_unchecked(n..) };
self
}
/// Advance the view by n, advancing it in-place to (1..).
pub unsafe fn step(&mut self) -> &mut Self {
// SAFETY: safe as long as self is not empty
unsafe { self.step_by(1) }
}
/// Calculate the offset of a slice from another.
fn offset_from(&self, other: &Self) -> isize;
/// Iteratively parse and consume digits from bytes.
pub fn parse_digits(&mut self, mut func: impl FnMut(u8)) {
while let Some(&c) = self.as_ref().first() {
/// Returns the same bytes with consumed digits being
/// elided.
fn parse_digits(&self, func: impl FnMut(u8)) -> &Self;
}
impl ByteSlice for [u8] {
#[inline(always)] // inlining this is crucial to remove bound checks
fn read_u64(&self) -> u64 {
let mut tmp = [0; 8];
tmp.copy_from_slice(&self[..8]);
u64::from_le_bytes(tmp)
}
#[inline(always)] // inlining this is crucial to remove bound checks
fn write_u64(&mut self, value: u64) {
self[..8].copy_from_slice(&value.to_le_bytes())
}
#[inline]
fn offset_from(&self, other: &Self) -> isize {
other.len() as isize - self.len() as isize
}
#[inline]
fn parse_digits(&self, mut func: impl FnMut(u8)) -> &Self {
let mut s = self;
// FIXME: Can't use s.split_first() here yet,
// see https://github.com/rust-lang/rust/issues/109328
while let [c, s_next @ ..] = s {
let c = c.wrapping_sub(b'0');
if c < 10 {
func(c);
// SAFETY: self cannot be empty
unsafe {
self.step();
}
s = s_next;
} else {
break;
}
}
}
}
impl<'a> AsRef<[u8]> for AsciiStr<'a> {
#[inline]
fn as_ref(&self) -> &[u8] {
self.slc
s
}
}
impl<'a> ByteSlice for AsciiStr<'a> {}
/// Determine if 8 bytes are all decimal digits.
/// This does not care about the order in which the bytes were loaded.
pub(crate) fn is_8digits(v: u64) -> bool {
@ -168,19 +63,6 @@ pub(crate) fn is_8digits(v: u64) -> bool {
(a | b) & 0x8080_8080_8080_8080 == 0
}
/// Iteratively parse and consume digits from bytes.
pub(crate) fn parse_digits(s: &mut &[u8], mut f: impl FnMut(u8)) {
while let Some(&c) = s.get(0) {
let c = c.wrapping_sub(b'0');
if c < 10 {
f(c);
*s = s.advance(1);
} else {
break;
}
}
}
/// A custom 64-bit floating point type, representing `f * 2^e`.
/// e is biased, so it be directly shifted into the exponent bits.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]

View File

@ -9,7 +9,7 @@
//! algorithm can be found in "ParseNumberF64 by Simple Decimal Conversion",
//! available online: <https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html>.
use crate::num::dec2flt::common::{is_8digits, parse_digits, ByteSlice, ByteSliceMut};
use crate::num::dec2flt::common::{is_8digits, ByteSlice};
#[derive(Clone)]
pub struct Decimal {
@ -205,29 +205,32 @@ impl Decimal {
pub fn parse_decimal(mut s: &[u8]) -> Decimal {
let mut d = Decimal::default();
let start = s;
s = s.skip_chars(b'0');
parse_digits(&mut s, |digit| d.try_add_digit(digit));
if s.first_is(b'.') {
s = s.advance(1);
while let Some((&b'0', s_next)) = s.split_first() {
s = s_next;
}
s = s.parse_digits(|digit| d.try_add_digit(digit));
if let Some((b'.', s_next)) = s.split_first() {
s = s_next;
let first = s;
// Skip leading zeros.
if d.num_digits == 0 {
s = s.skip_chars(b'0');
while let Some((&b'0', s_next)) = s.split_first() {
s = s_next;
}
}
while s.len() >= 8 && d.num_digits + 8 < Decimal::MAX_DIGITS {
// SAFETY: s is at least 8 bytes.
let v = unsafe { s.read_u64_unchecked() };
let v = s.read_u64();
if !is_8digits(v) {
break;
}
// SAFETY: d.num_digits + 8 is less than d.digits.len()
unsafe {
d.digits[d.num_digits..].write_u64_unchecked(v - 0x3030_3030_3030_3030);
}
d.digits[d.num_digits..].write_u64(v - 0x3030_3030_3030_3030);
d.num_digits += 8;
s = s.advance(8);
s = &s[8..];
}
parse_digits(&mut s, |digit| d.try_add_digit(digit));
s = s.parse_digits(|digit| d.try_add_digit(digit));
d.decimal_point = s.len() as i32 - first.len() as i32;
}
if d.num_digits != 0 {
@ -248,23 +251,27 @@ pub fn parse_decimal(mut s: &[u8]) -> Decimal {
d.num_digits = Decimal::MAX_DIGITS;
}
}
if s.first_is2(b'e', b'E') {
s = s.advance(1);
if let Some((&ch, s_next)) = s.split_first() {
if ch == b'e' || ch == b'E' {
s = s_next;
let mut neg_exp = false;
if s.first_is(b'-') {
neg_exp = true;
s = s.advance(1);
} else if s.first_is(b'+') {
s = s.advance(1);
if let Some((&ch, s_next)) = s.split_first() {
neg_exp = ch == b'-';
if ch == b'-' || ch == b'+' {
s = s_next;
}
}
let mut exp_num = 0_i32;
parse_digits(&mut s, |digit| {
s.parse_digits(|digit| {
if exp_num < 0x10000 {
exp_num = 10 * exp_num + digit as i32;
}
});
d.decimal_point += if neg_exp { -exp_num } else { exp_num };
}
}
for i in d.num_digits..Decimal::MAX_DIGITS_WITHOUT_OVERFLOW {
d.digits[i] = 0;
}

View File

@ -79,7 +79,7 @@ use crate::error::Error;
use crate::fmt;
use crate::str::FromStr;
use self::common::{BiasedFp, ByteSlice};
use self::common::BiasedFp;
use self::float::RawFloat;
use self::lemire::compute_float;
use self::parse::{parse_inf_nan, parse_number};
@ -238,17 +238,18 @@ pub fn dec2flt<F: RawFloat>(s: &str) -> Result<F, ParseFloatError> {
};
let negative = c == b'-';
if c == b'-' || c == b'+' {
s = s.advance(1);
s = &s[1..];
}
if s.is_empty() {
return Err(pfe_invalid());
}
let num = match parse_number(s, negative) {
let mut num = match parse_number(s) {
Some(r) => r,
None if let Some(value) = parse_inf_nan(s, negative) => return Ok(value),
None => return Err(pfe_invalid()),
};
num.negative = negative;
if let Some(value) = num.try_fast_path::<F>() {
return Ok(value);
}

View File

@ -1,6 +1,6 @@
//! Functions to parse floating-point numbers.
use crate::num::dec2flt::common::{is_8digits, AsciiStr, ByteSlice};
use crate::num::dec2flt::common::{is_8digits, ByteSlice};
use crate::num::dec2flt::float::RawFloat;
use crate::num::dec2flt::number::Number;
@ -26,24 +26,39 @@ fn parse_8digits(mut v: u64) -> u64 {
}
/// Parse digits until a non-digit character is found.
fn try_parse_digits(s: &mut AsciiStr<'_>, x: &mut u64) {
fn try_parse_digits(mut s: &[u8], mut x: u64) -> (&[u8], u64) {
// may cause overflows, to be handled later
s.parse_digits(|digit| {
*x = x.wrapping_mul(10).wrapping_add(digit as _);
while s.len() >= 8 {
let num = s.read_u64();
if is_8digits(num) {
x = x.wrapping_mul(1_0000_0000).wrapping_add(parse_8digits(num));
s = &s[8..];
} else {
break;
}
}
s = s.parse_digits(|digit| {
x = x.wrapping_mul(10).wrapping_add(digit as _);
});
(s, x)
}
/// Parse up to 19 digits (the max that can be stored in a 64-bit integer).
fn try_parse_19digits(s: &mut AsciiStr<'_>, x: &mut u64) {
fn try_parse_19digits(s_ref: &mut &[u8], x: &mut u64) {
let mut s = *s_ref;
while *x < MIN_19DIGIT_INT {
if let Some(&c) = s.as_ref().first() {
// FIXME: Can't use s.split_first() here yet,
// see https://github.com/rust-lang/rust/issues/109328
if let [c, s_next @ ..] = s {
let digit = c.wrapping_sub(b'0');
if digit < 10 {
*x = (*x * 10) + digit as u64; // no overflows here
// SAFETY: cannot be empty
unsafe {
s.step();
}
s = s_next;
} else {
break;
}
@ -51,46 +66,26 @@ fn try_parse_19digits(s: &mut AsciiStr<'_>, x: &mut u64) {
break;
}
}
}
/// Try to parse 8 digits at a time, using an optimized algorithm.
fn try_parse_8digits(s: &mut AsciiStr<'_>, x: &mut u64) {
// may cause overflows, to be handled later
if let Some(v) = s.read_u64() {
if is_8digits(v) {
*x = x.wrapping_mul(1_0000_0000).wrapping_add(parse_8digits(v));
// SAFETY: already ensured the buffer was >= 8 bytes in read_u64.
unsafe {
s.step_by(8);
}
if let Some(v) = s.read_u64() {
if is_8digits(v) {
*x = x.wrapping_mul(1_0000_0000).wrapping_add(parse_8digits(v));
// SAFETY: already ensured the buffer was >= 8 bytes in try_read_u64.
unsafe {
s.step_by(8);
}
}
}
}
}
*s_ref = s;
}
/// Parse the scientific notation component of a float.
fn parse_scientific(s: &mut AsciiStr<'_>) -> Option<i64> {
let mut exponent = 0_i64;
fn parse_scientific(s_ref: &mut &[u8]) -> Option<i64> {
let mut exponent = 0i64;
let mut negative = false;
if let Some(&c) = s.as_ref().get(0) {
let mut s = *s_ref;
if let Some((&c, s_next)) = s.split_first() {
negative = c == b'-';
if c == b'-' || c == b'+' {
// SAFETY: s cannot be empty
unsafe {
s.step();
s = s_next;
}
}
}
if s.first_isdigit() {
s.parse_digits(|digit| {
if matches!(s.first(), Some(&x) if x.is_ascii_digit()) {
*s_ref = s.parse_digits(|digit| {
// no overflows here, saturate well before overflow
if exponent < 0x10000 {
exponent = 10 * exponent + digit as i64;
@ -98,6 +93,7 @@ fn parse_scientific(s: &mut AsciiStr<'_>) -> Option<i64> {
});
if negative { Some(-exponent) } else { Some(exponent) }
} else {
*s_ref = s;
None
}
}
@ -106,28 +102,29 @@ fn parse_scientific(s: &mut AsciiStr<'_>) -> Option<i64> {
///
/// This creates a representation of the float as the
/// significant digits and the decimal exponent.
fn parse_partial_number(s: &[u8], negative: bool) -> Option<(Number, usize)> {
let mut s = AsciiStr::new(s);
let start = s;
fn parse_partial_number(mut s: &[u8]) -> Option<(Number, usize)> {
debug_assert!(!s.is_empty());
// parse initial digits before dot
let mut mantissa = 0_u64;
let digits_start = s;
try_parse_digits(&mut s, &mut mantissa);
let mut n_digits = s.offset_from(&digits_start);
let start = s;
let tmp = try_parse_digits(s, mantissa);
s = tmp.0;
mantissa = tmp.1;
let mut n_digits = s.offset_from(start);
// handle dot with the following digits
let mut n_after_dot = 0;
let mut exponent = 0_i64;
let int_end = s;
if s.first_is(b'.') {
// SAFETY: s cannot be empty due to first_is
unsafe { s.step() };
if let Some((&b'.', s_next)) = s.split_first() {
s = s_next;
let before = s;
try_parse_8digits(&mut s, &mut mantissa);
try_parse_digits(&mut s, &mut mantissa);
n_after_dot = s.offset_from(&before);
let tmp = try_parse_digits(s, mantissa);
s = tmp.0;
mantissa = tmp.1;
n_after_dot = s.offset_from(before);
exponent = -n_after_dot as i64;
}
@ -138,65 +135,60 @@ fn parse_partial_number(s: &[u8], negative: bool) -> Option<(Number, usize)> {
// handle scientific format
let mut exp_number = 0_i64;
if s.first_is2(b'e', b'E') {
// SAFETY: s cannot be empty
unsafe {
s.step();
}
if let Some((&c, s_next)) = s.split_first() {
if c == b'e' || c == b'E' {
s = s_next;
// If None, we have no trailing digits after exponent, or an invalid float.
exp_number = parse_scientific(&mut s)?;
exponent += exp_number;
}
}
let len = s.offset_from(&start) as _;
let len = s.offset_from(start) as _;
// handle uncommon case with many digits
if n_digits <= 19 {
return Some((Number { exponent, mantissa, negative, many_digits: false }, len));
return Some((Number { exponent, mantissa, negative: false, many_digits: false }, len));
}
n_digits -= 19;
let mut many_digits = false;
let mut p = digits_start;
while p.first_is2(b'0', b'.') {
// SAFETY: p cannot be empty due to first_is2
unsafe {
// '0' = b'.' + 2
n_digits -= p.first_unchecked().saturating_sub(b'0' - 1) as isize;
p.step();
let mut p = start;
while let Some((&c, p_next)) = p.split_first() {
if c == b'.' || c == b'0' {
n_digits -= c.saturating_sub(b'0' - 1) as isize;
p = p_next;
} else {
break;
}
}
if n_digits > 0 {
// at this point we have more than 19 significant digits, let's try again
many_digits = true;
mantissa = 0;
let mut s = digits_start;
let mut s = start;
try_parse_19digits(&mut s, &mut mantissa);
exponent = if mantissa >= MIN_19DIGIT_INT {
// big int
int_end.offset_from(&s)
int_end.offset_from(s)
} else {
// SAFETY: the next byte must be present and be '.'
// We know this is true because we had more than 19
// digits previously, so we overflowed a 64-bit integer,
// but parsing only the integral digits produced less
// than 19 digits. That means we must have a decimal
// point, and at least 1 fractional digit.
unsafe { s.step() };
s = &s[1..];
let before = s;
try_parse_19digits(&mut s, &mut mantissa);
-s.offset_from(&before)
-s.offset_from(before)
} as i64;
// add back the explicit part
exponent += exp_number;
}
Some((Number { exponent, mantissa, negative, many_digits }, len))
Some((Number { exponent, mantissa, negative: false, many_digits }, len))
}
/// Try to parse a non-special floating point number.
pub fn parse_number(s: &[u8], negative: bool) -> Option<Number> {
if let Some((float, rest)) = parse_partial_number(s, negative) {
/// Try to parse a non-special floating point number,
/// as well as two slices with integer and fractional parts
/// and the parsed exponent.
pub fn parse_number(s: &[u8]) -> Option<Number> {
if let Some((float, rest)) = parse_partial_number(s) {
if rest == s.len() {
return Some(float);
}
@ -204,30 +196,48 @@ pub fn parse_number(s: &[u8], negative: bool) -> Option<Number> {
None
}
/// Parse a partial representation of a special, non-finite float.
fn parse_partial_inf_nan<F: RawFloat>(s: &[u8]) -> Option<(F, usize)> {
fn parse_inf_rest(s: &[u8]) -> usize {
if s.len() >= 8 && s[3..].as_ref().starts_with_ignore_case(b"inity") { 8 } else { 3 }
}
if s.len() >= 3 {
if s.starts_with_ignore_case(b"nan") {
return Some((F::NAN, 3));
} else if s.starts_with_ignore_case(b"inf") {
return Some((F::INFINITY, parse_inf_rest(s)));
}
}
None
}
/// Try to parse a special, non-finite float.
pub fn parse_inf_nan<F: RawFloat>(s: &[u8], negative: bool) -> Option<F> {
if let Some((mut float, rest)) = parse_partial_inf_nan::<F>(s) {
if rest == s.len() {
if negative {
float = -float;
pub(crate) fn parse_inf_nan<F: RawFloat>(s: &[u8], negative: bool) -> Option<F> {
// Since a valid string has at most the length 8, we can load
// all relevant characters into a u64 and work from there.
// This also generates much better code.
let mut register;
let len: usize;
// All valid strings are either of length 8 or 3.
if s.len() == 8 {
register = s.read_u64();
len = 8;
} else if s.len() == 3 {
let a = s[0] as u64;
let b = s[1] as u64;
let c = s[2] as u64;
register = (c << 16) | (b << 8) | a;
len = 3;
} else {
return None;
}
return Some(float);
}
}
None
// Clear out the bits which turn ASCII uppercase characters into
// lowercase characters. The resulting string is all uppercase.
// What happens to other characters is irrelevant.
register &= 0xDFDFDFDFDFDFDFDF;
// u64 values corresponding to relevant cases
const INF_3: u64 = 0x464E49; // "INF"
const INF_8: u64 = 0x5954494E49464E49; // "INFINITY"
const NAN: u64 = 0x4E414E; // "NAN"
// Match register value to constant to parse string.
// Also match on the string length to catch edge cases
// like "inf\0\0\0\0\0".
let float = match (register, len) {
(INF_3, 3) => F::INFINITY,
(INF_8, 8) => F::INFINITY,
(NAN, 3) => F::NAN,
_ => return None,
};
if negative { Some(-float) } else { Some(float) }
}

View File

@ -32,7 +32,7 @@ fn invalid_chars() {
}
fn parse_positive(s: &[u8]) -> Option<Number> {
parse_number(s, false)
parse_number(s)
}
#[test]