Add f16 and f128 support to Miri

This commit is contained in:
Trevor Gross 2024-06-13 13:04:16 -05:00
parent eab5e8e9d9
commit 5cb58ad503
2 changed files with 56 additions and 101 deletions

View File

@ -7,7 +7,7 @@ use std::time::Duration;
use rand::RngCore; use rand::RngCore;
use rustc_apfloat::ieee::{Double, Single}; use rustc_apfloat::ieee::{Double, Half, Quad, Single};
use rustc_apfloat::Float; use rustc_apfloat::Float;
use rustc_hir::{ use rustc_hir::{
def::{DefKind, Namespace}, def::{DefKind, Namespace},
@ -1201,12 +1201,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}; };
let (val, status) = match fty { let (val, status) = match fty {
FloatTy::F16 => unimplemented!("f16_f128"), FloatTy::F16 =>
float_to_int_inner::<Half>(this, src.to_scalar().to_f16()?, cast_to, round),
FloatTy::F32 => FloatTy::F32 =>
float_to_int_inner::<Single>(this, src.to_scalar().to_f32()?, cast_to, round), float_to_int_inner::<Single>(this, src.to_scalar().to_f32()?, cast_to, round),
FloatTy::F64 => FloatTy::F64 =>
float_to_int_inner::<Double>(this, src.to_scalar().to_f64()?, cast_to, round), float_to_int_inner::<Double>(this, src.to_scalar().to_f64()?, cast_to, round),
FloatTy::F128 => unimplemented!("f16_f128"), FloatTy::F128 =>
float_to_int_inner::<Quad>(this, src.to_scalar().to_f128()?, cast_to, round),
}; };
if status.intersects( if status.intersects(

View File

@ -1,6 +1,8 @@
#![feature(stmt_expr_attributes)] #![feature(stmt_expr_attributes)]
#![feature(float_gamma)] #![feature(float_gamma)]
#![feature(core_intrinsics)] #![feature(core_intrinsics)]
#![feature(f128)]
#![feature(f16)]
#![allow(arithmetic_overflow)] #![allow(arithmetic_overflow)]
use std::fmt::Debug; use std::fmt::Debug;
@ -41,103 +43,23 @@ trait FloatToInt<Int>: Copy {
unsafe fn cast_unchecked(self) -> Int; unsafe fn cast_unchecked(self) -> Int;
} }
impl FloatToInt<i8> for f32 { macro_rules! float_to_int {
fn cast(self) -> i8 { ($fty:ty => $($ity:ty),+ $(,)?) => {
self as _ $(
} impl FloatToInt<$ity> for $fty {
unsafe fn cast_unchecked(self) -> i8 { fn cast(self) -> $ity {
self.to_int_unchecked() self as _
} }
} unsafe fn cast_unchecked(self) -> $ity {
impl FloatToInt<i32> for f32 { self.to_int_unchecked()
fn cast(self) -> i32 { }
self as _ }
} )*
unsafe fn cast_unchecked(self) -> i32 { };
self.to_int_unchecked()
}
}
impl FloatToInt<u32> for f32 {
fn cast(self) -> u32 {
self as _
}
unsafe fn cast_unchecked(self) -> u32 {
self.to_int_unchecked()
}
}
impl FloatToInt<i64> for f32 {
fn cast(self) -> i64 {
self as _
}
unsafe fn cast_unchecked(self) -> i64 {
self.to_int_unchecked()
}
}
impl FloatToInt<u64> for f32 {
fn cast(self) -> u64 {
self as _
}
unsafe fn cast_unchecked(self) -> u64 {
self.to_int_unchecked()
}
} }
impl FloatToInt<i8> for f64 { float_to_int!(f32 => i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
fn cast(self) -> i8 { float_to_int!(f64 => i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
self as _
}
unsafe fn cast_unchecked(self) -> i8 {
self.to_int_unchecked()
}
}
impl FloatToInt<i32> for f64 {
fn cast(self) -> i32 {
self as _
}
unsafe fn cast_unchecked(self) -> i32 {
self.to_int_unchecked()
}
}
impl FloatToInt<u32> for f64 {
fn cast(self) -> u32 {
self as _
}
unsafe fn cast_unchecked(self) -> u32 {
self.to_int_unchecked()
}
}
impl FloatToInt<i64> for f64 {
fn cast(self) -> i64 {
self as _
}
unsafe fn cast_unchecked(self) -> i64 {
self.to_int_unchecked()
}
}
impl FloatToInt<u64> for f64 {
fn cast(self) -> u64 {
self as _
}
unsafe fn cast_unchecked(self) -> u64 {
self.to_int_unchecked()
}
}
impl FloatToInt<i128> for f64 {
fn cast(self) -> i128 {
self as _
}
unsafe fn cast_unchecked(self) -> i128 {
self.to_int_unchecked()
}
}
impl FloatToInt<u128> for f64 {
fn cast(self) -> u128 {
self as _
}
unsafe fn cast_unchecked(self) -> u128 {
self.to_int_unchecked()
}
}
/// Test this cast both via `as` and via `approx_unchecked` (i.e., it must not saturate). /// Test this cast both via `as` and via `approx_unchecked` (i.e., it must not saturate).
#[track_caller] #[track_caller]
@ -153,18 +75,29 @@ where
fn basic() { fn basic() {
// basic arithmetic // basic arithmetic
assert_eq(6.0_f16 * 6.0_f16, 36.0_f16);
assert_eq(6.0_f32 * 6.0_f32, 36.0_f32); assert_eq(6.0_f32 * 6.0_f32, 36.0_f32);
assert_eq(6.0_f64 * 6.0_f64, 36.0_f64); assert_eq(6.0_f64 * 6.0_f64, 36.0_f64);
assert_eq(6.0_f128 * 6.0_f128, 36.0_f128);
assert_eq(-{ 5.0_f16 }, -5.0_f16);
assert_eq(-{ 5.0_f32 }, -5.0_f32); assert_eq(-{ 5.0_f32 }, -5.0_f32);
assert_eq(-{ 5.0_f64 }, -5.0_f64); assert_eq(-{ 5.0_f64 }, -5.0_f64);
assert_eq(-{ 5.0_f128 }, -5.0_f128);
// infinities, NaN // infinities, NaN
// FIXME(f16_f128): add when constants and `is_infinite` are available
assert!((5.0_f32 / 0.0).is_infinite()); assert!((5.0_f32 / 0.0).is_infinite());
assert_ne!({ 5.0_f32 / 0.0 }, { -5.0_f32 / 0.0 }); assert_ne!({ 5.0_f32 / 0.0 }, { -5.0_f32 / 0.0 });
assert!((5.0_f64 / 0.0).is_infinite()); assert!((5.0_f64 / 0.0).is_infinite());
assert_ne!({ 5.0_f64 / 0.0 }, { 5.0_f64 / -0.0 }); assert_ne!({ 5.0_f64 / 0.0 }, { 5.0_f64 / -0.0 });
assert_ne!(f32::NAN, f32::NAN); assert_ne!(f32::NAN, f32::NAN);
assert_ne!(f64::NAN, f64::NAN); assert_ne!(f64::NAN, f64::NAN);
// negative zero // negative zero
let posz = 0.0f16;
let negz = -0.0f16;
assert_eq(posz, negz);
assert_ne!(posz.to_bits(), negz.to_bits());
let posz = 0.0f32; let posz = 0.0f32;
let negz = -0.0f32; let negz = -0.0f32;
assert_eq(posz, negz); assert_eq(posz, negz);
@ -173,15 +106,30 @@ fn basic() {
let negz = -0.0f64; let negz = -0.0f64;
assert_eq(posz, negz); assert_eq(posz, negz);
assert_ne!(posz.to_bits(), negz.to_bits()); assert_ne!(posz.to_bits(), negz.to_bits());
let posz = 0.0f128;
let negz = -0.0f128;
assert_eq(posz, negz);
assert_ne!(posz.to_bits(), negz.to_bits());
// byte-level transmute // byte-level transmute
let x: u64 = unsafe { std::mem::transmute(42.0_f64) }; let x: u16 = unsafe { std::mem::transmute(42.0_f16) };
let y: f64 = unsafe { std::mem::transmute(x) }; let y: f16 = unsafe { std::mem::transmute(x) };
assert_eq(y, 42.0_f64); assert_eq(y, 42.0_f16);
let x: u32 = unsafe { std::mem::transmute(42.0_f32) }; let x: u32 = unsafe { std::mem::transmute(42.0_f32) };
let y: f32 = unsafe { std::mem::transmute(x) }; let y: f32 = unsafe { std::mem::transmute(x) };
assert_eq(y, 42.0_f32); assert_eq(y, 42.0_f32);
let x: u64 = unsafe { std::mem::transmute(42.0_f64) };
let y: f64 = unsafe { std::mem::transmute(x) };
assert_eq(y, 42.0_f64);
let x: u128 = unsafe { std::mem::transmute(42.0_f128) };
let y: f128 = unsafe { std::mem::transmute(x) };
assert_eq(y, 42.0_f128);
// `%` sign behavior, some of this used to be buggy // `%` sign behavior, some of this used to be buggy
assert!((black_box(1.0f16) % 1.0).is_sign_positive());
assert!((black_box(1.0f16) % -1.0).is_sign_positive());
assert!((black_box(-1.0f16) % 1.0).is_sign_negative());
assert!((black_box(-1.0f16) % -1.0).is_sign_negative());
assert!((black_box(1.0f32) % 1.0).is_sign_positive()); assert!((black_box(1.0f32) % 1.0).is_sign_positive());
assert!((black_box(1.0f32) % -1.0).is_sign_positive()); assert!((black_box(1.0f32) % -1.0).is_sign_positive());
assert!((black_box(-1.0f32) % 1.0).is_sign_negative()); assert!((black_box(-1.0f32) % 1.0).is_sign_negative());
@ -190,7 +138,12 @@ fn basic() {
assert!((black_box(1.0f64) % -1.0).is_sign_positive()); assert!((black_box(1.0f64) % -1.0).is_sign_positive());
assert!((black_box(-1.0f64) % 1.0).is_sign_negative()); assert!((black_box(-1.0f64) % 1.0).is_sign_negative());
assert!((black_box(-1.0f64) % -1.0).is_sign_negative()); assert!((black_box(-1.0f64) % -1.0).is_sign_negative());
assert!((black_box(1.0f128) % 1.0).is_sign_positive());
assert!((black_box(1.0f128) % -1.0).is_sign_positive());
assert!((black_box(-1.0f128) % 1.0).is_sign_negative());
assert!((black_box(-1.0f128) % -1.0).is_sign_negative());
// FIXME(f16_f128): add when `abs` is available
assert_eq!((-1.0f32).abs(), 1.0f32); assert_eq!((-1.0f32).abs(), 1.0f32);
assert_eq!(34.2f64.abs(), 34.2f64); assert_eq!(34.2f64.abs(), 34.2f64);
} }