diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs index c97a052f517..af98b38af8c 100644 --- a/src/tools/miri/src/shims/intrinsics/simd.rs +++ b/src/tools/miri/src/shims/intrinsics/simd.rs @@ -33,6 +33,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { | "round" | "trunc" | "fsqrt" + | "fsin" + | "fcos" + | "fexp" + | "fexp2" + | "flog" + | "flog2" + | "flog10" | "ctlz" | "cttz" | "bswap" @@ -45,17 +52,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { assert_eq!(dest_len, op_len); #[derive(Copy, Clone)] - enum Op { + enum Op<'a> { MirOp(mir::UnOp), Abs, - Sqrt, Round(rustc_apfloat::Round), Numeric(Symbol), + HostOp(&'a str), } let which = match intrinsic_name { "neg" => Op::MirOp(mir::UnOp::Neg), "fabs" => Op::Abs, - "fsqrt" => Op::Sqrt, "ceil" => Op::Round(rustc_apfloat::Round::TowardPositive), "floor" => Op::Round(rustc_apfloat::Round::TowardNegative), "round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway), @@ -64,7 +70,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { "cttz" => Op::Numeric(sym::cttz), "bswap" => Op::Numeric(sym::bswap), "bitreverse" => Op::Numeric(sym::bitreverse), - _ => unreachable!(), + _ => Op::HostOp(intrinsic_name), }; for i in 0..dest_len { @@ -89,7 +95,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { FloatTy::F128 => unimplemented!("f16_f128"), } } - Op::Sqrt => { + Op::HostOp(host_op) => { let ty::Float(float_ty) = op.layout.ty.kind() else { span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) }; @@ -98,13 +104,37 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { FloatTy::F16 => unimplemented!("f16_f128"), FloatTy::F32 => { let f = op.to_scalar().to_f32()?; - let res = f.to_host().sqrt().to_soft(); + let f_host = f.to_host(); + let res = match host_op { + "fsqrt" => f_host.sqrt(), + "fsin" => f_host.sin(), + "fcos" => f_host.cos(), + "fexp" => f_host.exp(), + "fexp2" => f_host.exp2(), + "flog" => f_host.ln(), + "flog2" => f_host.log2(), + "flog10" => f_host.log10(), + _ => bug!(), + }; + let res = res.to_soft(); let res = this.adjust_nan(res, &[f]); Scalar::from(res) } FloatTy::F64 => { let f = op.to_scalar().to_f64()?; - let res = f.to_host().sqrt().to_soft(); + let f_host = f.to_host(); + let res = match host_op { + "fsqrt" => f_host.sqrt(), + "fsin" => f_host.sin(), + "fcos" => f_host.cos(), + "fexp" => f_host.exp(), + "fexp2" => f_host.exp2(), + "flog" => f_host.ln(), + "flog2" => f_host.log2(), + "flog10" => f_host.log10(), + _ => bug!(), + }; + let res = res.to_soft(); let res = this.adjust_nan(res, &[f]); Scalar::from(res) } diff --git a/src/tools/miri/tests/pass/portable-simd.rs b/src/tools/miri/tests/pass/portable-simd.rs index 399913a757b..cdb441b450b 100644 --- a/src/tools/miri/tests/pass/portable-simd.rs +++ b/src/tools/miri/tests/pass/portable-simd.rs @@ -526,6 +526,23 @@ fn simd_intrinsics() { } } +fn simd_float_intrinsics() { + use intrinsics::*; + + // These are just smoke tests to ensure the intrinsics can be called. + unsafe { + let a = f32x4::splat(10.0); + simd_fsqrt(a); + simd_fsin(a); + simd_fcos(a); + simd_fexp(a); + simd_fexp2(a); + simd_flog(a); + simd_flog2(a); + simd_flog10(a); + } +} + fn simd_masked_loadstore() { // The buffer is deliberarely too short, so reading the last element would be UB. let buf = [3i32; 3]; @@ -559,5 +576,6 @@ fn main() { simd_gather_scatter(); simd_round(); simd_intrinsics(); + simd_float_intrinsics(); simd_masked_loadstore(); }