From 8a5c187948ea2db82539e91a43ea8b8e5d13029a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Sun, 10 Nov 2024 23:14:16 +0100 Subject: [PATCH] miri: implement square root without relying on host floats --- src/tools/miri/src/intrinsics/mod.rs | 42 ++++--- src/tools/miri/src/intrinsics/simd.rs | 39 +++--- src/tools/miri/src/lib.rs | 1 + src/tools/miri/src/math.rs | 164 ++++++++++++++++++++++++++ src/tools/miri/src/shims/x86/mod.rs | 27 +---- src/tools/miri/tests/pass/float.rs | 14 ++- 6 files changed, 219 insertions(+), 68 deletions(-) create mode 100644 src/tools/miri/src/math.rs diff --git a/src/tools/miri/src/intrinsics/mod.rs b/src/tools/miri/src/intrinsics/mod.rs index 272dca1594e..9eebbc5d363 100644 --- a/src/tools/miri/src/intrinsics/mod.rs +++ b/src/tools/miri/src/intrinsics/mod.rs @@ -218,20 +218,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { => { let [f] = check_arg_count(args)?; let f = this.read_scalar(f)?.to_f32()?; - // Using host floats (but it's fine, these operations do not have guaranteed precision). - let f_host = f.to_host(); + // Using host floats except for sqrt (but it's fine, these operations do not have + // guaranteed precision). let res = match intrinsic_name { - "sinf32" => f_host.sin(), - "cosf32" => f_host.cos(), - "sqrtf32" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats - "expf32" => f_host.exp(), - "exp2f32" => f_host.exp2(), - "logf32" => f_host.ln(), - "log10f32" => f_host.log10(), - "log2f32" => f_host.log2(), + "sinf32" => f.to_host().sin().to_soft(), + "cosf32" => f.to_host().cos().to_soft(), + "sqrtf32" => math::sqrt(f), + "expf32" => f.to_host().exp().to_soft(), + "exp2f32" => f.to_host().exp2().to_soft(), + "logf32" => f.to_host().ln().to_soft(), + "log10f32" => f.to_host().log10().to_soft(), + "log2f32" => f.to_host().log2().to_soft(), _ => bug!(), }; - let res = res.to_soft(); let res = this.adjust_nan(res, &[f]); this.write_scalar(res, dest)?; } @@ -247,20 +246,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { => { let [f] = check_arg_count(args)?; let f = this.read_scalar(f)?.to_f64()?; - // Using host floats (but it's fine, these operations do not have guaranteed precision). - let f_host = f.to_host(); + // Using host floats except for sqrt (but it's fine, these operations do not have + // guaranteed precision). let res = match intrinsic_name { - "sinf64" => f_host.sin(), - "cosf64" => f_host.cos(), - "sqrtf64" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats - "expf64" => f_host.exp(), - "exp2f64" => f_host.exp2(), - "logf64" => f_host.ln(), - "log10f64" => f_host.log10(), - "log2f64" => f_host.log2(), + "sinf64" => f.to_host().sin().to_soft(), + "cosf64" => f.to_host().cos().to_soft(), + "sqrtf64" => math::sqrt(f), + "expf64" => f.to_host().exp().to_soft(), + "exp2f64" => f.to_host().exp2().to_soft(), + "logf64" => f.to_host().ln().to_soft(), + "log10f64" => f.to_host().log10().to_soft(), + "log2f64" => f.to_host().log2().to_soft(), _ => bug!(), }; - let res = res.to_soft(); let res = this.adjust_nan(res, &[f]); this.write_scalar(res, dest)?; } diff --git a/src/tools/miri/src/intrinsics/simd.rs b/src/tools/miri/src/intrinsics/simd.rs index d5c417e7231..075b6f35e0e 100644 --- a/src/tools/miri/src/intrinsics/simd.rs +++ b/src/tools/miri/src/intrinsics/simd.rs @@ -104,42 +104,39 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let ty::Float(float_ty) = op.layout.ty.kind() else { span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) }; - // Using host floats (but it's fine, these operations do not have guaranteed precision). + // Using host floats except for sqrt (but it's fine, these operations do not + // have guaranteed precision). match float_ty { FloatTy::F16 => unimplemented!("f16_f128"), FloatTy::F32 => { let f = op.to_scalar().to_f32()?; - let f_host = f.to_host(); let res = match host_op { - "fsqrt" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats - "fsin" => f_host.sin(), - "fcos" => f_host.cos(), - "fexp" => f_host.exp(), - "fexp2" => f_host.exp2(), - "flog" => f_host.ln(), - "flog2" => f_host.log2(), - "flog10" => f_host.log10(), + "fsqrt" => math::sqrt(f), + "fsin" => f.to_host().sin().to_soft(), + "fcos" => f.to_host().cos().to_soft(), + "fexp" => f.to_host().exp().to_soft(), + "fexp2" => f.to_host().exp2().to_soft(), + "flog" => f.to_host().ln().to_soft(), + "flog2" => f.to_host().log2().to_soft(), + "flog10" => f.to_host().log10().to_soft(), _ => bug!(), }; - let res = res.to_soft(); let res = this.adjust_nan(res, &[f]); Scalar::from(res) } FloatTy::F64 => { let f = op.to_scalar().to_f64()?; - let f_host = f.to_host(); let res = match host_op { - "fsqrt" => f_host.sqrt(), - "fsin" => f_host.sin(), - "fcos" => f_host.cos(), - "fexp" => f_host.exp(), - "fexp2" => f_host.exp2(), - "flog" => f_host.ln(), - "flog2" => f_host.log2(), - "flog10" => f_host.log10(), + "fsqrt" => math::sqrt(f), + "fsin" => f.to_host().sin().to_soft(), + "fcos" => f.to_host().cos().to_soft(), + "fexp" => f.to_host().exp().to_soft(), + "fexp2" => f.to_host().exp2().to_soft(), + "flog" => f.to_host().ln().to_soft(), + "flog2" => f.to_host().log2().to_soft(), + "flog10" => f.to_host().log10().to_soft(), _ => bug!(), }; - let res = res.to_soft(); let res = this.adjust_nan(res, &[f]); Scalar::from(res) } diff --git a/src/tools/miri/src/lib.rs b/src/tools/miri/src/lib.rs index e69651631bb..85c896563da 100644 --- a/src/tools/miri/src/lib.rs +++ b/src/tools/miri/src/lib.rs @@ -83,6 +83,7 @@ mod eval; mod helpers; mod intrinsics; mod machine; +mod math; mod mono_hash_map; mod operator; mod provenance_gc; diff --git a/src/tools/miri/src/math.rs b/src/tools/miri/src/math.rs new file mode 100644 index 00000000000..ed3d2d55678 --- /dev/null +++ b/src/tools/miri/src/math.rs @@ -0,0 +1,164 @@ +use rand::Rng as _; +use rand::distributions::Distribution as _; +use rustc_apfloat::Float as _; +use rustc_apfloat::ieee::IeeeFloat; + +/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale). +pub(crate) fn apply_random_float_error( + ecx: &mut crate::MiriInterpCx<'_>, + val: F, + err_scale: i32, +) -> F { + let rng = ecx.machine.rng.get_mut(); + // Generate a random integer in the range [0, 2^PREC). + let dist = rand::distributions::Uniform::new(0, 1 << F::PRECISION); + let err = F::from_u128(dist.sample(rng)) + .value + .scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap())); + // give it a random sign + let err = if rng.gen::() { -err } else { err }; + // multiple the value with (1+err) + (val * (F::from_u128(1).value + err).value).value +} + +pub(crate) fn sqrt(x: IeeeFloat) -> IeeeFloat { + match x.category() { + // preserve zero sign + rustc_apfloat::Category::Zero => x, + // propagate NaN + rustc_apfloat::Category::NaN => x, + // sqrt of negative number is NaN + _ if x.is_negative() => IeeeFloat::NAN, + // sqrt(∞) = ∞ + rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY, + rustc_apfloat::Category::Normal => { + // Floating point precision, excluding the integer bit + let prec = i32::try_from(S::PRECISION).unwrap() - 1; + + // x = 2^(exp - prec) * mant + // where mant is an integer with prec+1 bits + // mant is a u128, which should be large enough for the largest prec (112 for f128) + let mut exp = x.ilogb(); + let mut mant = x.scalbn(prec - exp).to_u128(128).value; + + if exp % 2 != 0 { + // Make exponent even, so it can be divided by 2 + exp -= 1; + mant <<= 1; + } + + // Bit-by-bit (base-2 digit-by-digit) sqrt of mant. + // mant is treated here as a fixed point number with prec fractional bits. + // mant will be shifted left by one bit to have an extra fractional bit, which + // will be used to determine the rounding direction. + + // res is the truncated sqrt of mant, where one bit is added at each iteration. + let mut res = 0u128; + // rem is the remainder with the current res + // rem_i = 2^i * ((mant<<1) - res_i^2) + // starting with res = 0, rem = mant<<1 + let mut rem = mant << 1; + // s_i = 2*res_i + let mut s = 0u128; + // d is used to iterate over bits, from high to low (d_i = 2^(-i)) + let mut d = 1u128 << (prec + 1); + + // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that + // (res_i + b_j * 2^(-j))^2 <= mant<<1 + // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b: + // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1 + // And rearranging the terms: + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2) + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i + + while d != 0 { + // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1: + // t = 2*res_i + 2^(-j) + let t = s + d; + if rem >= t { + // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem + res += d; + s += d + d; + rem -= t; + } + // Adjust rem for next iteration + rem <<= 1; + // Shift iterator + d >>= 1; + } + + // Remove extra fractional bit from result, rounding to nearest. + // If the last bit is 0, then the nearest neighbor is definitely the lower one. + // If the last bit is 1, it sounds like this may either be a tie (if there's + // infinitely many 0s after this 1), or the nearest neighbor is the upper one. + // However, since square roots are either exact or irrational, and an exact root + // would lead to the last "extra" bit being 0, we can exclude a tie in this case. + // We therefore always round up if the last bit is 1. When the last bit is 0, + // adding 1 will not do anything since the shift will discard it. + res = (res + 1) >> 1; + + // Build resulting value with res as mantissa and exp/2 as exponent + IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec) + } + } +} + +#[cfg(test)] +mod tests { + use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS}; + + use super::sqrt; + + #[test] + fn test_sqrt() { + #[track_caller] + fn test(x: &str, expected: &str) { + let x: IeeeFloat = x.parse().unwrap(); + let expected: IeeeFloat = expected.parse().unwrap(); + let result = sqrt(x); + assert_eq!(result, expected); + } + + fn exact_tests() { + test::("0", "0"); + test::("1", "1"); + test::("1.5625", "1.25"); + test::("2.25", "1.5"); + test::("4", "2"); + test::("5.0625", "2.25"); + test::("9", "3"); + test::("16", "4"); + test::("25", "5"); + test::("36", "6"); + test::("49", "7"); + test::("64", "8"); + test::("81", "9"); + test::("100", "10"); + + test::("0.5625", "0.75"); + test::("0.25", "0.5"); + test::("0.0625", "0.25"); + test::("0.00390625", "0.0625"); + } + + exact_tests::(); + exact_tests::(); + exact_tests::(); + exact_tests::(); + + test::("2", "1.4142135"); + test::("2", "1.4142135623730951"); + + test::("1.1", "1.0488088"); + test::("1.1", "1.0488088481701516"); + + test::("2.2", "1.4832398"); + test::("2.2", "1.4832396974191326"); + + test::("1.22101e-40", "1.10499205e-20"); + test::("1.22101e-310", "1.1049932126488395e-155"); + + test::("3.4028235e38", "1.8446743e19"); + test::("1.7976931348623157e308", "1.3407807929942596e154"); + } +} diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs index 66c8f3b4c2b..3e02a4b3637 100644 --- a/src/tools/miri/src/shims/x86/mod.rs +++ b/src/tools/miri/src/shims/x86/mod.rs @@ -1,4 +1,3 @@ -use rand::Rng as _; use rustc_abi::{ExternAbi, Size}; use rustc_apfloat::Float; use rustc_apfloat::ieee::Single; @@ -408,38 +407,20 @@ fn unary_op_f32<'tcx>( let div = (Single::from_u128(1).value / op).value; // Apply a relative error with a magnitude on the order of 2^-12 to simulate the // inaccuracy of RCP. - let res = apply_random_float_error(ecx, div, -12); + let res = math::apply_random_float_error(ecx, div, -12); interp_ok(Scalar::from_f32(res)) } FloatUnaryOp::Rsqrt => { - let op = op.to_scalar().to_u32()?; - // FIXME using host floats - let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into()); - let rsqrt = (Single::from_u128(1).value / sqrt).value; + let op = op.to_scalar().to_f32()?; + let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value; // Apply a relative error with a magnitude on the order of 2^-12 to simulate the // inaccuracy of RSQRT. - let res = apply_random_float_error(ecx, rsqrt, -12); + let res = math::apply_random_float_error(ecx, rsqrt, -12); interp_ok(Scalar::from_f32(res)) } } } -/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale). -#[expect(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic -fn apply_random_float_error( - ecx: &mut crate::MiriInterpCx<'_>, - val: F, - err_scale: i32, -) -> F { - let rng = ecx.machine.rng.get_mut(); - // generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale - let err = F::from_u128(rng.gen::().into()).value.scalbn(err_scale.strict_sub(64)); - // give it a random sign - let err = if rng.gen::() { -err } else { err }; - // multiple the value with (1+err) - (val * (F::from_u128(1).value + err).value).value -} - /// Performs `which` operation on the first component of `op` and copies /// the other components. The result is stored in `dest`. fn unary_op_ss<'tcx>( diff --git a/src/tools/miri/tests/pass/float.rs b/src/tools/miri/tests/pass/float.rs index 66843ca584b..4de315e3589 100644 --- a/src/tools/miri/tests/pass/float.rs +++ b/src/tools/miri/tests/pass/float.rs @@ -959,10 +959,20 @@ pub fn libm() { unsafe { ldexp(a, b) } } - assert_approx_eq!(64f32.sqrt(), 8f32); - assert_approx_eq!(64f64.sqrt(), 8f64); + assert_eq!(64_f32.sqrt(), 8_f32); + assert_eq!(64_f64.sqrt(), 8_f64); + assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY); + assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY); + assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); assert!((-5.0_f32).sqrt().is_nan()); assert!((-5.0_f64).sqrt().is_nan()); + assert!(f32::NEG_INFINITY.sqrt().is_nan()); + assert!(f64::NEG_INFINITY.sqrt().is_nan()); + assert!(f32::NAN.sqrt().is_nan()); + assert!(f64::NAN.sqrt().is_nan()); assert_approx_eq!(25f32.powi(-2), 0.0016f32); assert_approx_eq!(23.2f64.powi(2), 538.24f64);