From 524c16d38744a45d25b17542228e4f2f628f66ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Sun, 26 Nov 2023 11:27:41 +0100 Subject: [PATCH] Implement x86 AVX intrinsics --- src/tools/miri/src/shims/x86/avx.rs | 417 ++++++ src/tools/miri/src/shims/x86/mod.rs | 210 ++- .../miri/tests/pass/intrinsics-x86-avx.rs | 1269 +++++++++++++++++ 3 files changed, 1845 insertions(+), 51 deletions(-) create mode 100644 src/tools/miri/src/shims/x86/avx.rs diff --git a/src/tools/miri/src/shims/x86/avx.rs b/src/tools/miri/src/shims/x86/avx.rs new file mode 100644 index 00000000000..65de1607595 --- /dev/null +++ b/src/tools/miri/src/shims/x86/avx.rs @@ -0,0 +1,417 @@ +use rustc_apfloat::{ieee::Double, ieee::Single}; +use rustc_middle::mir; +use rustc_middle::ty::layout::LayoutOf as _; +use rustc_middle::ty::Ty; +use rustc_span::Symbol; +use rustc_target::spec::abi::Abi; + +use super::{ + bin_op_simd_float_all, conditional_dot_product, convert_float_to_int, horizontal_bin_op, + round_all, test_bits_masked, test_high_bits_masked, unary_op_ps, FloatBinOp, FloatUnaryOp, +}; +use crate::*; +use shims::foreign_items::EmulateForeignItemResult; + +impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} +pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: + crate::MiriInterpCxExt<'mir, 'tcx> +{ + fn emulate_x86_avx_intrinsic( + &mut self, + link_name: Symbol, + abi: Abi, + args: &[OpTy<'tcx, Provenance>], + dest: &PlaceTy<'tcx, Provenance>, + ) -> InterpResult<'tcx, EmulateForeignItemResult> { + let this = self.eval_context_mut(); + this.expect_target_feature_for_intrinsic(link_name, "avx")?; + // Prefix should have already been checked. + let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap(); + + match unprefixed_name { + // Used to implement _mm256_min_ps and _mm256_max_ps functions. + // Note that the semantics are a bit different from Rust simd_min + // and simd_max intrinsics regarding handling of NaN and -0.0: Rust + // matches the IEEE min/max operations, while x86 has different + // semantics. + "min.ps.256" | "max.ps.256" => { + let [left, right] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let which = match unprefixed_name { + "min.ps.256" => FloatBinOp::Min, + "max.ps.256" => FloatBinOp::Max, + _ => unreachable!(), + }; + + bin_op_simd_float_all::(this, which, left, right, dest)?; + } + // Used to implement _mm256_min_pd and _mm256_max_pd functions. + "min.pd.256" | "max.pd.256" => { + let [left, right] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let which = match unprefixed_name { + "min.pd.256" => FloatBinOp::Min, + "max.pd.256" => FloatBinOp::Max, + _ => unreachable!(), + }; + + bin_op_simd_float_all::(this, which, left, right, dest)?; + } + // Used to implement the _mm256_round_ps function. + // Rounds the elements of `op` according to `rounding`. + "round.ps.256" => { + let [op, rounding] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + round_all::(this, op, rounding, dest)?; + } + // Used to implement the _mm256_round_pd function. + // Rounds the elements of `op` according to `rounding`. + "round.pd.256" => { + let [op, rounding] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + round_all::(this, op, rounding, dest)?; + } + // Used to implement _mm256_{sqrt,rcp,rsqrt}_ps functions. + // Performs the operations on all components of `op`. + "sqrt.ps.256" | "rcp.ps.256" | "rsqrt.ps.256" => { + let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let which = match unprefixed_name { + "sqrt.ps.256" => FloatUnaryOp::Sqrt, + "rcp.ps.256" => FloatUnaryOp::Rcp, + "rsqrt.ps.256" => FloatUnaryOp::Rsqrt, + _ => unreachable!(), + }; + + unary_op_ps(this, which, op, dest)?; + } + // Used to implement the _mm256_dp_ps function. + "dp.ps.256" => { + let [left, right, imm] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + conditional_dot_product(this, left, right, imm, dest)?; + } + // Used to implement the _mm256_h{add,sub}_p{s,d} functions. + // Horizontally add/subtract adjacent floating point values + // in `left` and `right`. + "hadd.ps.256" | "hadd.pd.256" | "hsub.ps.256" | "hsub.pd.256" => { + let [left, right] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let which = match unprefixed_name { + "hadd.ps.256" | "hadd.pd.256" => mir::BinOp::Add, + "hsub.ps.256" | "hsub.pd.256" => mir::BinOp::Sub, + _ => unreachable!(), + }; + + horizontal_bin_op(this, which, /*saturating*/ false, left, right, dest)?; + } + // Used to implement the _mm256_cmp_ps function. + // Performs a comparison operation on each component of `left` + // and `right`. For each component, returns 0 if false or u32::MAX + // if true. + "cmp.ps.256" => { + let [left, right, imm] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let which = + FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?; + + bin_op_simd_float_all::(this, which, left, right, dest)?; + } + // Used to implement the _mm256_cmp_pd function. + // Performs a comparison operation on each component of `left` + // and `right`. For each component, returns 0 if false or u64::MAX + // if true. + "cmp.pd.256" => { + let [left, right, imm] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let which = + FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?; + + bin_op_simd_float_all::(this, which, left, right, dest)?; + } + // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32 + // and _mm256_cvttpd_epi32 functions. + // Converts packed f32/f64 to packed i32. + "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => { + let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let rnd = match unprefixed_name { + // "current SSE rounding mode", assume nearest + "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven, + // always truncate + "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero, + _ => unreachable!(), + }; + + convert_float_to_int(this, op, rnd, dest)?; + } + // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions. + // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit + // chunk is shuffled independently: this means that we view the vector as a + // sequence of 4-element arrays, and we shuffle each of these arrays, where + // `control` determines which element of the current `data` array is written. + "vpermilvar.ps" | "vpermilvar.ps.256" => { + let [data, control] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (data, data_len) = this.operand_to_simd(data)?; + let (control, control_len) = this.operand_to_simd(control)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, data_len); + assert_eq!(dest_len, control_len); + + for i in 0..dest_len { + let control = this.project_index(&control, i)?; + + // Each 128-bit chunk is shuffled independently. Since each chunk contains + // four 32-bit elements, only two bits from `control` are used. To read the + // value from the current chunk, add the destination index truncated to a multiple + // of 4. + let chunk_base = i & !0b11; + let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11) + .checked_add(chunk_base) + .unwrap(); + + this.copy_op( + &this.project_index(&data, src_i)?, + &this.project_index(&dest, i)?, + )?; + } + } + // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions. + // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit + // chunk is shuffled independently: this means that we view the vector as + // a sequence of 2-element arrays, and we shuffle each of these arrays, + // where `right` determines which element of the current `left` array is + // written. + "vpermilvar.pd" | "vpermilvar.pd.256" => { + let [data, control] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (data, data_len) = this.operand_to_simd(data)?; + let (control, control_len) = this.operand_to_simd(control)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, data_len); + assert_eq!(dest_len, control_len); + + for i in 0..dest_len { + let control = this.project_index(&control, i)?; + + // Each 128-bit chunk is shuffled independently. Since each chunk contains + // two 64-bit elements, only the second bit from `control` is used (yes, the + // second instead of the first, ask Intel). To read the value from the current + // chunk, add the destination index truncated to a multiple of 2. + let chunk_base = i & !1; + let src_i = ((this.read_scalar(&control)?.to_u64()? >> 1) & 1) + .checked_add(chunk_base) + .unwrap(); + + this.copy_op( + &this.project_index(&data, src_i)?, + &this.project_index(&dest, i)?, + )?; + } + } + // Used to implement the _mm256_permute2f128_ps, _mm256_permute2f128_pd and + // _mm256_permute2f128_si256 functions. Regardless of the suffix in the name + // thay all can be considered to operate on vectors of 128-bit elements. + // For each 128-bit element of `dest`, copies one from `left`, `right` or + // zero, according to `imm`. + "vperm2f128.ps.256" | "vperm2f128.pd.256" | "vperm2f128.si.256" => { + let [left, right, imm] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + assert_eq!(dest.layout, left.layout); + assert_eq!(dest.layout, right.layout); + assert_eq!(dest.layout.size.bits(), 256); + + // Transmute to `[u128; 2]` to process each 128-bit chunk independently. + let u128x2_layout = + this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u128, 2))?; + let left = left.transmute(u128x2_layout, this)?; + let right = right.transmute(u128x2_layout, this)?; + let dest = dest.transmute(u128x2_layout, this)?; + + let imm = this.read_scalar(imm)?.to_u8()?; + + for i in 0..2 { + let dest = this.project_index(&dest, i)?; + + let imm = match i { + 0 => imm & 0xF, + 1 => imm >> 4, + _ => unreachable!(), + }; + if imm & 0b100 != 0 { + this.write_scalar(Scalar::from_u128(0), &dest)?; + } else { + let src = match imm { + 0b00 => this.project_index(&left, 0)?, + 0b01 => this.project_index(&left, 1)?, + 0b10 => this.project_index(&right, 0)?, + 0b11 => this.project_index(&right, 1)?, + _ => unreachable!(), + }; + this.copy_op(&src, &dest)?; + } + } + } + // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps + // and _mm256_maskload_pd functions. + // For the element `i`, if the high bit of the `i`-th element of `mask` + // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is + // loaded. + "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => { + let [ptr, mask] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + mask_load(this, ptr, mask, dest)?; + } + // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps + // and _mm256_maskstore_pd functions. + // For the element `i`, if the high bit of the element `i`-th of `mask` + // is one, it is stored into `ptr.wapping_add(i)`. + // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores. + "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => { + let [ptr, mask, value] = + this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + mask_store(this, ptr, mask, value)?; + } + // Used to implement the _mm256_lddqu_si256 function. + // Reads a 256-bit vector from an unaligned pointer. This intrinsic + // is expected to perform better than a regular unaligned read when + // the data crosses a cache line, but for Miri this is just a regular + // unaligned read. + "ldu.dq.256" => { + let [src_ptr] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + let src_ptr = this.read_pointer(src_ptr)?; + let dest = dest.force_mplace(this)?; + + // Unaligned copy, which is what we want. + this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?; + } + // Used to implement the _mm256_testz_si256, _mm256_testc_si256 and + // _mm256_testnzc_si256 functions. + // Tests `op & mask == 0`, `op & mask == mask` or + // `op & mask != 0 && op & mask != mask` + "ptestz.256" | "ptestc.256" | "ptestnzc.256" => { + let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (all_zero, masked_set) = test_bits_masked(this, op, mask)?; + let res = match unprefixed_name { + "ptestz.256" => all_zero, + "ptestc.256" => masked_set, + "ptestnzc.256" => !all_zero && !masked_set, + _ => unreachable!(), + }; + + this.write_scalar(Scalar::from_i32(res.into()), dest)?; + } + // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd + // _mm_testz_pd, _mm_testc_pd, _mm_testnzc_pd, _mm256_testz_ps, + // _mm256_testc_ps, _mm256_testnzc_ps, _mm_testz_ps, _mm_testc_ps and + // _mm_testnzc_ps functions. + // Calculates two booleans: + // `direct`, which is true when the highest bit of each element of `op & mask` is zero. + // `negated`, which is true when the highest bit of each element of `!op & mask` is zero. + // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc) + "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestz.pd" | "vtestc.pd" + | "vtestnzc.pd" | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256" + | "vtestz.ps" | "vtestc.ps" | "vtestnzc.ps" => { + let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; + + let (direct, negated) = test_high_bits_masked(this, op, mask)?; + let res = match unprefixed_name { + "vtestz.pd.256" | "vtestz.pd" | "vtestz.ps.256" | "vtestz.ps" => direct, + "vtestc.pd.256" | "vtestc.pd" | "vtestc.ps.256" | "vtestc.ps" => negated, + "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" => + !direct && !negated, + _ => unreachable!(), + }; + + this.write_scalar(Scalar::from_i32(res.into()), dest)?; + } + _ => return Ok(EmulateForeignItemResult::NotSupported), + } + Ok(EmulateForeignItemResult::NeedsJumping) + } +} + +/// Conditionally loads from `ptr` according the high bit of each +/// element of `mask`. `ptr` does not need to be aligned. +fn mask_load<'tcx>( + this: &mut crate::MiriInterpCx<'_, 'tcx>, + ptr: &OpTy<'tcx, Provenance>, + mask: &OpTy<'tcx, Provenance>, + dest: &PlaceTy<'tcx, Provenance>, +) -> InterpResult<'tcx, ()> { + let (mask, mask_len) = this.operand_to_simd(mask)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, mask_len); + + let mask_item_size = mask.layout.field(this, 0).size; + let high_bit_offset = mask_item_size.bits().checked_sub(1).unwrap(); + + let ptr = this.read_pointer(ptr)?; + for i in 0..dest_len { + let mask = this.project_index(&mask, i)?; + let dest = this.project_index(&dest, i)?; + + if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 { + // Size * u64 is implemented as always checked + #[allow(clippy::arithmetic_side_effects)] + let ptr = ptr.wrapping_offset(dest.layout.size * i, &this.tcx); + // Unaligned copy, which is what we want. + this.mem_copy(ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?; + } else { + this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?; + } + } + + Ok(()) +} + +/// Conditionally stores into `ptr` according the high bit of each +/// element of `mask`. `ptr` does not need to be aligned. +fn mask_store<'tcx>( + this: &mut crate::MiriInterpCx<'_, 'tcx>, + ptr: &OpTy<'tcx, Provenance>, + mask: &OpTy<'tcx, Provenance>, + value: &OpTy<'tcx, Provenance>, +) -> InterpResult<'tcx, ()> { + let (mask, mask_len) = this.operand_to_simd(mask)?; + let (value, value_len) = this.operand_to_simd(value)?; + + assert_eq!(value_len, mask_len); + + let mask_item_size = mask.layout.field(this, 0).size; + let high_bit_offset = mask_item_size.bits().checked_sub(1).unwrap(); + + let ptr = this.read_pointer(ptr)?; + for i in 0..value_len { + let mask = this.project_index(&mask, i)?; + let value = this.project_index(&value, i)?; + + if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 { + // Size * u64 is implemented as always checked + #[allow(clippy::arithmetic_side_effects)] + let ptr = ptr.wrapping_offset(value.layout.size * i, &this.tcx); + // Unaligned copy, which is what we want. + this.mem_copy(value.ptr(), ptr, value.layout.size, /*nonoverlapping*/ true)?; + } + } + + Ok(()) +} diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs index 115be1d6f22..9cfee20014f 100644 --- a/src/tools/miri/src/shims/x86/mod.rs +++ b/src/tools/miri/src/shims/x86/mod.rs @@ -1,6 +1,8 @@ use rand::Rng as _; -use rustc_apfloat::{ieee::Single, Float as _}; +use rustc_apfloat::{ieee::Single, Float}; +use rustc_middle::ty::layout::LayoutOf as _; +use rustc_middle::ty::Ty; use rustc_middle::{mir, ty}; use rustc_span::Symbol; use rustc_target::abi::Size; @@ -11,6 +13,7 @@ use helpers::bool_to_simd_element; use shims::foreign_items::EmulateForeignItemResult; mod aesni; +mod avx; mod sse; mod sse2; mod sse3; @@ -115,6 +118,11 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: this, link_name, abi, args, dest, ); } + name if name.starts_with("avx.") => { + return avx::EvalContextExt::emulate_x86_avx_intrinsic( + this, link_name, abi, args, dest, + ); + } _ => return Ok(EmulateForeignItemResult::NotSupported), } @@ -563,8 +571,65 @@ fn convert_float_to_int<'tcx>( Ok(()) } +/// Splits `left`, `right` and `dest` (which must be SIMD vectors) +/// into 128-bit chuncks. +/// +/// `left`, `right` and `dest` cannot have different types. +/// +/// Returns a tuple where: +/// * The first element is the number of 128-bit chunks (let's call it `N`). +/// * The second element is the number of elements per chunk (let's call it `M`). +/// * The third element is the `left` vector split into chunks, i.e, it's +/// type is `[[T; M]; N]`. +/// * The fourth element is the `right` vector split into chunks. +/// * The fifth element is the `dest` vector split into chunks. +fn split_simd_to_128bit_chunks<'tcx>( + this: &mut crate::MiriInterpCx<'_, 'tcx>, + left: &OpTy<'tcx, Provenance>, + right: &OpTy<'tcx, Provenance>, + dest: &PlaceTy<'tcx, Provenance>, +) -> InterpResult< + 'tcx, + (u64, u64, MPlaceTy<'tcx, Provenance>, MPlaceTy<'tcx, Provenance>, MPlaceTy<'tcx, Provenance>), +> { + assert_eq!(dest.layout, left.layout); + assert_eq!(dest.layout, right.layout); + + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(dest_len, left_len); + assert_eq!(dest_len, right_len); + + assert_eq!(dest.layout.size.bits() % 128, 0); + let num_chunks = dest.layout.size.bits() / 128; + assert_eq!(dest_len.checked_rem(num_chunks), Some(0)); + let items_per_chunk = dest_len.checked_div(num_chunks).unwrap(); + + // Transmute to `[[T; items_per_chunk]; num_chunks]` + let element_layout = left.layout.field(this, 0); + let chunked_layout = this.layout_of(Ty::new_array( + this.tcx.tcx, + Ty::new_array(this.tcx.tcx, element_layout.ty, items_per_chunk), + num_chunks, + ))?; + let left = left.transmute(chunked_layout, this)?; + let right = right.transmute(chunked_layout, this)?; + let dest = dest.transmute(chunked_layout, this)?; + + Ok((num_chunks, items_per_chunk, left, right, dest)) +} + /// Horizontaly performs `which` operation on adjacent values of /// `left` and `right` SIMD vectors and stores the result in `dest`. +/// "Horizontal" means that the i-th output element is calculated +/// from the elements 2*i and 2*i+1 of the concatenation of `left` and +/// `right`. +/// +/// Each 128-bit chunk is treated independently (i.e., the value for +/// the is i-th 128-bit chunk of `dest` is calculated with the i-th +/// 128-bit chunks of `left` and `right`). fn horizontal_bin_op<'tcx>( this: &mut crate::MiriInterpCx<'_, 'tcx>, which: mir::BinOp, @@ -573,32 +638,34 @@ fn horizontal_bin_op<'tcx>( right: &OpTy<'tcx, Provenance>, dest: &PlaceTy<'tcx, Provenance>, ) -> InterpResult<'tcx, ()> { - let (left, left_len) = this.operand_to_simd(left)?; - let (right, right_len) = this.operand_to_simd(right)?; - let (dest, dest_len) = this.place_to_simd(dest)?; + let (num_chunks, items_per_chunk, left, right, dest) = + split_simd_to_128bit_chunks(this, left, right, dest)?; - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - assert_eq!(dest_len % 2, 0); + let middle = items_per_chunk / 2; + for i in 0..num_chunks { + let left = this.project_index(&left, i)?; + let right = this.project_index(&right, i)?; + let dest = this.project_index(&dest, i)?; - let middle = dest_len / 2; - for i in 0..dest_len { - // `i` is the index in `dest` - // `j` is the index of the 2-item chunk in `src` - let (j, src) = - if i < middle { (i, &left) } else { (i.checked_sub(middle).unwrap(), &right) }; - // `base_i` is the index of the first item of the 2-item chunk in `src` - let base_i = j.checked_mul(2).unwrap(); - let lhs = this.read_immediate(&this.project_index(src, base_i)?)?; - let rhs = this.read_immediate(&this.project_index(src, base_i.checked_add(1).unwrap())?)?; + for j in 0..items_per_chunk { + // `j` is the index in `dest` + // `k` is the index of the 2-item chunk in `src` + let (k, src) = + if j < middle { (j, &left) } else { (j.checked_sub(middle).unwrap(), &right) }; + // `base_i` is the index of the first item of the 2-item chunk in `src` + let base_i = k.checked_mul(2).unwrap(); + let lhs = this.read_immediate(&this.project_index(src, base_i)?)?; + let rhs = + this.read_immediate(&this.project_index(src, base_i.checked_add(1).unwrap())?)?; - let res = if saturating { - Immediate::from(this.saturating_arith(which, &lhs, &rhs)?) - } else { - *this.wrapping_binary_op(which, &lhs, &rhs)? - }; + let res = if saturating { + Immediate::from(this.saturating_arith(which, &lhs, &rhs)?) + } else { + *this.wrapping_binary_op(which, &lhs, &rhs)? + }; - this.write_immediate(res, &this.project_index(&dest, i)?)?; + this.write_immediate(res, &this.project_index(&dest, j)?)?; + } } Ok(()) @@ -608,6 +675,10 @@ fn horizontal_bin_op<'tcx>( /// `left` and `right` using the high 4 bits in `imm`, sums the calculated /// products (up to 4), and conditionally stores the sum in `dest` using /// the low 4 bits of `imm`. +/// +/// Each 128-bit chunk is treated independently (i.e., the value for +/// the is i-th 128-bit chunk of `dest` is calculated with the i-th +/// 128-bit blocks of `left` and `right`). fn conditional_dot_product<'tcx>( this: &mut crate::MiriInterpCx<'_, 'tcx>, left: &OpTy<'tcx, Provenance>, @@ -615,39 +686,43 @@ fn conditional_dot_product<'tcx>( imm: &OpTy<'tcx, Provenance>, dest: &PlaceTy<'tcx, Provenance>, ) -> InterpResult<'tcx, ()> { - let (left, left_len) = this.operand_to_simd(left)?; - let (right, right_len) = this.operand_to_simd(right)?; - let (dest, dest_len) = this.place_to_simd(dest)?; + let (num_chunks, items_per_chunk, left, right, dest) = + split_simd_to_128bit_chunks(this, left, right, dest)?; - assert_eq!(left_len, right_len); - assert!(dest_len <= 4); + let element_layout = left.layout.field(this, 0).field(this, 0); + assert!(items_per_chunk <= 4); - let imm = this.read_scalar(imm)?.to_u8()?; + // `imm` is a `u8` for SSE4.1 or an `i32` for AVX :/ + let imm = this.read_scalar(imm)?.to_uint(imm.layout.size)?; - let element_layout = left.layout.field(this, 0); - - // Calculate dot product - // Elements are floating point numbers, but we can use `from_int` - // because the representation of 0.0 is all zero bits. - let mut sum = ImmTy::from_int(0u8, element_layout); - for i in 0..left_len { - if imm & (1 << i.checked_add(4).unwrap()) != 0 { - let left = this.read_immediate(&this.project_index(&left, i)?)?; - let right = this.read_immediate(&this.project_index(&right, i)?)?; - - let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?; - sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?; - } - } - - // Write to destination (conditioned to imm) - for i in 0..dest_len { + for i in 0..num_chunks { + let left = this.project_index(&left, i)?; + let right = this.project_index(&right, i)?; let dest = this.project_index(&dest, i)?; - if imm & (1 << i) != 0 { - this.write_immediate(*sum, &dest)?; - } else { - this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?; + // Calculate dot product + // Elements are floating point numbers, but we can use `from_int` + // for the initial value because the representation of 0.0 is all zero bits. + let mut sum = ImmTy::from_int(0u8, element_layout); + for j in 0..items_per_chunk { + if imm & (1 << j.checked_add(4).unwrap()) != 0 { + let left = this.read_immediate(&this.project_index(&left, j)?)?; + let right = this.read_immediate(&this.project_index(&right, j)?)?; + + let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?; + sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?; + } + } + + // Write to destination (conditioned to imm) + for j in 0..items_per_chunk { + let dest = this.project_index(&dest, j)?; + + if imm & (1 << j) != 0 { + this.write_immediate(*sum, &dest)?; + } else { + this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?; + } } } @@ -684,3 +759,36 @@ fn test_bits_masked<'tcx>( Ok((all_zero, masked_set)) } + +/// Calculates two booleans. +/// +/// The first is true when the highest bit of each element of `op & mask` is zero. +/// The second is true when the highest bit of each element of `!op & mask` is zero. +fn test_high_bits_masked<'tcx>( + this: &crate::MiriInterpCx<'_, 'tcx>, + op: &OpTy<'tcx, Provenance>, + mask: &OpTy<'tcx, Provenance>, +) -> InterpResult<'tcx, (bool, bool)> { + assert_eq!(op.layout, mask.layout); + + let (op, op_len) = this.operand_to_simd(op)?; + let (mask, mask_len) = this.operand_to_simd(mask)?; + + assert_eq!(op_len, mask_len); + + let high_bit_offset = op.layout.field(this, 0).size.bits().checked_sub(1).unwrap(); + + let mut direct = true; + let mut negated = true; + for i in 0..op_len { + let op = this.project_index(&op, i)?; + let mask = this.project_index(&mask, i)?; + + let op = this.read_scalar(&op)?.to_uint(op.layout.size)?; + let mask = this.read_scalar(&mask)?.to_uint(mask.layout.size)?; + direct &= (op & mask) >> high_bit_offset == 0; + negated &= (!op & mask) >> high_bit_offset == 0; + } + + Ok((direct, negated)) +} diff --git a/src/tools/miri/tests/pass/intrinsics-x86-avx.rs b/src/tools/miri/tests/pass/intrinsics-x86-avx.rs index 933e3d4153a..7d43cc596ae 100644 --- a/src/tools/miri/tests/pass/intrinsics-x86-avx.rs +++ b/src/tools/miri/tests/pass/intrinsics-x86-avx.rs @@ -25,6 +25,528 @@ fn main() { #[target_feature(enable = "avx")] unsafe fn test_avx() { + // Mostly copied from library/stdarch/crates/core_arch/src/x86/avx.rs + + macro_rules! assert_approx_eq { + ($a:expr, $b:expr, $eps:expr) => {{ + let (a, b) = (&$a, &$b); + assert!( + (*a - *b).abs() < $eps, + "assertion failed: `(left !== right)` \ + (left: `{:?}`, right: `{:?}`, expect diff: `{:?}`, real diff: `{:?}`)", + *a, + *b, + $eps, + (*a - *b).abs() + ); + }}; + } + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_max_pd() { + let a = _mm256_setr_pd(1., 4., 5., 8.); + let b = _mm256_setr_pd(2., 3., 6., 7.); + let r = _mm256_max_pd(a, b); + let e = _mm256_setr_pd(2., 4., 6., 8.); + assert_eq_m256d(r, e); + // > If the values being compared are both 0.0s (of either sign), the + // > value in the second operand (source operand) is returned. + let w = _mm256_max_pd(_mm256_set1_pd(0.0), _mm256_set1_pd(-0.0)); + let x = _mm256_max_pd(_mm256_set1_pd(-0.0), _mm256_set1_pd(0.0)); + let wu: [u64; 4] = transmute(w); + let xu: [u64; 4] = transmute(x); + assert_eq!(wu, [0x8000_0000_0000_0000u64; 4]); + assert_eq!(xu, [0u64; 4]); + // > If only one value is a NaN (SNaN or QNaN) for this instruction, the + // > second operand (source operand), either a NaN or a valid + // > floating-point value, is written to the result. + let y = _mm256_max_pd(_mm256_set1_pd(f64::NAN), _mm256_set1_pd(0.0)); + let z = _mm256_max_pd(_mm256_set1_pd(0.0), _mm256_set1_pd(f64::NAN)); + let yf: [f64; 4] = transmute(y); + let zf: [f64; 4] = transmute(z); + assert_eq!(yf, [0.0; 4]); + assert!(zf.iter().all(|f| f.is_nan()), "{:?}", zf); + } + test_mm256_max_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_max_ps() { + let a = _mm256_setr_ps(1., 4., 5., 8., 9., 12., 13., 16.); + let b = _mm256_setr_ps(2., 3., 6., 7., 10., 11., 14., 15.); + let r = _mm256_max_ps(a, b); + let e = _mm256_setr_ps(2., 4., 6., 8., 10., 12., 14., 16.); + assert_eq_m256(r, e); + // > If the values being compared are both 0.0s (of either sign), the + // > value in the second operand (source operand) is returned. + let w = _mm256_max_ps(_mm256_set1_ps(0.0), _mm256_set1_ps(-0.0)); + let x = _mm256_max_ps(_mm256_set1_ps(-0.0), _mm256_set1_ps(0.0)); + let wu: [u32; 8] = transmute(w); + let xu: [u32; 8] = transmute(x); + assert_eq!(wu, [0x8000_0000u32; 8]); + assert_eq!(xu, [0u32; 8]); + // > If only one value is a NaN (SNaN or QNaN) for this instruction, the + // > second operand (source operand), either a NaN or a valid + // > floating-point value, is written to the result. + let y = _mm256_max_ps(_mm256_set1_ps(f32::NAN), _mm256_set1_ps(0.0)); + let z = _mm256_max_ps(_mm256_set1_ps(0.0), _mm256_set1_ps(f32::NAN)); + let yf: [f32; 8] = transmute(y); + let zf: [f32; 8] = transmute(z); + assert_eq!(yf, [0.0; 8]); + assert!(zf.iter().all(|f| f.is_nan()), "{:?}", zf); + } + test_mm256_max_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_min_pd() { + let a = _mm256_setr_pd(1., 4., 5., 8.); + let b = _mm256_setr_pd(2., 3., 6., 7.); + let r = _mm256_min_pd(a, b); + let e = _mm256_setr_pd(1., 3., 5., 7.); + assert_eq_m256d(r, e); + // > If the values being compared are both 0.0s (of either sign), the + // > value in the second operand (source operand) is returned. + let w = _mm256_min_pd(_mm256_set1_pd(0.0), _mm256_set1_pd(-0.0)); + let x = _mm256_min_pd(_mm256_set1_pd(-0.0), _mm256_set1_pd(0.0)); + let wu: [u64; 4] = transmute(w); + let xu: [u64; 4] = transmute(x); + assert_eq!(wu, [0x8000_0000_0000_0000u64; 4]); + assert_eq!(xu, [0u64; 4]); + // > If only one value is a NaN (SNaN or QNaN) for this instruction, the + // > second operand (source operand), either a NaN or a valid + // > floating-point value, is written to the result. + let y = _mm256_min_pd(_mm256_set1_pd(f64::NAN), _mm256_set1_pd(0.0)); + let z = _mm256_min_pd(_mm256_set1_pd(0.0), _mm256_set1_pd(f64::NAN)); + let yf: [f64; 4] = transmute(y); + let zf: [f64; 4] = transmute(z); + assert_eq!(yf, [0.0; 4]); + assert!(zf.iter().all(|f| f.is_nan()), "{:?}", zf); + } + test_mm256_min_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_min_ps() { + let a = _mm256_setr_ps(1., 4., 5., 8., 9., 12., 13., 16.); + let b = _mm256_setr_ps(2., 3., 6., 7., 10., 11., 14., 15.); + let r = _mm256_min_ps(a, b); + let e = _mm256_setr_ps(1., 3., 5., 7., 9., 11., 13., 15.); + assert_eq_m256(r, e); + // > If the values being compared are both 0.0s (of either sign), the + // > value in the second operand (source operand) is returned. + let w = _mm256_min_ps(_mm256_set1_ps(0.0), _mm256_set1_ps(-0.0)); + let x = _mm256_min_ps(_mm256_set1_ps(-0.0), _mm256_set1_ps(0.0)); + let wu: [u32; 8] = transmute(w); + let xu: [u32; 8] = transmute(x); + assert_eq!(wu, [0x8000_0000u32; 8]); + assert_eq!(xu, [0u32; 8]); + // > If only one value is a NaN (SNaN or QNaN) for this instruction, the + // > second operand (source operand), either a NaN or a valid + // > floating-point value, is written to the result. + let y = _mm256_min_ps(_mm256_set1_ps(f32::NAN), _mm256_set1_ps(0.0)); + let z = _mm256_min_ps(_mm256_set1_ps(0.0), _mm256_set1_ps(f32::NAN)); + let yf: [f32; 8] = transmute(y); + let zf: [f32; 8] = transmute(z); + assert_eq!(yf, [0.0; 8]); + assert!(zf.iter().all(|f| f.is_nan()), "{:?}", zf); + } + test_mm256_min_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_nearest_f32() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f32, res: f32) { + let a = _mm256_set1_ps(x); + let e = _mm256_set1_ps(res); + let r = _mm256_round_ps::<_MM_FROUND_TO_NEAREST_INT>(a); + assert_eq_m256(r, e); + // Assume round-to-nearest by default + let r = _mm256_round_ps::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m256(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm256_setr_ps(1.5, 3.5, 5.5, 7.5, 9.5, 11.5, 13.5, 15.5); + let e = _mm256_setr_ps(2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0); + let r = _mm256_round_ps::<_MM_FROUND_TO_NEAREST_INT>(a); + assert_eq_m256(r, e); + // Assume round-to-nearest by default + let r = _mm256_round_ps::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m256(r, e); + } + test_round_nearest_f32(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_floor_f32() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f32, res: f32) { + let a = _mm256_set1_ps(x); + let e = _mm256_set1_ps(res); + let r = _mm256_floor_ps(a); + assert_eq_m256(r, e); + let r = _mm256_round_ps::<_MM_FROUND_TO_NEG_INF>(a); + assert_eq_m256(r, e); + } + + // Test rounding direction + test(-2.5, -3.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -2.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm256_setr_ps(1.5, 3.5, 5.5, 7.5, 9.5, 11.5, 13.5, 15.5); + let e = _mm256_setr_ps(1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0); + let r = _mm256_floor_ps(a); + assert_eq_m256(r, e); + let r = _mm256_round_ps::<_MM_FROUND_TO_NEG_INF>(a); + assert_eq_m256(r, e); + } + test_round_floor_f32(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_ceil_f32() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f32, res: f32) { + let a = _mm256_set1_ps(x); + let e = _mm256_set1_ps(res); + let r = _mm256_ceil_ps(a); + assert_eq_m256(r, e); + let r = _mm256_round_ps::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m256(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 2.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 3.0); + + // Test that each element is rounded + let a = _mm256_setr_ps(1.5, 3.5, 5.5, 7.5, 9.5, 11.5, 13.5, 15.5); + let e = _mm256_setr_ps(2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0); + let r = _mm256_ceil_ps(a); + assert_eq_m256(r, e); + let r = _mm256_round_ps::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m256(r, e); + } + test_round_ceil_f32(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_trunc_f32() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f32, res: f32) { + let a = _mm256_set1_ps(x); + let e = _mm256_set1_ps(res); + let r = _mm256_round_ps::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m256(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm256_setr_ps(1.5, 3.5, 5.5, 7.5, 9.5, 11.5, 13.5, 15.5); + let e = _mm256_setr_ps(1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0); + let r = _mm256_round_ps::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m256(r, e); + } + test_round_trunc_f32(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_nearest_f64() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f64, res: f64) { + let a = _mm256_set1_pd(x); + let e = _mm256_set1_pd(res); + let r = _mm256_round_pd::<_MM_FROUND_TO_NEAREST_INT>(a); + assert_eq_m256d(r, e); + // Assume round-to-nearest by default + let r = _mm256_round_pd::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m256d(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm256_setr_pd(1.5, 3.5, 5.5, 7.5); + let e = _mm256_setr_pd(2.0, 4.0, 6.0, 8.0); + let r = _mm256_round_pd::<_MM_FROUND_TO_NEAREST_INT>(a); + assert_eq_m256d(r, e); + // Assume round-to-nearest by default + let r = _mm256_round_pd::<_MM_FROUND_CUR_DIRECTION>(a); + assert_eq_m256d(r, e); + } + test_round_nearest_f64(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_floor_f64() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f64, res: f64) { + let a = _mm256_set1_pd(x); + let e = _mm256_set1_pd(res); + let r = _mm256_floor_pd(a); + assert_eq_m256d(r, e); + let r = _mm256_round_pd::<_MM_FROUND_TO_NEG_INF>(a); + assert_eq_m256d(r, e); + } + + // Test rounding direction + test(-2.5, -3.0); + test(-1.75, -2.0); + test(-1.5, -2.0); + test(-1.25, -2.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm256_setr_pd(1.5, 3.5, 5.5, 7.5); + let e = _mm256_setr_pd(1.0, 3.0, 5.0, 7.0); + let r = _mm256_floor_pd(a); + assert_eq_m256d(r, e); + let r = _mm256_round_pd::<_MM_FROUND_TO_NEG_INF>(a); + assert_eq_m256d(r, e); + } + test_round_floor_f64(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_ceil_f64() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f64, res: f64) { + let a = _mm256_set1_pd(x); + let e = _mm256_set1_pd(res); + let r = _mm256_ceil_pd(a); + assert_eq_m256d(r, e); + let r = _mm256_round_pd::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m256d(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 2.0); + test(1.5, 2.0); + test(1.75, 2.0); + test(2.5, 3.0); + + // Test that each element is rounded + let a = _mm256_setr_pd(1.5, 3.5, 5.5, 7.5); + let e = _mm256_setr_pd(2.0, 4.0, 6.0, 8.0); + let r = _mm256_ceil_pd(a); + assert_eq_m256d(r, e); + let r = _mm256_round_pd::<_MM_FROUND_TO_POS_INF>(a); + assert_eq_m256d(r, e); + } + test_round_ceil_f64(); + + #[target_feature(enable = "avx")] + unsafe fn test_round_trunc_f64() { + #[target_feature(enable = "avx")] + unsafe fn test(x: f64, res: f64) { + let a = _mm256_set1_pd(x); + let e = _mm256_set1_pd(res); + let r = _mm256_round_pd::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m256d(r, e); + } + + // Test rounding direction + test(-2.5, -2.0); + test(-1.75, -1.0); + test(-1.5, -1.0); + test(-1.25, -1.0); + test(-1.0, -1.0); + test(0.0, 0.0); + test(1.0, 1.0); + test(1.25, 1.0); + test(1.5, 1.0); + test(1.75, 1.0); + test(2.5, 2.0); + + // Test that each element is rounded + let a = _mm256_setr_pd(1.5, 3.5, 5.5, 7.5); + let e = _mm256_setr_pd(1.0, 3.0, 5.0, 7.0); + let r = _mm256_round_pd::<_MM_FROUND_TO_ZERO>(a); + assert_eq_m256d(r, e); + } + test_round_trunc_f64(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_sqrt_ps() { + let a = _mm256_setr_ps(4., 9., 16., 25., 4., 9., 16., 25.); + let r = _mm256_sqrt_ps(a); + let e = _mm256_setr_ps(2., 3., 4., 5., 2., 3., 4., 5.); + assert_eq_m256(r, e); + } + test_mm256_sqrt_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_rcp_ps() { + let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.); + let r = _mm256_rcp_ps(a); + #[rustfmt::skip] + let e = _mm256_setr_ps( + 0.99975586, 0.49987793, 0.33325195, 0.24993896, + 0.19995117, 0.16662598, 0.14282227, 0.12496948, + ); + let rel_err = 0.00048828125; + + let r: [f32; 8] = transmute(r); + let e: [f32; 8] = transmute(e); + for i in 0..8 { + assert_approx_eq!(r[i], e[i], 2. * rel_err); + } + } + test_mm256_rcp_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_rsqrt_ps() { + let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.); + let r = _mm256_rsqrt_ps(a); + #[rustfmt::skip] + let e = _mm256_setr_ps( + 0.99975586, 0.7069092, 0.5772705, 0.49987793, + 0.44714355, 0.40820313, 0.3779297, 0.3534546, + ); + let rel_err = 0.00048828125; + + let r: [f32; 8] = transmute(r); + let e: [f32; 8] = transmute(e); + for i in 0..8 { + assert_approx_eq!(r[i], e[i], 2. * rel_err); + } + } + test_mm256_rsqrt_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_dp_ps() { + let a = _mm256_setr_ps(4., 9., 16., 25., 4., 9., 16., 25.); + let b = _mm256_setr_ps(4., 3., 2., 5., 8., 9., 64., 50.); + let r = _mm256_dp_ps::<0xFF>(a, b); + let e = _mm256_setr_ps(200., 200., 200., 200., 2387., 2387., 2387., 2387.); + assert_eq_m256(r, e); + } + test_mm256_dp_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_hadd_pd() { + let a = _mm256_setr_pd(4., 9., 16., 25.); + let b = _mm256_setr_pd(4., 3., 2., 5.); + let r = _mm256_hadd_pd(a, b); + let e = _mm256_setr_pd(13., 7., 41., 7.); + assert_eq_m256d(r, e); + + let a = _mm256_setr_pd(1., 2., 3., 4.); + let b = _mm256_setr_pd(5., 6., 7., 8.); + let r = _mm256_hadd_pd(a, b); + let e = _mm256_setr_pd(3., 11., 7., 15.); + assert_eq_m256d(r, e); + } + test_mm256_hadd_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_hadd_ps() { + let a = _mm256_setr_ps(4., 9., 16., 25., 4., 9., 16., 25.); + let b = _mm256_setr_ps(4., 3., 2., 5., 8., 9., 64., 50.); + let r = _mm256_hadd_ps(a, b); + let e = _mm256_setr_ps(13., 41., 7., 7., 13., 41., 17., 114.); + assert_eq_m256(r, e); + + let a = _mm256_setr_ps(1., 2., 3., 4., 1., 2., 3., 4.); + let b = _mm256_setr_ps(5., 6., 7., 8., 5., 6., 7., 8.); + let r = _mm256_hadd_ps(a, b); + let e = _mm256_setr_ps(3., 7., 11., 15., 3., 7., 11., 15.); + assert_eq_m256(r, e); + } + test_mm256_hadd_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_hsub_pd() { + let a = _mm256_setr_pd(4., 9., 16., 25.); + let b = _mm256_setr_pd(4., 3., 2., 5.); + let r = _mm256_hsub_pd(a, b); + let e = _mm256_setr_pd(-5., 1., -9., -3.); + assert_eq_m256d(r, e); + + let a = _mm256_setr_pd(1., 2., 3., 4.); + let b = _mm256_setr_pd(5., 6., 7., 8.); + let r = _mm256_hsub_pd(a, b); + let e = _mm256_setr_pd(-1., -1., -1., -1.); + assert_eq_m256d(r, e); + } + test_mm256_hsub_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_hsub_ps() { + let a = _mm256_setr_ps(4., 9., 16., 25., 4., 9., 16., 25.); + let b = _mm256_setr_ps(4., 3., 2., 5., 8., 9., 64., 50.); + let r = _mm256_hsub_ps(a, b); + let e = _mm256_setr_ps(-5., -9., 1., -3., -5., -9., -1., 14.); + assert_eq_m256(r, e); + + let a = _mm256_setr_ps(1., 2., 3., 4., 1., 2., 3., 4.); + let b = _mm256_setr_ps(5., 6., 7., 8., 5., 6., 7., 8.); + let r = _mm256_hsub_ps(a, b); + let e = _mm256_setr_ps(-1., -1., -1., -1., -1., -1., -1., -1.); + assert_eq_m256(r, e); + } + test_mm256_hsub_ps(); + fn expected_cmp(imm: i32, lhs: F, rhs: F, if_t: F, if_f: F) -> F { let res = match imm { _CMP_EQ_OQ => lhs == rhs, @@ -135,12 +657,54 @@ unsafe fn test_avx() { } } + #[target_feature(enable = "avx")] + unsafe fn test_mm256_cmp_ps() { + let values = [ + (1.0, 1.0), + (0.0, 1.0), + (1.0, 0.0), + (f32::NAN, 0.0), + (0.0, f32::NAN), + (f32::NAN, f32::NAN), + ]; + + for (lhs, rhs) in values { + let a = _mm256_set1_ps(lhs); + let b = _mm256_set1_ps(rhs); + let r: [u32; 8] = transmute(_mm256_cmp_ps::(a, b)); + let e: [u32; 8] = transmute(_mm256_set1_ps(expected_cmp_f32(IMM, lhs, rhs))); + assert_eq!(r, e); + } + } + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_cmp_pd() { + let values = [ + (1.0, 1.0), + (0.0, 1.0), + (1.0, 0.0), + (f64::NAN, 0.0), + (0.0, f64::NAN), + (f64::NAN, f64::NAN), + ]; + + for (lhs, rhs) in values { + let a = _mm256_set1_pd(lhs); + let b = _mm256_set1_pd(rhs); + let r: [u64; 4] = transmute(_mm256_cmp_pd::(a, b)); + let e: [u64; 4] = transmute(_mm256_set1_pd(expected_cmp_f64(IMM, lhs, rhs))); + assert_eq!(r, e); + } + } + #[target_feature(enable = "avx")] unsafe fn test_cmp() { test_mm_cmp_ss::(); test_mm_cmp_ps::(); test_mm_cmp_sd::(); test_mm_cmp_pd::(); + test_mm256_cmp_ps::(); + test_mm256_cmp_pd::(); } test_cmp::<_CMP_EQ_OQ>(); @@ -159,4 +723,709 @@ unsafe fn test_avx() { test_cmp::<_CMP_GE_OS>(); test_cmp::<_CMP_GT_OS>(); test_cmp::<_CMP_TRUE_US>(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_cvtps_epi32() { + let a = _mm256_setr_ps(4., 9., 16., 25., 4., 9., 16., 25.); + let r = _mm256_cvtps_epi32(a); + let e = _mm256_setr_epi32(4, 9, 16, 25, 4, 9, 16, 25); + assert_eq_m256i(r, e); + + let a = _mm256_setr_ps( + f32::NEG_INFINITY, + f32::INFINITY, + f32::MIN, + f32::MAX, + f32::NAN, + f32::NAN, + f32::NAN, + f32::NAN, + ); + let r = _mm256_cvtps_epi32(a); + assert_eq_m256i(r, _mm256_set1_epi32(i32::MIN)); + } + test_mm256_cvtps_epi32(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_cvttps_epi32() { + let a = _mm256_setr_ps(4., 9., 16., 25., 4., 9., 16., 25.); + let r = _mm256_cvttps_epi32(a); + let e = _mm256_setr_epi32(4, 9, 16, 25, 4, 9, 16, 25); + assert_eq_m256i(r, e); + + let a = _mm256_setr_ps( + f32::NEG_INFINITY, + f32::INFINITY, + f32::MIN, + f32::MAX, + f32::NAN, + f32::NAN, + f32::NAN, + f32::NAN, + ); + let r = _mm256_cvttps_epi32(a); + assert_eq_m256i(r, _mm256_set1_epi32(i32::MIN)); + } + test_mm256_cvttps_epi32(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_cvtpd_epi32() { + let a = _mm256_setr_pd(4., 9., 16., 25.); + let r = _mm256_cvtpd_epi32(a); + let e = _mm_setr_epi32(4, 9, 16, 25); + assert_eq_m128i(r, e); + + let a = _mm256_setr_pd(f64::NEG_INFINITY, f64::INFINITY, f64::MIN, f64::MAX); + let r = _mm256_cvtpd_epi32(a); + assert_eq_m128i(r, _mm_set1_epi32(i32::MIN)); + } + test_mm256_cvtpd_epi32(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_cvttpd_epi32() { + let a = _mm256_setr_pd(4., 9., 16., 25.); + let r = _mm256_cvttpd_epi32(a); + let e = _mm_setr_epi32(4, 9, 16, 25); + assert_eq_m128i(r, e); + + let a = _mm256_setr_pd(f64::NEG_INFINITY, f64::INFINITY, f64::MIN, f64::MAX); + let r = _mm256_cvttpd_epi32(a); + assert_eq_m128i(r, _mm_set1_epi32(i32::MIN)); + } + test_mm256_cvttpd_epi32(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_permutevar_ps() { + let a = _mm_setr_ps(4., 3., 2., 5.); + let b = _mm_setr_epi32(1, 2, 3, 4); + let r = _mm_permutevar_ps(a, b); + let e = _mm_setr_ps(3., 2., 5., 4.); + assert_eq_m128(r, e); + } + test_mm_permutevar_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_permutevar_ps() { + let a = _mm256_setr_ps(4., 3., 2., 5., 8., 9., 64., 50.); + let b = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 8); + let r = _mm256_permutevar_ps(a, b); + let e = _mm256_setr_ps(3., 2., 5., 4., 9., 64., 50., 8.); + assert_eq_m256(r, e); + } + test_mm256_permutevar_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_permutevar_pd() { + let a = _mm_setr_pd(4., 3.); + let b = _mm_setr_epi64x(3, 0); + let r = _mm_permutevar_pd(a, b); + let e = _mm_setr_pd(3., 4.); + assert_eq_m128d(r, e); + } + test_mm_permutevar_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_permutevar_pd() { + let a = _mm256_setr_pd(4., 3., 2., 5.); + let b = _mm256_setr_epi64x(1, 2, 3, 4); + let r = _mm256_permutevar_pd(a, b); + let e = _mm256_setr_pd(4., 3., 5., 2.); + assert_eq_m256d(r, e); + } + test_mm256_permutevar_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_permute2f128_ps() { + let a = _mm256_setr_ps(1., 2., 3., 4., 1., 2., 3., 4.); + let b = _mm256_setr_ps(5., 6., 7., 8., 5., 6., 7., 8.); + let r = _mm256_permute2f128_ps::<0x13>(a, b); + let e = _mm256_setr_ps(5., 6., 7., 8., 1., 2., 3., 4.); + assert_eq_m256(r, e); + + let r = _mm256_permute2f128_ps::<0x44>(a, b); + let e = _mm256_setr_ps(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + assert_eq_m256(r, e); + } + test_mm256_permute2f128_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_permute2f128_pd() { + let a = _mm256_setr_pd(1., 2., 3., 4.); + let b = _mm256_setr_pd(5., 6., 7., 8.); + let r = _mm256_permute2f128_pd::<0x31>(a, b); + let e = _mm256_setr_pd(3., 4., 7., 8.); + assert_eq_m256d(r, e); + + let r = _mm256_permute2f128_pd::<0x44>(a, b); + let e = _mm256_setr_pd(0.0, 0.0, 0.0, 0.0); + assert_eq_m256d(r, e); + } + test_mm256_permute2f128_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_permute2f128_si256() { + let a = _mm256_setr_epi32(1, 2, 3, 4, 1, 2, 3, 4); + let b = _mm256_setr_epi32(5, 6, 7, 8, 5, 6, 7, 8); + let r = _mm256_permute2f128_si256::<0x20>(a, b); + let e = _mm256_setr_epi32(1, 2, 3, 4, 5, 6, 7, 8); + assert_eq_m256i(r, e); + + let r = _mm256_permute2f128_si256::<0x44>(a, b); + let e = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0); + assert_eq_m256i(r, e); + } + test_mm256_permute2f128_si256(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_maskload_ps() { + let a = &[1.0f32, 2., 3., 4.]; + let mask = _mm_setr_epi32(0, !0, 0, !0); + let r = _mm_maskload_ps(a.as_ptr(), mask); + let e = _mm_setr_ps(0., 2., 0., 4.); + assert_eq_m128(r, e); + + // Unaligned pointer + let a = Unaligned::new([1.0f32, 2., 3., 4.]); + let mask = _mm_setr_epi32(0, !0, 0, !0); + let r = _mm_maskload_ps(a.as_ptr().cast(), mask); + let e = _mm_setr_ps(0., 2., 0., 4.); + assert_eq_m128(r, e); + + // Only loading first element, so slice can be short. + let a = &[2.0f32]; + let mask = _mm_setr_epi32(!0, 0, 0, 0); + let r = _mm_maskload_ps(a.as_ptr(), mask); + let e = _mm_setr_ps(2.0, 0.0, 0.0, 0.0); + assert_eq_m128(r, e); + + // Only loading last element, so slice can be short. + let a = &[2.0f32]; + let mask = _mm_setr_epi32(0, 0, 0, !0); + let r = _mm_maskload_ps(a.as_ptr().wrapping_sub(3), mask); + let e = _mm_setr_ps(0.0, 0.0, 0.0, 2.0); + assert_eq_m128(r, e); + } + test_mm_maskload_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_maskload_pd() { + let a = &[1.0f64, 2.]; + let mask = _mm_setr_epi64x(0, !0); + let r = _mm_maskload_pd(a.as_ptr(), mask); + let e = _mm_setr_pd(0., 2.); + assert_eq_m128d(r, e); + + // Unaligned pointer + let a = Unaligned::new([1.0f64, 2.]); + let mask = _mm_setr_epi64x(0, !0); + let r = _mm_maskload_pd(a.as_ptr().cast(), mask); + let e = _mm_setr_pd(0., 2.); + assert_eq_m128d(r, e); + + // Only loading first element, so slice can be short. + let a = &[2.0f64]; + let mask = _mm_setr_epi64x(!0, 0); + let r = _mm_maskload_pd(a.as_ptr(), mask); + let e = _mm_setr_pd(2.0, 0.0); + assert_eq_m128d(r, e); + + // Only loading last element, so slice can be short. + let a = &[2.0f64]; + let mask = _mm_setr_epi64x(0, !0); + let r = _mm_maskload_pd(a.as_ptr().wrapping_sub(1), mask); + let e = _mm_setr_pd(0.0, 2.0); + assert_eq_m128d(r, e); + } + test_mm_maskload_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_maskload_ps() { + let a = &[1.0f32, 2., 3., 4., 5., 6., 7., 8.]; + let mask = _mm256_setr_epi32(0, !0, 0, !0, 0, !0, 0, !0); + let r = _mm256_maskload_ps(a.as_ptr(), mask); + let e = _mm256_setr_ps(0., 2., 0., 4., 0., 6., 0., 8.); + assert_eq_m256(r, e); + + // Unaligned pointer + let a = Unaligned::new([1.0f32, 2., 3., 4., 5., 6., 7., 8.]); + let mask = _mm256_setr_epi32(0, !0, 0, !0, 0, !0, 0, !0); + let r = _mm256_maskload_ps(a.as_ptr().cast(), mask); + let e = _mm256_setr_ps(0., 2., 0., 4., 0., 6., 0., 8.); + assert_eq_m256(r, e); + + // Only loading first element, so slice can be short. + let a = &[2.0f32]; + let mask = _mm256_setr_epi32(!0, 0, 0, 0, 0, 0, 0, 0); + let r = _mm256_maskload_ps(a.as_ptr(), mask); + let e = _mm256_setr_ps(2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + assert_eq_m256(r, e); + + // Only loading last element, so slice can be short. + let a = &[2.0f32]; + let mask = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 0, !0); + let r = _mm256_maskload_ps(a.as_ptr().wrapping_sub(7), mask); + let e = _mm256_setr_ps(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0); + assert_eq_m256(r, e); + } + test_mm256_maskload_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_maskload_pd() { + let a = &[1.0f64, 2., 3., 4.]; + let mask = _mm256_setr_epi64x(0, !0, 0, !0); + let r = _mm256_maskload_pd(a.as_ptr(), mask); + let e = _mm256_setr_pd(0., 2., 0., 4.); + assert_eq_m256d(r, e); + + // Unaligned pointer + let a = Unaligned::new([1.0f64, 2., 3., 4.]); + let mask = _mm256_setr_epi64x(0, !0, 0, !0); + let r = _mm256_maskload_pd(a.as_ptr().cast(), mask); + let e = _mm256_setr_pd(0., 2., 0., 4.); + assert_eq_m256d(r, e); + + // Only loading first element, so slice can be short. + let a = &[2.0f64]; + let mask = _mm256_setr_epi64x(!0, 0, 0, 0); + let r = _mm256_maskload_pd(a.as_ptr(), mask); + let e = _mm256_setr_pd(2.0, 0.0, 0.0, 0.0); + assert_eq_m256d(r, e); + + // Only loading last element, so slice can be short. + let a = &[2.0f64]; + let mask = _mm256_setr_epi64x(0, 0, 0, !0); + let r = _mm256_maskload_pd(a.as_ptr().wrapping_sub(3), mask); + let e = _mm256_setr_pd(0.0, 0.0, 0.0, 2.0); + assert_eq_m256d(r, e); + } + test_mm256_maskload_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_maskstore_ps() { + let mut r = _mm_set1_ps(0.); + let mask = _mm_setr_epi32(0, !0, 0, !0); + let a = _mm_setr_ps(1., 2., 3., 4.); + _mm_maskstore_ps(&mut r as *mut _ as *mut f32, mask, a); + let e = _mm_setr_ps(0., 2., 0., 4.); + assert_eq_m128(r, e); + + // Unaligned pointer + let mut r = Unaligned::new([0.0f32; 4]); + let mask = _mm_setr_epi32(0, !0, 0, !0); + let a = _mm_setr_ps(1., 2., 3., 4.); + _mm_maskstore_ps(r.as_mut_ptr().cast(), mask, a); + let e = [0., 2., 0., 4.]; + assert_eq!(r.read(), e); + + // Only storing first element, so slice can be short. + let mut r = [0.0f32]; + let mask = _mm_setr_epi32(!0, 0, 0, 0); + let a = _mm_setr_ps(1., 2., 3., 4.); + _mm_maskstore_ps(r.as_mut_ptr(), mask, a); + let e = [1.0f32]; + assert_eq!(r, e); + + // Only storing last element, so slice can be short. + let mut r = [0.0f32]; + let mask = _mm_setr_epi32(0, 0, 0, !0); + let a = _mm_setr_ps(1., 2., 3., 4.); + _mm_maskstore_ps(r.as_mut_ptr().wrapping_sub(3), mask, a); + let e = [4.0f32]; + assert_eq!(r, e); + } + test_mm_maskstore_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_maskstore_pd() { + let mut r = _mm_set1_pd(0.); + let mask = _mm_setr_epi64x(0, !0); + let a = _mm_setr_pd(1., 2.); + _mm_maskstore_pd(&mut r as *mut _ as *mut f64, mask, a); + let e = _mm_setr_pd(0., 2.); + assert_eq_m128d(r, e); + + // Unaligned pointer + let mut r = Unaligned::new([0.0f64; 2]); + let mask = _mm_setr_epi64x(0, !0); + let a = _mm_setr_pd(1., 2.); + _mm_maskstore_pd(r.as_mut_ptr().cast(), mask, a); + let e = [0., 2.]; + assert_eq!(r.read(), e); + + // Only storing first element, so slice can be short. + let mut r = [0.0f64]; + let mask = _mm_setr_epi64x(!0, 0); + let a = _mm_setr_pd(1., 2.); + _mm_maskstore_pd(r.as_mut_ptr(), mask, a); + let e = [1.0f64]; + assert_eq!(r, e); + + // Only storing last element, so slice can be short. + let mut r = [0.0f64]; + let mask = _mm_setr_epi64x(0, !0); + let a = _mm_setr_pd(1., 2.); + _mm_maskstore_pd(r.as_mut_ptr().wrapping_sub(1), mask, a); + let e = [2.0f64]; + assert_eq!(r, e); + } + test_mm_maskstore_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_maskstore_ps() { + let mut r = _mm256_set1_ps(0.); + let mask = _mm256_setr_epi32(0, !0, 0, !0, 0, !0, 0, !0); + let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.); + _mm256_maskstore_ps(&mut r as *mut _ as *mut f32, mask, a); + let e = _mm256_setr_ps(0., 2., 0., 4., 0., 6., 0., 8.); + assert_eq_m256(r, e); + + // Unaligned pointer + let mut r = Unaligned::new([0.0f32; 8]); + let mask = _mm256_setr_epi32(0, !0, 0, !0, 0, !0, 0, !0); + let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.); + _mm256_maskstore_ps(r.as_mut_ptr().cast(), mask, a); + let e = [0., 2., 0., 4., 0., 6., 0., 8.]; + assert_eq!(r.read(), e); + + // Only storing first element, so slice can be short. + let mut r = [0.0f32]; + let mask = _mm256_setr_epi32(!0, 0, 0, 0, 0, 0, 0, 0); + let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.); + _mm256_maskstore_ps(r.as_mut_ptr(), mask, a); + let e = [1.0f32]; + assert_eq!(r, e); + + // Only storing last element, so slice can be short. + let mut r = [0.0f32]; + let mask = _mm256_setr_epi32(0, 0, 0, 0, 0, 0, 0, !0); + let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.); + _mm256_maskstore_ps(r.as_mut_ptr().wrapping_sub(7), mask, a); + let e = [8.0f32]; + assert_eq!(r, e); + } + test_mm256_maskstore_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_maskstore_pd() { + let mut r = _mm256_set1_pd(0.); + let mask = _mm256_setr_epi64x(0, !0, 0, !0); + let a = _mm256_setr_pd(1., 2., 3., 4.); + _mm256_maskstore_pd(&mut r as *mut _ as *mut f64, mask, a); + let e = _mm256_setr_pd(0., 2., 0., 4.); + assert_eq_m256d(r, e); + + // Unaligned pointer + let mut r = Unaligned::new([0.0f64; 4]); + let mask = _mm256_setr_epi64x(0, !0, 0, !0); + let a = _mm256_setr_pd(1., 2., 3., 4.); + _mm256_maskstore_pd(r.as_mut_ptr().cast(), mask, a); + let e = [0., 2., 0., 4.]; + assert_eq!(r.read(), e); + + // Only storing first element, so slice can be short. + let mut r = [0.0f64]; + let mask = _mm256_setr_epi64x(!0, 0, 0, 0); + let a = _mm256_setr_pd(1., 2., 3., 4.); + _mm256_maskstore_pd(r.as_mut_ptr(), mask, a); + let e = [1.0f64]; + assert_eq!(r, e); + + // Only storing last element, so slice can be short. + let mut r = [0.0f64]; + let mask = _mm256_setr_epi64x(0, 0, 0, !0); + let a = _mm256_setr_pd(1., 2., 3., 4.); + _mm256_maskstore_pd(r.as_mut_ptr().wrapping_sub(3), mask, a); + let e = [4.0f64]; + assert_eq!(r, e); + } + test_mm256_maskstore_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_lddqu_si256() { + #[rustfmt::skip] + let a = _mm256_setr_epi8( + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ); + let p = &a as *const _; + let r = _mm256_lddqu_si256(p); + #[rustfmt::skip] + let e = _mm256_setr_epi8( + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, + ); + assert_eq_m256i(r, e); + } + test_mm256_lddqu_si256(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testz_si256() { + let a = _mm256_setr_epi64x(1, 2, 3, 4); + let b = _mm256_setr_epi64x(5, 6, 7, 8); + let r = _mm256_testz_si256(a, b); + assert_eq!(r, 0); + let b = _mm256_set1_epi64x(0); + let r = _mm256_testz_si256(a, b); + assert_eq!(r, 1); + } + test_mm256_testz_si256(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testc_si256() { + let a = _mm256_setr_epi64x(1, 2, 3, 4); + let b = _mm256_setr_epi64x(5, 6, 7, 8); + let r = _mm256_testc_si256(a, b); + assert_eq!(r, 0); + let b = _mm256_set1_epi64x(0); + let r = _mm256_testc_si256(a, b); + assert_eq!(r, 1); + } + test_mm256_testc_si256(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testnzc_si256() { + let a = _mm256_setr_epi64x(1, 2, 3, 4); + let b = _mm256_setr_epi64x(5, 6, 7, 8); + let r = _mm256_testnzc_si256(a, b); + assert_eq!(r, 1); + let a = _mm256_setr_epi64x(0, 0, 0, 0); + let b = _mm256_setr_epi64x(0, 0, 0, 0); + let r = _mm256_testnzc_si256(a, b); + assert_eq!(r, 0); + } + test_mm256_testnzc_si256(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testz_pd() { + let a = _mm256_setr_pd(1., 2., 3., 4.); + let b = _mm256_setr_pd(5., 6., 7., 8.); + let r = _mm256_testz_pd(a, b); + assert_eq!(r, 1); + let a = _mm256_set1_pd(-1.); + let r = _mm256_testz_pd(a, a); + assert_eq!(r, 0); + } + test_mm256_testz_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testc_pd() { + let a = _mm256_setr_pd(1., 2., 3., 4.); + let b = _mm256_setr_pd(5., 6., 7., 8.); + let r = _mm256_testc_pd(a, b); + assert_eq!(r, 1); + let a = _mm256_set1_pd(1.); + let b = _mm256_set1_pd(-1.); + let r = _mm256_testc_pd(a, b); + assert_eq!(r, 0); + } + test_mm256_testc_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testnzc_pd() { + let a = _mm256_setr_pd(1., 2., 3., 4.); + let b = _mm256_setr_pd(5., 6., 7., 8.); + let r = _mm256_testnzc_pd(a, b); + assert_eq!(r, 0); + let a = _mm256_setr_pd(1., -1., -1., -1.); + let b = _mm256_setr_pd(-1., -1., 1., 1.); + let r = _mm256_testnzc_pd(a, b); + assert_eq!(r, 1); + } + test_mm256_testnzc_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_testz_pd() { + let a = _mm_setr_pd(1., 2.); + let b = _mm_setr_pd(5., 6.); + let r = _mm_testz_pd(a, b); + assert_eq!(r, 1); + let a = _mm_set1_pd(-1.); + let r = _mm_testz_pd(a, a); + assert_eq!(r, 0); + } + test_mm_testz_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_testc_pd() { + let a = _mm_setr_pd(1., 2.); + let b = _mm_setr_pd(5., 6.); + let r = _mm_testc_pd(a, b); + assert_eq!(r, 1); + let a = _mm_set1_pd(1.); + let b = _mm_set1_pd(-1.); + let r = _mm_testc_pd(a, b); + assert_eq!(r, 0); + } + test_mm_testc_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_testnzc_pd() { + let a = _mm_setr_pd(1., 2.); + let b = _mm_setr_pd(5., 6.); + let r = _mm_testnzc_pd(a, b); + assert_eq!(r, 0); + let a = _mm_setr_pd(1., -1.); + let b = _mm_setr_pd(-1., -1.); + let r = _mm_testnzc_pd(a, b); + assert_eq!(r, 1); + } + test_mm_testnzc_pd(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testz_ps() { + let a = _mm256_set1_ps(1.); + let r = _mm256_testz_ps(a, a); + assert_eq!(r, 1); + let a = _mm256_set1_ps(-1.); + let r = _mm256_testz_ps(a, a); + assert_eq!(r, 0); + } + test_mm256_testz_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testc_ps() { + let a = _mm256_set1_ps(1.); + let r = _mm256_testc_ps(a, a); + assert_eq!(r, 1); + let b = _mm256_set1_ps(-1.); + let r = _mm256_testc_ps(a, b); + assert_eq!(r, 0); + } + test_mm256_testc_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm256_testnzc_ps() { + let a = _mm256_set1_ps(1.); + let r = _mm256_testnzc_ps(a, a); + assert_eq!(r, 0); + let a = _mm256_setr_ps(1., -1., -1., -1., -1., -1., -1., -1.); + let b = _mm256_setr_ps(-1., -1., 1., 1., 1., 1., 1., 1.); + let r = _mm256_testnzc_ps(a, b); + assert_eq!(r, 1); + } + test_mm256_testnzc_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_testz_ps() { + let a = _mm_set1_ps(1.); + let r = _mm_testz_ps(a, a); + assert_eq!(r, 1); + let a = _mm_set1_ps(-1.); + let r = _mm_testz_ps(a, a); + assert_eq!(r, 0); + } + test_mm_testz_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_testc_ps() { + let a = _mm_set1_ps(1.); + let r = _mm_testc_ps(a, a); + assert_eq!(r, 1); + let b = _mm_set1_ps(-1.); + let r = _mm_testc_ps(a, b); + assert_eq!(r, 0); + } + test_mm_testc_ps(); + + #[target_feature(enable = "avx")] + unsafe fn test_mm_testnzc_ps() { + let a = _mm_set1_ps(1.); + let r = _mm_testnzc_ps(a, a); + assert_eq!(r, 0); + let a = _mm_setr_ps(1., -1., -1., -1.); + let b = _mm_setr_ps(-1., -1., 1., 1.); + let r = _mm_testnzc_ps(a, b); + assert_eq!(r, 1); + } + test_mm_testnzc_ps(); +} + +#[target_feature(enable = "sse2")] +unsafe fn _mm_setr_epi64x(a: i64, b: i64) -> __m128i { + _mm_set_epi64x(b, a) +} + +#[track_caller] +#[target_feature(enable = "sse")] +unsafe fn assert_eq_m128(a: __m128, b: __m128) { + let r = _mm_cmpeq_ps(a, b); + if _mm_movemask_ps(r) != 0b1111 { + panic!("{:?} != {:?}", a, b); + } +} + +#[track_caller] +#[target_feature(enable = "sse2")] +unsafe fn assert_eq_m128d(a: __m128d, b: __m128d) { + if _mm_movemask_pd(_mm_cmpeq_pd(a, b)) != 0b11 { + panic!("{:?} != {:?}", a, b); + } +} + +#[track_caller] +#[target_feature(enable = "sse2")] +unsafe fn assert_eq_m128i(a: __m128i, b: __m128i) { + assert_eq!(transmute::<_, [u64; 2]>(a), transmute::<_, [u64; 2]>(b)) +} + +#[track_caller] +#[target_feature(enable = "avx")] +unsafe fn assert_eq_m256(a: __m256, b: __m256) { + let cmp = _mm256_cmp_ps::<_CMP_EQ_OQ>(a, b); + if _mm256_movemask_ps(cmp) != 0b11111111 { + panic!("{:?} != {:?}", a, b); + } +} + +#[track_caller] +#[target_feature(enable = "avx")] +unsafe fn assert_eq_m256d(a: __m256d, b: __m256d) { + let cmp = _mm256_cmp_pd::<_CMP_EQ_OQ>(a, b); + if _mm256_movemask_pd(cmp) != 0b1111 { + panic!("{:?} != {:?}", a, b); + } +} + +#[track_caller] +#[target_feature(enable = "avx")] +unsafe fn assert_eq_m256i(a: __m256i, b: __m256i) { + assert_eq!(transmute::<_, [u64; 4]>(a), transmute::<_, [u64; 4]>(b)) +} + +/// Stores `T` in an unaligned address +struct Unaligned { + buf: Vec, + offset: bool, + _marker: std::marker::PhantomData, +} + +impl Unaligned { + fn new(value: T) -> Self { + // Allocate extra byte for unalignment headroom + let len = std::mem::size_of::(); + let mut buf = Vec::::with_capacity(len + 1); + // Force the address to be a non-multiple of 2, so it is as unaligned as it can get. + let offset = (buf.as_ptr() as usize % 2) == 0; + let value_ptr: *const T = &value; + unsafe { + buf.as_mut_ptr().add(offset.into()).copy_from_nonoverlapping(value_ptr.cast(), len); + } + Self { buf, offset, _marker: std::marker::PhantomData } + } + + fn as_ptr(&self) -> *const T { + unsafe { self.buf.as_ptr().add(self.offset.into()).cast() } + } + + fn as_mut_ptr(&mut self) -> *mut T { + unsafe { self.buf.as_mut_ptr().add(self.offset.into()).cast() } + } + + fn read(&self) -> T { + unsafe { self.as_ptr().read_unaligned() } + } }