Do not use host floats in simd_{ceil,floor,round,trunc}

This commit is contained in:
Eduardo Sánchez Muñoz 2023-10-06 15:12:36 +02:00
parent 4587c7c1c0
commit e1e880e9c6

View File

@ -32,28 +32,21 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
assert_eq!(dest_len, op_len);
#[derive(Copy, Clone)]
enum HostFloatOp {
Ceil,
Floor,
Round,
Trunc,
Sqrt,
}
#[derive(Copy, Clone)]
enum Op {
MirOp(mir::UnOp),
Abs,
HostOp(HostFloatOp),
Sqrt,
Round(rustc_apfloat::Round),
}
let which = match intrinsic_name {
"neg" => Op::MirOp(mir::UnOp::Neg),
"fabs" => Op::Abs,
"ceil" => Op::HostOp(HostFloatOp::Ceil),
"floor" => Op::HostOp(HostFloatOp::Floor),
"round" => Op::HostOp(HostFloatOp::Round),
"trunc" => Op::HostOp(HostFloatOp::Trunc),
"fsqrt" => Op::HostOp(HostFloatOp::Sqrt),
"fsqrt" => Op::Sqrt,
"ceil" => Op::Round(rustc_apfloat::Round::TowardPositive),
"floor" => Op::Round(rustc_apfloat::Round::TowardNegative),
"round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway),
"trunc" => Op::Round(rustc_apfloat::Round::TowardZero),
_ => unreachable!(),
};
@ -73,7 +66,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
}
}
Op::HostOp(host_op) => {
Op::Sqrt => {
let ty::Float(float_ty) = op.layout.ty.kind() else {
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
@ -81,28 +74,32 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
match float_ty {
FloatTy::F32 => {
let f = f32::from_bits(op.to_scalar().to_u32()?);
let res = match host_op {
HostFloatOp::Ceil => f.ceil(),
HostFloatOp::Floor => f.floor(),
HostFloatOp::Round => f.round(),
HostFloatOp::Trunc => f.trunc(),
HostFloatOp::Sqrt => f.sqrt(),
};
let res = f.sqrt();
Scalar::from_u32(res.to_bits())
}
FloatTy::F64 => {
let f = f64::from_bits(op.to_scalar().to_u64()?);
let res = match host_op {
HostFloatOp::Ceil => f.ceil(),
HostFloatOp::Floor => f.floor(),
HostFloatOp::Round => f.round(),
HostFloatOp::Trunc => f.trunc(),
HostFloatOp::Sqrt => f.sqrt(),
};
let res = f.sqrt();
Scalar::from_u64(res.to_bits())
}
}
}
Op::Round(rounding) => {
let ty::Float(float_ty) = op.layout.ty.kind() else {
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
match float_ty {
FloatTy::F32 => {
let f = op.to_scalar().to_f32()?;
let res = f.round_to_integral(rounding).value;
Scalar::from_f32(res)
}
FloatTy::F64 => {
let f = op.to_scalar().to_f64()?;
let res = f.round_to_integral(rounding).value;
Scalar::from_f64(res)
}
}
}
};
this.write_scalar(val, &dest)?;