diff --git a/src/tools/miri/src/shims/x86/sse41.rs b/src/tools/miri/src/shims/x86/sse41.rs index cfa06ded6e6..523f3bfc26f 100644 --- a/src/tools/miri/src/shims/x86/sse41.rs +++ b/src/tools/miri/src/shims/x86/sse41.rs @@ -148,6 +148,14 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: round_first::(this, left, right, rounding, dest)?; } + // Used to implement the _mm_floor_ps, _mm_ceil_ps and _mm_round_ps + // functions. Rounds the elements of `op` according to `rounding`. + "round.ps" => { + let [op, rounding] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + round_all::(this, op, rounding, dest)?; + } // Used to implement the _mm_floor_sd, _mm_ceil_sd and _mm_round_sd // functions. Rounds the first element of `right` according to `rounding` // and copies the remaining elements from `left`. @@ -157,6 +165,14 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: round_first::(this, left, right, rounding, dest)?; } + // Used to implement the _mm_floor_pd, _mm_ceil_pd and _mm_round_pd + // functions. Rounds the elements of `op` according to `rounding`. + "round.pd" => { + let [op, rounding] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + round_all::(this, op, rounding, dest)?; + } // Used to implement the _mm_minpos_epu16 function. // Find the minimum unsinged 16-bit integer in `op` and // returns its value and position. @@ -283,22 +299,7 @@ fn round_first<'tcx, F: rustc_apfloat::Float>( assert_eq!(dest_len, left_len); assert_eq!(dest_len, right_len); - // The fourth bit of `rounding` only affects the SSE status - // register, which cannot be accessed from Miri (or from Rust, - // for that matter), so we can ignore it. - let rounding = match this.read_scalar(rounding)?.to_i32()? & !0b1000 { - // When the third bit is 0, the rounding mode is determined by the - // first two bits. - 0b000 => rustc_apfloat::Round::NearestTiesToEven, - 0b001 => rustc_apfloat::Round::TowardNegative, - 0b010 => rustc_apfloat::Round::TowardPositive, - 0b011 => rustc_apfloat::Round::TowardZero, - // When the third bit is 1, the rounding mode is determined by the - // SSE status register. Since we do not support modifying it from - // Miri (or Rust), we assume it to be at its default mode (round-to-nearest). - 0b100..=0b111 => rustc_apfloat::Round::NearestTiesToEven, - rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"), - }; + let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?; let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?; let res = op0.round_to_integral(rounding).value; @@ -317,3 +318,50 @@ fn round_first<'tcx, F: rustc_apfloat::Float>( Ok(()) } + +// Rounds all elements of `op` according to `rounding`. +fn round_all<'tcx, F: rustc_apfloat::Float>( + this: &mut crate::MiriInterpCx<'_, 'tcx>, + op: &OpTy<'tcx, Provenance>, + rounding: &OpTy<'tcx, Provenance>, + dest: &PlaceTy<'tcx, Provenance>, +) -> InterpResult<'tcx, ()> { + let (op, op_len) = this.operand_to_simd(op)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, op_len); + + let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?; + + for i in 0..dest_len { + let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?; + let res = op.round_to_integral(rounding).value; + this.write_scalar( + Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)), + &this.project_index(&dest, i)?, + )?; + } + + Ok(()) +} + +/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of +/// `round.{ss,sd,ps,pd}` intrinsics. +fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> { + // The fourth bit of `rounding` only affects the SSE status + // register, which cannot be accessed from Miri (or from Rust, + // for that matter), so we can ignore it. + match rounding & !0b1000 { + // When the third bit is 0, the rounding mode is determined by the + // first two bits. + 0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven), + 0b001 => Ok(rustc_apfloat::Round::TowardNegative), + 0b010 => Ok(rustc_apfloat::Round::TowardPositive), + 0b011 => Ok(rustc_apfloat::Round::TowardZero), + // When the third bit is 1, the rounding mode is determined by the + // SSE status register. Since we do not support modifying it from + // Miri (or Rust), we assume it to be at its default mode (round-to-nearest). + 0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven), + rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"), + } +} diff --git a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs index d5489ffaf4b..8c565a2d6e0 100644 --- a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs +++ b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs @@ -73,114 +73,342 @@ unsafe fn test_sse41() { test_mm_dp_ps(); #[target_feature(enable = "sse4.1")] - unsafe fn test_mm_floor_sd() { - let a = _mm_setr_pd(2.5, 4.5); - let b = _mm_setr_pd(-1.5, -3.5); - let r = _mm_floor_sd(a, b); - let e = _mm_setr_pd(-2.0, 4.5); - assert_eq_m128d(r, e); - } - test_mm_floor_sd(); + unsafe fn test_round_nearest_f32() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f32, res: f32) { + let a = _mm_setr_ps(3.5, 2.5, 1.5, 4.5); + let b = _mm_setr_ps(x, -1.5, -3.5, -2.5); + let e = _mm_setr_ps(res, 2.5, 1.5, 4.5); + let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b); + assert_eq_m128(r, e); + // Assume round-to-nearest by default + let r = _mm_round_ss::<_MM_FROUND_CUR_DIRECTION>(a, b); + assert_eq_m128(r, e); - #[target_feature(enable = "sse4.1")] - unsafe fn test_mm_floor_ss() { - let a = _mm_setr_ps(2.5, 4.5, 8.5, 16.5); - let b = _mm_setr_ps(-1.5, -3.5, -7.5, -15.5); - let r = _mm_floor_ss(a, b); - let e = _mm_setr_ps(-2.0, 4.5, 8.5, 16.5); + let a = _mm_set1_ps(x); + let e = _mm_set1_ps(res); + let r = _mm_round_ps::<_MM_FROUND_TO_NEAREST_INT>(a); + assert_eq_m128(r, e); + // Assume round-to-nearest by default + let r = _mm_round_ps::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m128(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm_setr_ps(1.5, 3.5, 5.5, 7.5); + let e = _mm_setr_ps(2.0, 4.0, 6.0, 8.0); + let r = _mm_round_ps::<_MM_FROUND_TO_NEAREST_INT>(a); assert_eq_m128(r, e); - } - test_mm_floor_ss(); - - #[target_feature(enable = "sse4.1")] - unsafe fn test_mm_ceil_sd() { - let a = _mm_setr_pd(1.5, 3.5); - let b = _mm_setr_pd(-2.5, -4.5); - let r = _mm_ceil_sd(a, b); - let e = _mm_setr_pd(-2.0, 3.5); - assert_eq_m128d(r, e); - } - test_mm_ceil_sd(); - - #[target_feature(enable = "sse4.1")] - unsafe fn test_mm_ceil_ss() { - let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); - let b = _mm_setr_ps(-2.5, -4.5, -8.5, -16.5); - let r = _mm_ceil_ss(a, b); - let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); - assert_eq_m128(r, e); - } - test_mm_ceil_ss(); - - #[target_feature(enable = "sse4.1")] - unsafe fn test_mm_round_sd() { - let a = _mm_setr_pd(1.5, 3.5); - let b = _mm_setr_pd(-2.5, -4.5); - let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b); - let e = _mm_setr_pd(-2.0, 3.5); - assert_eq_m128d(r, e); - - let a = _mm_setr_pd(1.5, 3.5); - let b = _mm_setr_pd(-2.5, -4.5); - let r = _mm_round_sd::<_MM_FROUND_TO_NEG_INF>(a, b); - let e = _mm_setr_pd(-3.0, 3.5); - assert_eq_m128d(r, e); - - let a = _mm_setr_pd(1.5, 3.5); - let b = _mm_setr_pd(-2.5, -4.5); - let r = _mm_round_sd::<_MM_FROUND_TO_POS_INF>(a, b); - let e = _mm_setr_pd(-2.0, 3.5); - assert_eq_m128d(r, e); - - let a = _mm_setr_pd(1.5, 3.5); - let b = _mm_setr_pd(-2.5, -4.5); - let r = _mm_round_sd::<_MM_FROUND_TO_ZERO>(a, b); - let e = _mm_setr_pd(-2.0, 3.5); - assert_eq_m128d(r, e); - // Assume round-to-nearest by default - let a = _mm_setr_pd(1.5, 3.5); - let b = _mm_setr_pd(-2.5, -4.5); - let r = _mm_round_sd::<_MM_FROUND_CUR_DIRECTION>(a, b); - let e = _mm_setr_pd(-2.0, 3.5); - assert_eq_m128d(r, e); + let r = _mm_round_ps::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m128(r, e); } - test_mm_round_sd(); + test_round_nearest_f32(); #[target_feature(enable = "sse4.1")] - unsafe fn test_mm_round_ss() { - let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); - let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); - let r = _mm_round_ss::<_MM_FROUND_TO_NEAREST_INT>(a, b); - let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); - assert_eq_m128(r, e); + unsafe fn test_round_floor_f32() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f32, res: f32) { + let a = _mm_setr_ps(3.5, 2.5, 1.5, 4.5); + let b = _mm_setr_ps(x, -1.5, -3.5, -2.5); + let e = _mm_setr_ps(res, 2.5, 1.5, 4.5); + let r = _mm_floor_ss(a, b); + assert_eq_m128(r, e); + let r = _mm_round_ss::<_MM_FROUND_TO_NEG_INF>(a, b); + assert_eq_m128(r, e); - let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); - let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); - let r = _mm_round_ss::<_MM_FROUND_TO_NEG_INF>(a, b); - let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); - assert_eq_m128(r, e); + let a = _mm_set1_ps(x); + let e = _mm_set1_ps(res); + let r = _mm_floor_ps(a); + assert_eq_m128(r, e); + let r = _mm_round_ps::<_MM_FROUND_TO_NEG_INF>(a); + assert_eq_m128(r, e); + } - let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); - let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); - let r = _mm_round_ss::<_MM_FROUND_TO_POS_INF>(a, b); - let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5); - assert_eq_m128(r, e); + // Test rounding direction + test(-2.5, -3.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -2.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); - let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); - let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); - let r = _mm_round_ss::<_MM_FROUND_TO_ZERO>(a, b); - let e = _mm_setr_ps(-1.0, 3.5, 7.5, 15.5); + // Test that each element is rounded + let a = _mm_setr_ps(1.5, 3.5, 5.5, 7.5); + let e = _mm_setr_ps(1.0, 3.0, 5.0, 7.0); + let r = _mm_floor_ps(a); assert_eq_m128(r, e); - - // Assume round-to-nearest by default - let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5); - let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5); - let r = _mm_round_ss::<_MM_FROUND_CUR_DIRECTION>(a, b); - let e = _mm_setr_ps(-2.0, 3.5, 7.5, 15.5); + let r = _mm_round_ps::<_MM_FROUND_TO_NEG_INF>(a); assert_eq_m128(r, e); } - test_mm_round_ss(); + test_round_floor_f32(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_round_ceil_f32() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f32, res: f32) { + let a = _mm_setr_ps(3.5, 2.5, 1.5, 4.5); + let b = _mm_setr_ps(x, -1.5, -3.5, -2.5); + let e = _mm_setr_ps(res, 2.5, 1.5, 4.5); + let r = _mm_ceil_ss(a, b); + assert_eq_m128(r, e); + let r = _mm_round_ss::<_MM_FROUND_TO_POS_INF>(a, b); + assert_eq_m128(r, e); + + let a = _mm_set1_ps(x); + let e = _mm_set1_ps(res); + let r = _mm_ceil_ps(a); + assert_eq_m128(r, e); + let r = _mm_round_ps::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m128(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 2.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 3.0); + + // Test that each element is rounded + let a = _mm_setr_ps(1.5, 3.5, 5.5, 7.5); + let e = _mm_setr_ps(2.0, 4.0, 6.0, 8.0); + let r = _mm_ceil_ps(a); + assert_eq_m128(r, e); + let r = _mm_round_ps::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m128(r, e); + } + test_round_ceil_f32(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_round_trunc_f32() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f32, res: f32) { + let a = _mm_setr_ps(3.5, 2.5, 1.5, 4.5); + let b = _mm_setr_ps(x, -1.5, -3.5, -2.5); + let e = _mm_setr_ps(res, 2.5, 1.5, 4.5); + let r = _mm_round_ss::<_MM_FROUND_TO_ZERO>(a, b); + assert_eq_m128(r, e); + + let a = _mm_set1_ps(x); + let e = _mm_set1_ps(res); + let r = _mm_round_ps::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m128(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm_setr_ps(1.5, 3.5, 5.5, 7.5); + let e = _mm_setr_ps(1.0, 3.0, 5.0, 7.0); + let r = _mm_round_ps::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m128(r, e); + } + test_round_trunc_f32(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_round_nearest_f64() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f64, res: f64) { + let a = _mm_setr_pd(3.5, 2.5); + let b = _mm_setr_pd(x, -1.5); + let e = _mm_setr_pd(res, 2.5); + let r = _mm_round_sd::<_MM_FROUND_TO_NEAREST_INT>(a, b); + assert_eq_m128d(r, e); + // Assume round-to-nearest by default + let r = _mm_round_sd::<_MM_FROUND_CUR_DIRECTION>(a, b); + assert_eq_m128d(r, e); + + let a = _mm_set1_pd(x); + let e = _mm_set1_pd(res); + let r = _mm_round_pd::<_MM_FROUND_TO_NEAREST_INT>(a); + assert_eq_m128d(r, e); + // Assume round-to-nearest by default + let r = _mm_round_pd::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m128d(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm_setr_pd(1.5, 3.5); + let e = _mm_setr_pd(2.0, 4.0); + let r = _mm_round_pd::<_MM_FROUND_TO_NEAREST_INT>(a); + assert_eq_m128d(r, e); + // Assume round-to-nearest by default + let r = _mm_round_pd::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m128d(r, e); + } + test_round_nearest_f64(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_round_floor_f64() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f64, res: f64) { + let a = _mm_setr_pd(3.5, 2.5); + let b = _mm_setr_pd(x, -1.5); + let e = _mm_setr_pd(res, 2.5); + let r = _mm_floor_sd(a, b); + assert_eq_m128d(r, e); + let r = _mm_round_sd::<_MM_FROUND_TO_NEG_INF>(a, b); + assert_eq_m128d(r, e); + + let a = _mm_set1_pd(x); + let e = _mm_set1_pd(res); + let r = _mm_floor_pd(a); + assert_eq_m128d(r, e); + let r = _mm_round_pd::<_MM_FROUND_TO_NEG_INF>(a); + assert_eq_m128d(r, e); + } + + // Test rounding direction + test(-2.5, -3.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -2.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm_setr_pd(1.5, 3.5); + let e = _mm_setr_pd(1.0, 3.0); + let r = _mm_floor_pd(a); + assert_eq_m128d(r, e); + let r = _mm_round_pd::<_MM_FROUND_TO_NEG_INF>(a); + assert_eq_m128d(r, e); + } + test_round_floor_f64(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_round_ceil_f64() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f64, res: f64) { + let a = _mm_setr_pd(3.5, 2.5); + let b = _mm_setr_pd(x, -1.5); + let e = _mm_setr_pd(res, 2.5); + let r = _mm_ceil_sd(a, b); + assert_eq_m128d(r, e); + let r = _mm_round_sd::<_MM_FROUND_TO_POS_INF>(a, b); + assert_eq_m128d(r, e); + + let a = _mm_set1_pd(x); + let e = _mm_set1_pd(res); + let r = _mm_ceil_pd(a); + assert_eq_m128d(r, e); + let r = _mm_round_pd::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m128d(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 2.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 3.0); + + // Test that each element is rounded + let a = _mm_setr_pd(1.5, 3.5); + let e = _mm_setr_pd(2.0, 4.0); + let r = _mm_ceil_pd(a); + assert_eq_m128d(r, e); + let r = _mm_round_pd::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m128d(r, e); + } + test_round_ceil_f64(); + + #[target_feature(enable = "sse4.1")] + unsafe fn test_round_trunc_f64() { + #[target_feature(enable = "sse4.1")] + unsafe fn test(x: f64, res: f64) { + let a = _mm_setr_pd(3.5, 2.5); + let b = _mm_setr_pd(x, -1.5); + let e = _mm_setr_pd(res, 2.5); + let r = _mm_round_sd::<_MM_FROUND_TO_ZERO>(a, b); + assert_eq_m128d(r, e); + + let a = _mm_set1_pd(x); + let e = _mm_set1_pd(res); + let r = _mm_round_pd::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m128d(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm_setr_pd(1.5, 3.5); + let e = _mm_setr_pd(1.0, 3.0); + let r = _mm_round_pd::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m128d(r, e); + } + test_round_trunc_f64(); #[target_feature(enable = "sse4.1")] unsafe fn test_mm_minpos_epu16() {