Auto merge of #3192 - eduardosm:x86-avx-intrinsics, r=RalfJung

Implement x86 AVX intrinsics

~Blocked on <https://github.com/rust-lang/miri/pull/3214>~
This commit is contained in:
bors 2024-02-16 20:53:50 +00:00
commit d2a4ef39ca
3 changed files with 1845 additions and 51 deletions

View File

@ -0,0 +1,417 @@
use rustc_apfloat::{ieee::Double, ieee::Single};
use rustc_middle::mir;
use rustc_middle::ty::layout::LayoutOf as _;
use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;
use super::{
bin_op_simd_float_all, conditional_dot_product, convert_float_to_int, horizontal_bin_op,
round_all, test_bits_masked, test_high_bits_masked, unary_op_ps, FloatBinOp, FloatUnaryOp,
};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
crate::MiriInterpCxExt<'mir, 'tcx>
{
fn emulate_x86_avx_intrinsic(
&mut self,
link_name: Symbol,
abi: Abi,
args: &[OpTy<'tcx, Provenance>],
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, EmulateForeignItemResult> {
let this = self.eval_context_mut();
this.expect_target_feature_for_intrinsic(link_name, "avx")?;
// Prefix should have already been checked.
let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
match unprefixed_name {
// Used to implement _mm256_min_ps and _mm256_max_ps functions.
// Note that the semantics are a bit different from Rust simd_min
// and simd_max intrinsics regarding handling of NaN and -0.0: Rust
// matches the IEEE min/max operations, while x86 has different
// semantics.
"min.ps.256" | "max.ps.256" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let which = match unprefixed_name {
"min.ps.256" => FloatBinOp::Min,
"max.ps.256" => FloatBinOp::Max,
_ => unreachable!(),
};
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
// Used to implement _mm256_min_pd and _mm256_max_pd functions.
"min.pd.256" | "max.pd.256" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let which = match unprefixed_name {
"min.pd.256" => FloatBinOp::Min,
"max.pd.256" => FloatBinOp::Max,
_ => unreachable!(),
};
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
// Used to implement the _mm256_round_ps function.
// Rounds the elements of `op` according to `rounding`.
"round.ps.256" => {
let [op, rounding] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
}
// Used to implement the _mm256_round_pd function.
// Rounds the elements of `op` according to `rounding`.
"round.pd.256" => {
let [op, rounding] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
}
// Used to implement _mm256_{sqrt,rcp,rsqrt}_ps functions.
// Performs the operations on all components of `op`.
"sqrt.ps.256" | "rcp.ps.256" | "rsqrt.ps.256" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let which = match unprefixed_name {
"sqrt.ps.256" => FloatUnaryOp::Sqrt,
"rcp.ps.256" => FloatUnaryOp::Rcp,
"rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
_ => unreachable!(),
};
unary_op_ps(this, which, op, dest)?;
}
// Used to implement the _mm256_dp_ps function.
"dp.ps.256" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
conditional_dot_product(this, left, right, imm, dest)?;
}
// Used to implement the _mm256_h{add,sub}_p{s,d} functions.
// Horizontally add/subtract adjacent floating point values
// in `left` and `right`.
"hadd.ps.256" | "hadd.pd.256" | "hsub.ps.256" | "hsub.pd.256" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let which = match unprefixed_name {
"hadd.ps.256" | "hadd.pd.256" => mir::BinOp::Add,
"hsub.ps.256" | "hsub.pd.256" => mir::BinOp::Sub,
_ => unreachable!(),
};
horizontal_bin_op(this, which, /*saturating*/ false, left, right, dest)?;
}
// Used to implement the _mm256_cmp_ps function.
// Performs a comparison operation on each component of `left`
// and `right`. For each component, returns 0 if false or u32::MAX
// if true.
"cmp.ps.256" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let which =
FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
}
// Used to implement the _mm256_cmp_pd function.
// Performs a comparison operation on each component of `left`
// and `right`. For each component, returns 0 if false or u64::MAX
// if true.
"cmp.pd.256" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let which =
FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
}
// Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
// and _mm256_cvttpd_epi32 functions.
// Converts packed f32/f64 to packed i32.
"cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let rnd = match unprefixed_name {
// "current SSE rounding mode", assume nearest
"cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
// always truncate
"cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
_ => unreachable!(),
};
convert_float_to_int(this, op, rnd, dest)?;
}
// Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
// Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
// chunk is shuffled independently: this means that we view the vector as a
// sequence of 4-element arrays, and we shuffle each of these arrays, where
// `control` determines which element of the current `data` array is written.
"vpermilvar.ps" | "vpermilvar.ps.256" => {
let [data, control] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (data, data_len) = this.operand_to_simd(data)?;
let (control, control_len) = this.operand_to_simd(control)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, data_len);
assert_eq!(dest_len, control_len);
for i in 0..dest_len {
let control = this.project_index(&control, i)?;
// Each 128-bit chunk is shuffled independently. Since each chunk contains
// four 32-bit elements, only two bits from `control` are used. To read the
// value from the current chunk, add the destination index truncated to a multiple
// of 4.
let chunk_base = i & !0b11;
let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
.checked_add(chunk_base)
.unwrap();
this.copy_op(
&this.project_index(&data, src_i)?,
&this.project_index(&dest, i)?,
)?;
}
}
// Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
// Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
// chunk is shuffled independently: this means that we view the vector as
// a sequence of 2-element arrays, and we shuffle each of these arrays,
// where `right` determines which element of the current `left` array is
// written.
"vpermilvar.pd" | "vpermilvar.pd.256" => {
let [data, control] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (data, data_len) = this.operand_to_simd(data)?;
let (control, control_len) = this.operand_to_simd(control)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, data_len);
assert_eq!(dest_len, control_len);
for i in 0..dest_len {
let control = this.project_index(&control, i)?;
// Each 128-bit chunk is shuffled independently. Since each chunk contains
// two 64-bit elements, only the second bit from `control` is used (yes, the
// second instead of the first, ask Intel). To read the value from the current
// chunk, add the destination index truncated to a multiple of 2.
let chunk_base = i & !1;
let src_i = ((this.read_scalar(&control)?.to_u64()? >> 1) & 1)
.checked_add(chunk_base)
.unwrap();
this.copy_op(
&this.project_index(&data, src_i)?,
&this.project_index(&dest, i)?,
)?;
}
}
// Used to implement the _mm256_permute2f128_ps, _mm256_permute2f128_pd and
// _mm256_permute2f128_si256 functions. Regardless of the suffix in the name
// thay all can be considered to operate on vectors of 128-bit elements.
// For each 128-bit element of `dest`, copies one from `left`, `right` or
// zero, according to `imm`.
"vperm2f128.ps.256" | "vperm2f128.pd.256" | "vperm2f128.si.256" => {
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
assert_eq!(dest.layout, left.layout);
assert_eq!(dest.layout, right.layout);
assert_eq!(dest.layout.size.bits(), 256);
// Transmute to `[u128; 2]` to process each 128-bit chunk independently.
let u128x2_layout =
this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u128, 2))?;
let left = left.transmute(u128x2_layout, this)?;
let right = right.transmute(u128x2_layout, this)?;
let dest = dest.transmute(u128x2_layout, this)?;
let imm = this.read_scalar(imm)?.to_u8()?;
for i in 0..2 {
let dest = this.project_index(&dest, i)?;
let imm = match i {
0 => imm & 0xF,
1 => imm >> 4,
_ => unreachable!(),
};
if imm & 0b100 != 0 {
this.write_scalar(Scalar::from_u128(0), &dest)?;
} else {
let src = match imm {
0b00 => this.project_index(&left, 0)?,
0b01 => this.project_index(&left, 1)?,
0b10 => this.project_index(&right, 0)?,
0b11 => this.project_index(&right, 1)?,
_ => unreachable!(),
};
this.copy_op(&src, &dest)?;
}
}
}
// Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
// and _mm256_maskload_pd functions.
// For the element `i`, if the high bit of the `i`-th element of `mask`
// is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
// loaded.
"maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
let [ptr, mask] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
mask_load(this, ptr, mask, dest)?;
}
// Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
// and _mm256_maskstore_pd functions.
// For the element `i`, if the high bit of the element `i`-th of `mask`
// is one, it is stored into `ptr.wapping_add(i)`.
// Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
"maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
let [ptr, mask, value] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
mask_store(this, ptr, mask, value)?;
}
// Used to implement the _mm256_lddqu_si256 function.
// Reads a 256-bit vector from an unaligned pointer. This intrinsic
// is expected to perform better than a regular unaligned read when
// the data crosses a cache line, but for Miri this is just a regular
// unaligned read.
"ldu.dq.256" => {
let [src_ptr] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let src_ptr = this.read_pointer(src_ptr)?;
let dest = dest.force_mplace(this)?;
// Unaligned copy, which is what we want.
this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
}
// Used to implement the _mm256_testz_si256, _mm256_testc_si256 and
// _mm256_testnzc_si256 functions.
// Tests `op & mask == 0`, `op & mask == mask` or
// `op & mask != 0 && op & mask != mask`
"ptestz.256" | "ptestc.256" | "ptestnzc.256" => {
let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
let res = match unprefixed_name {
"ptestz.256" => all_zero,
"ptestc.256" => masked_set,
"ptestnzc.256" => !all_zero && !masked_set,
_ => unreachable!(),
};
this.write_scalar(Scalar::from_i32(res.into()), dest)?;
}
// Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
// _mm_testz_pd, _mm_testc_pd, _mm_testnzc_pd, _mm256_testz_ps,
// _mm256_testc_ps, _mm256_testnzc_ps, _mm_testz_ps, _mm_testc_ps and
// _mm_testnzc_ps functions.
// Calculates two booleans:
// `direct`, which is true when the highest bit of each element of `op & mask` is zero.
// `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
// Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
"vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestz.pd" | "vtestc.pd"
| "vtestnzc.pd" | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256"
| "vtestz.ps" | "vtestc.ps" | "vtestnzc.ps" => {
let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
let (direct, negated) = test_high_bits_masked(this, op, mask)?;
let res = match unprefixed_name {
"vtestz.pd.256" | "vtestz.pd" | "vtestz.ps.256" | "vtestz.ps" => direct,
"vtestc.pd.256" | "vtestc.pd" | "vtestc.ps.256" | "vtestc.ps" => negated,
"vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
!direct && !negated,
_ => unreachable!(),
};
this.write_scalar(Scalar::from_i32(res.into()), dest)?;
}
_ => return Ok(EmulateForeignItemResult::NotSupported),
}
Ok(EmulateForeignItemResult::NeedsJumping)
}
}
/// Conditionally loads from `ptr` according the high bit of each
/// element of `mask`. `ptr` does not need to be aligned.
fn mask_load<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
ptr: &OpTy<'tcx, Provenance>,
mask: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, mask_len);
let mask_item_size = mask.layout.field(this, 0).size;
let high_bit_offset = mask_item_size.bits().checked_sub(1).unwrap();
let ptr = this.read_pointer(ptr)?;
for i in 0..dest_len {
let mask = this.project_index(&mask, i)?;
let dest = this.project_index(&dest, i)?;
if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
// Size * u64 is implemented as always checked
#[allow(clippy::arithmetic_side_effects)]
let ptr = ptr.wrapping_offset(dest.layout.size * i, &this.tcx);
// Unaligned copy, which is what we want.
this.mem_copy(ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
} else {
this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
}
}
Ok(())
}
/// Conditionally stores into `ptr` according the high bit of each
/// element of `mask`. `ptr` does not need to be aligned.
fn mask_store<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
ptr: &OpTy<'tcx, Provenance>,
mask: &OpTy<'tcx, Provenance>,
value: &OpTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (value, value_len) = this.operand_to_simd(value)?;
assert_eq!(value_len, mask_len);
let mask_item_size = mask.layout.field(this, 0).size;
let high_bit_offset = mask_item_size.bits().checked_sub(1).unwrap();
let ptr = this.read_pointer(ptr)?;
for i in 0..value_len {
let mask = this.project_index(&mask, i)?;
let value = this.project_index(&value, i)?;
if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
// Size * u64 is implemented as always checked
#[allow(clippy::arithmetic_side_effects)]
let ptr = ptr.wrapping_offset(value.layout.size * i, &this.tcx);
// Unaligned copy, which is what we want.
this.mem_copy(value.ptr(), ptr, value.layout.size, /*nonoverlapping*/ true)?;
}
}
Ok(())
}

View File

@ -1,6 +1,8 @@
use rand::Rng as _;
use rustc_apfloat::{ieee::Single, Float as _};
use rustc_apfloat::{ieee::Single, Float};
use rustc_middle::ty::layout::LayoutOf as _;
use rustc_middle::ty::Ty;
use rustc_middle::{mir, ty};
use rustc_span::Symbol;
use rustc_target::abi::Size;
@ -11,6 +13,7 @@ use helpers::bool_to_simd_element;
use shims::foreign_items::EmulateForeignItemResult;
mod aesni;
mod avx;
mod sse;
mod sse2;
mod sse3;
@ -115,6 +118,11 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
this, link_name, abi, args, dest,
);
}
name if name.starts_with("avx.") => {
return avx::EvalContextExt::emulate_x86_avx_intrinsic(
this, link_name, abi, args, dest,
);
}
_ => return Ok(EmulateForeignItemResult::NotSupported),
}
@ -563,8 +571,65 @@ fn convert_float_to_int<'tcx>(
Ok(())
}
/// Splits `left`, `right` and `dest` (which must be SIMD vectors)
/// into 128-bit chuncks.
///
/// `left`, `right` and `dest` cannot have different types.
///
/// Returns a tuple where:
/// * The first element is the number of 128-bit chunks (let's call it `N`).
/// * The second element is the number of elements per chunk (let's call it `M`).
/// * The third element is the `left` vector split into chunks, i.e, it's
/// type is `[[T; M]; N]`.
/// * The fourth element is the `right` vector split into chunks.
/// * The fifth element is the `dest` vector split into chunks.
fn split_simd_to_128bit_chunks<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<
'tcx,
(u64, u64, MPlaceTy<'tcx, Provenance>, MPlaceTy<'tcx, Provenance>, MPlaceTy<'tcx, Provenance>),
> {
assert_eq!(dest.layout, left.layout);
assert_eq!(dest.layout, right.layout);
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
assert_eq!(dest.layout.size.bits() % 128, 0);
let num_chunks = dest.layout.size.bits() / 128;
assert_eq!(dest_len.checked_rem(num_chunks), Some(0));
let items_per_chunk = dest_len.checked_div(num_chunks).unwrap();
// Transmute to `[[T; items_per_chunk]; num_chunks]`
let element_layout = left.layout.field(this, 0);
let chunked_layout = this.layout_of(Ty::new_array(
this.tcx.tcx,
Ty::new_array(this.tcx.tcx, element_layout.ty, items_per_chunk),
num_chunks,
))?;
let left = left.transmute(chunked_layout, this)?;
let right = right.transmute(chunked_layout, this)?;
let dest = dest.transmute(chunked_layout, this)?;
Ok((num_chunks, items_per_chunk, left, right, dest))
}
/// Horizontaly performs `which` operation on adjacent values of
/// `left` and `right` SIMD vectors and stores the result in `dest`.
/// "Horizontal" means that the i-th output element is calculated
/// from the elements 2*i and 2*i+1 of the concatenation of `left` and
/// `right`.
///
/// Each 128-bit chunk is treated independently (i.e., the value for
/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
/// 128-bit chunks of `left` and `right`).
fn horizontal_bin_op<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: mir::BinOp,
@ -573,32 +638,34 @@ fn horizontal_bin_op<'tcx>(
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
let (num_chunks, items_per_chunk, left, right, dest) =
split_simd_to_128bit_chunks(this, left, right, dest)?;
assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
assert_eq!(dest_len % 2, 0);
let middle = items_per_chunk / 2;
for i in 0..num_chunks {
let left = this.project_index(&left, i)?;
let right = this.project_index(&right, i)?;
let dest = this.project_index(&dest, i)?;
let middle = dest_len / 2;
for i in 0..dest_len {
// `i` is the index in `dest`
// `j` is the index of the 2-item chunk in `src`
let (j, src) =
if i < middle { (i, &left) } else { (i.checked_sub(middle).unwrap(), &right) };
// `base_i` is the index of the first item of the 2-item chunk in `src`
let base_i = j.checked_mul(2).unwrap();
let lhs = this.read_immediate(&this.project_index(src, base_i)?)?;
let rhs = this.read_immediate(&this.project_index(src, base_i.checked_add(1).unwrap())?)?;
for j in 0..items_per_chunk {
// `j` is the index in `dest`
// `k` is the index of the 2-item chunk in `src`
let (k, src) =
if j < middle { (j, &left) } else { (j.checked_sub(middle).unwrap(), &right) };
// `base_i` is the index of the first item of the 2-item chunk in `src`
let base_i = k.checked_mul(2).unwrap();
let lhs = this.read_immediate(&this.project_index(src, base_i)?)?;
let rhs =
this.read_immediate(&this.project_index(src, base_i.checked_add(1).unwrap())?)?;
let res = if saturating {
Immediate::from(this.saturating_arith(which, &lhs, &rhs)?)
} else {
*this.wrapping_binary_op(which, &lhs, &rhs)?
};
let res = if saturating {
Immediate::from(this.saturating_arith(which, &lhs, &rhs)?)
} else {
*this.wrapping_binary_op(which, &lhs, &rhs)?
};
this.write_immediate(res, &this.project_index(&dest, i)?)?;
this.write_immediate(res, &this.project_index(&dest, j)?)?;
}
}
Ok(())
@ -608,6 +675,10 @@ fn horizontal_bin_op<'tcx>(
/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
/// products (up to 4), and conditionally stores the sum in `dest` using
/// the low 4 bits of `imm`.
///
/// Each 128-bit chunk is treated independently (i.e., the value for
/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
/// 128-bit blocks of `left` and `right`).
fn conditional_dot_product<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
@ -615,39 +686,43 @@ fn conditional_dot_product<'tcx>(
imm: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
let (num_chunks, items_per_chunk, left, right, dest) =
split_simd_to_128bit_chunks(this, left, right, dest)?;
assert_eq!(left_len, right_len);
assert!(dest_len <= 4);
let element_layout = left.layout.field(this, 0).field(this, 0);
assert!(items_per_chunk <= 4);
let imm = this.read_scalar(imm)?.to_u8()?;
// `imm` is a `u8` for SSE4.1 or an `i32` for AVX :/
let imm = this.read_scalar(imm)?.to_uint(imm.layout.size)?;
let element_layout = left.layout.field(this, 0);
// Calculate dot product
// Elements are floating point numbers, but we can use `from_int`
// because the representation of 0.0 is all zero bits.
let mut sum = ImmTy::from_int(0u8, element_layout);
for i in 0..left_len {
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
}
}
// Write to destination (conditioned to imm)
for i in 0..dest_len {
for i in 0..num_chunks {
let left = this.project_index(&left, i)?;
let right = this.project_index(&right, i)?;
let dest = this.project_index(&dest, i)?;
if imm & (1 << i) != 0 {
this.write_immediate(*sum, &dest)?;
} else {
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
// Calculate dot product
// Elements are floating point numbers, but we can use `from_int`
// for the initial value because the representation of 0.0 is all zero bits.
let mut sum = ImmTy::from_int(0u8, element_layout);
for j in 0..items_per_chunk {
if imm & (1 << j.checked_add(4).unwrap()) != 0 {
let left = this.read_immediate(&this.project_index(&left, j)?)?;
let right = this.read_immediate(&this.project_index(&right, j)?)?;
let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
}
}
// Write to destination (conditioned to imm)
for j in 0..items_per_chunk {
let dest = this.project_index(&dest, j)?;
if imm & (1 << j) != 0 {
this.write_immediate(*sum, &dest)?;
} else {
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
}
}
}
@ -684,3 +759,36 @@ fn test_bits_masked<'tcx>(
Ok((all_zero, masked_set))
}
/// Calculates two booleans.
///
/// The first is true when the highest bit of each element of `op & mask` is zero.
/// The second is true when the highest bit of each element of `!op & mask` is zero.
fn test_high_bits_masked<'tcx>(
this: &crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
mask: &OpTy<'tcx, Provenance>,
) -> InterpResult<'tcx, (bool, bool)> {
assert_eq!(op.layout, mask.layout);
let (op, op_len) = this.operand_to_simd(op)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
assert_eq!(op_len, mask_len);
let high_bit_offset = op.layout.field(this, 0).size.bits().checked_sub(1).unwrap();
let mut direct = true;
let mut negated = true;
for i in 0..op_len {
let op = this.project_index(&op, i)?;
let mask = this.project_index(&mask, i)?;
let op = this.read_scalar(&op)?.to_uint(op.layout.size)?;
let mask = this.read_scalar(&mask)?.to_uint(mask.layout.size)?;
direct &= (op & mask) >> high_bit_offset == 0;
negated &= (!op & mask) >> high_bit_offset == 0;
}
Ok((direct, negated))
}

File diff suppressed because it is too large Load Diff