diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 1e2abbd4f6..7ac3fccc99 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -12,6 +12,7 @@ use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{BuilderMethods, ConstMethods, LayoutTypeMethods, OverflowOp}; use rustc_codegen_ssa::MemFlags; +use rustc_middle::bug; use rustc_middle::ty::Ty; use rustc_span::Span; use rustc_target::abi::{Abi, Align, Scalar, Size}; @@ -1973,11 +1974,26 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { for (argument, argument_type) in args.iter().zip(argument_types) { assert_ty_eq!(self, argument.ty, argument_type); } - let args = args.iter().map(|arg| arg.def(self)).collect::>(); - self.emit() - .function_call(result_type, None, llfn.def(self), args) - .unwrap() - .with_type(result_type) + let llfn_def = llfn.def(self); + let libm_intrinsic = self.libm_intrinsics.borrow().get(&llfn_def).cloned(); + if let Some(libm_intrinsic) = libm_intrinsic { + let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args); + if result_type != result.ty { + bug!( + "Mismatched libm result type for {:?}: expected {}, got {}", + libm_intrinsic, + self.debug_type(result_type), + self.debug_type(result.ty), + ); + } + result + } else { + let args = args.iter().map(|arg| arg.def(self)).collect::>(); + self.emit() + .function_call(result_type, None, llfn_def, args) + .unwrap() + .with_type(result_type) + } } fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { diff --git a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs index 5b69368de3..ef847a2a71 100644 --- a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs +++ b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs @@ -53,33 +53,43 @@ impl ExtInst { } impl<'a, 'tcx> Builder<'a, 'tcx> { - pub fn gl_op(&mut self, op: GLOp, args: impl AsRef<[SpirvValue]>) -> SpirvValue { + pub fn gl_op( + &mut self, + op: GLOp, + result_type: Word, + args: impl AsRef<[SpirvValue]>, + ) -> SpirvValue { let args = args.as_ref(); let glsl = self.ext_inst.borrow_mut().import_glsl(self); self.emit() .ext_inst( - args[0].ty, + result_type, None, glsl, op as u32, args.iter().map(|a| Operand::IdRef(a.def(self))), ) .unwrap() - .with_type(args[0].ty) + .with_type(result_type) } - pub fn cl_op(&mut self, op: CLOp, args: impl AsRef<[SpirvValue]>) -> SpirvValue { + pub fn cl_op( + &mut self, + op: CLOp, + result_type: Word, + args: impl AsRef<[SpirvValue]>, + ) -> SpirvValue { let args = args.as_ref(); let opencl = self.ext_inst.borrow_mut().import_opencl(self); self.emit() .ext_inst( - args[0].ty, + result_type, None, opencl, op as u32, args.iter().map(|a| Operand::IdRef(a.def(self))), ) .unwrap() - .with_type(args[0].ty) + .with_type(result_type) } } diff --git a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs index 349df7328d..f476d4bab0 100644 --- a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs +++ b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs @@ -1,6 +1,6 @@ use super::Builder; use crate::abi::ConvSpirvType; -use crate::builder_spirv::SpirvValueExt; +use crate::builder_spirv::{SpirvValue, SpirvValueExt}; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; use rspirv::spirv::{CLOp, GLOp}; @@ -30,6 +30,36 @@ fn int_type_width_signed(ty: Ty<'_>, cx: &CodegenCx<'_>) -> Option<(u64, bool)> } } +impl Builder<'_, '_> { + pub fn copysign(&mut self, val: SpirvValue, sign: SpirvValue) -> SpirvValue { + let width = match self.lookup_type(val.ty) { + SpirvType::Float(width) => width, + other => bug!( + "copysign must have float argument, not {}", + other.debug(val.ty, self) + ), + }; + let int_ty = SpirvType::Integer(width, false).def(self); + let (mask_sign, mask_value) = match width { + 32 => ( + self.constant_u32(1 << 31), + self.constant_u32(u32::max_value() >> 1), + ), + 64 => ( + self.constant_u64(1 << 63), + self.constant_u64(u64::max_value() >> 1), + ), + _ => bug!("copysign must have width 32 or 64, not {}", width), + }; + let val_bits = self.bitcast(val, int_ty); + let sign_bits = self.bitcast(sign, int_ty); + let val_masked = self.and(val_bits, mask_value); + let sign_masked = self.and(sign_bits, mask_sign); + let result_bits = self.or(val_masked, sign_masked); + self.bitcast(result_bits, val.ty) + } +} + impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { fn codegen_intrinsic_call( &mut self, @@ -53,6 +83,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { let arg_tys = sig.inputs(); let name = self.tcx.item_name(def_id); + let ret_ty = self.layout_of(sig.output()).spirv_type(self); let result = PlaceRef::new_sized(llresult, fn_abi.ret.layout); let value = match name { @@ -122,76 +153,88 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { // TODO: Configure these to be ocl vs. gl ext instructions, etc. sym::sqrtf32 | sym::sqrtf64 => { if self.kernel_mode { - self.cl_op(CLOp::sqrt, [args[0].immediate()]) + self.cl_op(CLOp::sqrt, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Sqrt, [args[0].immediate()]) + self.gl_op(GLOp::Sqrt, ret_ty, [args[0].immediate()]) } } sym::powif32 | sym::powif64 => { if self.kernel_mode { - self.cl_op(CLOp::pown, [args[0].immediate(), args[1].immediate()]) + self.cl_op( + CLOp::pown, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } else { let float = self.sitofp(args[1].immediate(), args[0].immediate().ty); - self.gl_op(GLOp::Pow, [args[0].immediate(), float]) + self.gl_op(GLOp::Pow, ret_ty, [args[0].immediate(), float]) } } sym::sinf32 | sym::sinf64 => { if self.kernel_mode { - self.cl_op(CLOp::sin, [args[0].immediate()]) + self.cl_op(CLOp::sin, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Sin, [args[0].immediate()]) + self.gl_op(GLOp::Sin, ret_ty, [args[0].immediate()]) } } sym::cosf32 | sym::cosf64 => { if self.kernel_mode { - self.cl_op(CLOp::cos, [args[0].immediate()]) + self.cl_op(CLOp::cos, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Cos, [args[0].immediate()]) + self.gl_op(GLOp::Cos, ret_ty, [args[0].immediate()]) } } sym::powf32 | sym::powf64 => { if self.kernel_mode { - self.cl_op(CLOp::pow, [args[0].immediate(), args[1].immediate()]) + self.cl_op( + CLOp::pow, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } else { - self.gl_op(GLOp::Pow, [args[0].immediate(), args[1].immediate()]) + self.gl_op( + GLOp::Pow, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } } sym::expf32 | sym::expf64 => { if self.kernel_mode { - self.cl_op(CLOp::exp, [args[0].immediate()]) + self.cl_op(CLOp::exp, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Exp, [args[0].immediate()]) + self.gl_op(GLOp::Exp, ret_ty, [args[0].immediate()]) } } sym::exp2f32 | sym::exp2f64 => { if self.kernel_mode { - self.cl_op(CLOp::exp2, [args[0].immediate()]) + self.cl_op(CLOp::exp2, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Exp2, [args[0].immediate()]) + self.gl_op(GLOp::Exp2, ret_ty, [args[0].immediate()]) } } sym::logf32 | sym::logf64 => { if self.kernel_mode { - self.cl_op(CLOp::log, [args[0].immediate()]) + self.cl_op(CLOp::log, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Log, [args[0].immediate()]) + self.gl_op(GLOp::Log, ret_ty, [args[0].immediate()]) } } sym::log2f32 | sym::log2f64 => { if self.kernel_mode { - self.cl_op(CLOp::log2, [args[0].immediate()]) + self.cl_op(CLOp::log2, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Log2, [args[0].immediate()]) + self.gl_op(GLOp::Log2, ret_ty, [args[0].immediate()]) } } sym::log10f32 | sym::log10f64 => { if self.kernel_mode { - self.cl_op(CLOp::log10, [args[0].immediate()]) + self.cl_op(CLOp::log10, ret_ty, [args[0].immediate()]) } else { // spir-v glsl doesn't have log10, so, // log10(x) == (1 / ln(10)) * ln(x) let mul = self.constant_float(args[0].immediate().ty, 1.0 / 10.0f64.ln()); - let ln = self.gl_op(GLOp::Log, [args[0].immediate()]); + let ln = self.gl_op(GLOp::Log, ret_ty, [args[0].immediate()]); self.mul(mul, ln) } } @@ -199,6 +242,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { if self.kernel_mode { self.cl_op( CLOp::fma, + ret_ty, [ args[0].immediate(), args[1].immediate(), @@ -208,6 +252,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { } else { self.gl_op( GLOp::Fma, + ret_ty, [ args[0].immediate(), args[1].immediate(), @@ -218,92 +263,88 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { } sym::fabsf32 | sym::fabsf64 => { if self.kernel_mode { - self.cl_op(CLOp::fabs, [args[0].immediate()]) + self.cl_op(CLOp::fabs, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::FAbs, [args[0].immediate()]) + self.gl_op(GLOp::FAbs, ret_ty, [args[0].immediate()]) } } sym::minnumf32 | sym::minnumf64 => { if self.kernel_mode { - self.cl_op(CLOp::fmin, [args[0].immediate(), args[1].immediate()]) + self.cl_op( + CLOp::fmin, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } else { - self.gl_op(GLOp::FMin, [args[0].immediate(), args[1].immediate()]) + self.gl_op( + GLOp::FMin, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } } sym::maxnumf32 | sym::maxnumf64 => { if self.kernel_mode { - self.cl_op(CLOp::fmax, [args[0].immediate(), args[1].immediate()]) + self.cl_op( + CLOp::fmax, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } else { - self.gl_op(GLOp::FMax, [args[0].immediate(), args[1].immediate()]) + self.gl_op( + GLOp::FMax, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } } sym::copysignf32 | sym::copysignf64 => { if self.kernel_mode { - self.cl_op(CLOp::copysign, [args[0].immediate(), args[1].immediate()]) + self.cl_op( + CLOp::copysign, + ret_ty, + [args[0].immediate(), args[1].immediate()], + ) } else { let val = args[0].immediate(); let sign = args[1].immediate(); - let width = match self.lookup_type(val.ty) { - SpirvType::Float(width) => width, - other => bug!( - "copysign must have float argument, not {}", - other.debug(val.ty, self) - ), - }; - let int_ty = SpirvType::Integer(width, false).def(self); - let (mask_sign, mask_value) = match width { - 32 => ( - self.constant_u32(1 << 31), - self.constant_u32(u32::max_value() >> 1), - ), - 64 => ( - self.constant_u64(1 << 63), - self.constant_u64(u64::max_value() >> 1), - ), - _ => bug!("copysign must have width 32 or 64, not {}", width), - }; - let val_bits = self.bitcast(val, int_ty); - let sign_bits = self.bitcast(sign, int_ty); - let val_masked = self.and(val_bits, mask_value); - let sign_masked = self.and(sign_bits, mask_sign); - let result_bits = self.or(val_masked, sign_masked); - self.bitcast(result_bits, val.ty) + self.copysign(val, sign) } } sym::floorf32 | sym::floorf64 => { if self.kernel_mode { - self.cl_op(CLOp::floor, [args[0].immediate()]) + self.cl_op(CLOp::floor, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Floor, [args[0].immediate()]) + self.gl_op(GLOp::Floor, ret_ty, [args[0].immediate()]) } } sym::ceilf32 | sym::ceilf64 => { if self.kernel_mode { - self.cl_op(CLOp::ceil, [args[0].immediate()]) + self.cl_op(CLOp::ceil, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Ceil, [args[0].immediate()]) + self.gl_op(GLOp::Ceil, ret_ty, [args[0].immediate()]) } } sym::truncf32 | sym::truncf64 => { if self.kernel_mode { - self.cl_op(CLOp::trunc, [args[0].immediate()]) + self.cl_op(CLOp::trunc, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Trunc, [args[0].immediate()]) + self.gl_op(GLOp::Trunc, ret_ty, [args[0].immediate()]) } } // TODO: Correctness of all these rounds sym::rintf32 | sym::rintf64 => { if self.kernel_mode { - self.cl_op(CLOp::rint, [args[0].immediate()]) + self.cl_op(CLOp::rint, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Round, [args[0].immediate()]) + self.gl_op(GLOp::Round, ret_ty, [args[0].immediate()]) } } sym::nearbyintf32 | sym::nearbyintf64 | sym::roundf32 | sym::roundf64 => { if self.kernel_mode { - self.cl_op(CLOp::round, [args[0].immediate()]) + self.cl_op(CLOp::round, ret_ty, [args[0].immediate()]) } else { - self.gl_op(GLOp::Round, [args[0].immediate()]) + self.gl_op(GLOp::Round, ret_ty, [args[0].immediate()]) } } @@ -317,7 +358,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { // TODO: Do we want to manually implement these instead of using intel instructions? sym::ctlz | sym::ctlz_nonzero => { if self.kernel_mode { - self.cl_op(CLOp::clz, [args[0].immediate()]) + self.cl_op(CLOp::clz, ret_ty, [args[0].immediate()]) } else { self.ext_inst .borrow_mut() @@ -334,7 +375,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { } sym::cttz | sym::cttz_nonzero => { if self.kernel_mode { - self.cl_op(CLOp::ctz, [args[0].immediate()]) + self.cl_op(CLOp::ctz, ret_ty, [args[0].immediate()]) } else { self.emit() .u_count_trailing_zeros_intel( diff --git a/crates/rustc_codegen_spirv/src/builder/libm_intrinsics.rs b/crates/rustc_codegen_spirv/src/builder/libm_intrinsics.rs new file mode 100644 index 0000000000..de4b089dcb --- /dev/null +++ b/crates/rustc_codegen_spirv/src/builder/libm_intrinsics.rs @@ -0,0 +1,351 @@ +use super::Builder; +use crate::builder_spirv::{SpirvValue, SpirvValueExt}; +use rspirv::spirv::{GLOp, Word}; +use rustc_codegen_ssa::traits::BuilderMethods; + +#[derive(Copy, Clone, Debug)] +pub enum LibmCustomIntrinsic { + CopySign, + Cbrt, + Erf, + Erfc, + Exp10, + Expm1, + Fdim, + Fmod, + Log10, + Hypot, + Ilogb, + J0, + Y0, + J1, + Y1, + Jn, + Yn, + Lgamma, + LgammaR, + Tgamma, + Log1p, + NextAfter, + Remainder, + RemQuo, + Scalbn, + SinCos, +} + +#[derive(Copy, Clone, Debug)] +pub enum LibmIntrinsic { + GLOp(GLOp), + Custom(LibmCustomIntrinsic), +} + +pub const TABLE: &[(&str, LibmIntrinsic)] = &[ + ("acos", LibmIntrinsic::GLOp(GLOp::Acos)), + ("acosf", LibmIntrinsic::GLOp(GLOp::Acos)), + ("acosh", LibmIntrinsic::GLOp(GLOp::Acosh)), + ("acoshf", LibmIntrinsic::GLOp(GLOp::Acosh)), + ("asin", LibmIntrinsic::GLOp(GLOp::Asin)), + ("asinf", LibmIntrinsic::GLOp(GLOp::Asin)), + ("asinh", LibmIntrinsic::GLOp(GLOp::Asinh)), + ("asinhf", LibmIntrinsic::GLOp(GLOp::Asinh)), + ("atan2", LibmIntrinsic::GLOp(GLOp::Atan2)), + ("atan2f", LibmIntrinsic::GLOp(GLOp::Atan2)), + ("atan", LibmIntrinsic::GLOp(GLOp::Atan)), + ("atanf", LibmIntrinsic::GLOp(GLOp::Atan)), + ("atanh", LibmIntrinsic::GLOp(GLOp::Atanh)), + ("atanhf", LibmIntrinsic::GLOp(GLOp::Atanh)), + ("cbrt", LibmIntrinsic::Custom(LibmCustomIntrinsic::Cbrt)), + ("cbrtf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Cbrt)), + ("ceil", LibmIntrinsic::GLOp(GLOp::Ceil)), + ("ceilf", LibmIntrinsic::GLOp(GLOp::Ceil)), + ( + "copysign", + LibmIntrinsic::Custom(LibmCustomIntrinsic::CopySign), + ), + ( + "copysignf", + LibmIntrinsic::Custom(LibmCustomIntrinsic::CopySign), + ), + ("cos", LibmIntrinsic::GLOp(GLOp::Cos)), + ("cosf", LibmIntrinsic::GLOp(GLOp::Cos)), + ("cosh", LibmIntrinsic::GLOp(GLOp::Cosh)), + ("coshf", LibmIntrinsic::GLOp(GLOp::Cosh)), + ("erf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Erf)), + ("erff", LibmIntrinsic::Custom(LibmCustomIntrinsic::Erf)), + ("erfc", LibmIntrinsic::Custom(LibmCustomIntrinsic::Erfc)), + ("erfcf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Erfc)), + ("exp10", LibmIntrinsic::Custom(LibmCustomIntrinsic::Exp10)), + ("exp10f", LibmIntrinsic::Custom(LibmCustomIntrinsic::Exp10)), + ("exp2", LibmIntrinsic::GLOp(GLOp::Exp2)), + ("exp2f", LibmIntrinsic::GLOp(GLOp::Exp2)), + ("exp", LibmIntrinsic::GLOp(GLOp::Exp)), + ("expf", LibmIntrinsic::GLOp(GLOp::Exp)), + ("expm1", LibmIntrinsic::Custom(LibmCustomIntrinsic::Expm1)), + ("expm1f", LibmIntrinsic::Custom(LibmCustomIntrinsic::Expm1)), + ("fabs", LibmIntrinsic::GLOp(GLOp::FAbs)), + ("fabsf", LibmIntrinsic::GLOp(GLOp::FAbs)), + ("fdim", LibmIntrinsic::Custom(LibmCustomIntrinsic::Fdim)), + ("fdimf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Fdim)), + ("floor", LibmIntrinsic::GLOp(GLOp::Floor)), + ("floorf", LibmIntrinsic::GLOp(GLOp::Floor)), + ("fma", LibmIntrinsic::GLOp(GLOp::Fma)), + ("fmaf", LibmIntrinsic::GLOp(GLOp::Fma)), + ("fmax", LibmIntrinsic::GLOp(GLOp::FMax)), + ("fmaxf", LibmIntrinsic::GLOp(GLOp::FMax)), + ("fmin", LibmIntrinsic::GLOp(GLOp::FMin)), + ("fminf", LibmIntrinsic::GLOp(GLOp::FMin)), + ("fmod", LibmIntrinsic::Custom(LibmCustomIntrinsic::Fmod)), + ("fmodf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Fmod)), + ("frexp", LibmIntrinsic::GLOp(GLOp::FrexpStruct)), + ("frexpf", LibmIntrinsic::GLOp(GLOp::FrexpStruct)), + ("hypot", LibmIntrinsic::Custom(LibmCustomIntrinsic::Hypot)), + ("hypotf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Hypot)), + ("ilogb", LibmIntrinsic::Custom(LibmCustomIntrinsic::Ilogb)), + ("ilogbf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Ilogb)), + ("j0", LibmIntrinsic::Custom(LibmCustomIntrinsic::J0)), + ("j0f", LibmIntrinsic::Custom(LibmCustomIntrinsic::J0)), + ("y0", LibmIntrinsic::Custom(LibmCustomIntrinsic::Y0)), + ("y0f", LibmIntrinsic::Custom(LibmCustomIntrinsic::Y0)), + ("j1", LibmIntrinsic::Custom(LibmCustomIntrinsic::J1)), + ("j1f", LibmIntrinsic::Custom(LibmCustomIntrinsic::J1)), + ("y1", LibmIntrinsic::Custom(LibmCustomIntrinsic::Y1)), + ("y1f", LibmIntrinsic::Custom(LibmCustomIntrinsic::Y1)), + ("jn", LibmIntrinsic::Custom(LibmCustomIntrinsic::Jn)), + ("jnf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Jn)), + ("yn", LibmIntrinsic::Custom(LibmCustomIntrinsic::Yn)), + ("ynf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Yn)), + ("ldexp", LibmIntrinsic::GLOp(GLOp::Ldexp)), + ("ldexpf", LibmIntrinsic::GLOp(GLOp::Ldexp)), + ("lgamma", LibmIntrinsic::Custom(LibmCustomIntrinsic::Lgamma)), + ( + "lgammaf", + LibmIntrinsic::Custom(LibmCustomIntrinsic::Lgamma), + ), + ( + "lgamma_r", + LibmIntrinsic::Custom(LibmCustomIntrinsic::LgammaR), + ), + ( + "lgammaf_r", + LibmIntrinsic::Custom(LibmCustomIntrinsic::LgammaR), + ), + ("tgamma", LibmIntrinsic::Custom(LibmCustomIntrinsic::Tgamma)), + ( + "tgammaf", + LibmIntrinsic::Custom(LibmCustomIntrinsic::Tgamma), + ), + ("log10", LibmIntrinsic::Custom(LibmCustomIntrinsic::Log10)), + ("log10f", LibmIntrinsic::Custom(LibmCustomIntrinsic::Log10)), + ("log1p", LibmIntrinsic::Custom(LibmCustomIntrinsic::Log1p)), + ("log1pf", LibmIntrinsic::Custom(LibmCustomIntrinsic::Log1p)), + ("log2", LibmIntrinsic::GLOp(GLOp::Log2)), + ("log2f", LibmIntrinsic::GLOp(GLOp::Log2)), + ("log", LibmIntrinsic::GLOp(GLOp::Log)), + ("logf", LibmIntrinsic::GLOp(GLOp::Log)), + ("modf", LibmIntrinsic::GLOp(GLOp::ModfStruct)), + ("modff", LibmIntrinsic::GLOp(GLOp::ModfStruct)), + ( + "nextafter", + LibmIntrinsic::Custom(LibmCustomIntrinsic::NextAfter), + ), + ( + "nextafterf", + LibmIntrinsic::Custom(LibmCustomIntrinsic::NextAfter), + ), + ("pow", LibmIntrinsic::GLOp(GLOp::Pow)), + ("powf", LibmIntrinsic::GLOp(GLOp::Pow)), + ( + "remainder", + LibmIntrinsic::Custom(LibmCustomIntrinsic::Remainder), + ), + ( + "remainderf", + LibmIntrinsic::Custom(LibmCustomIntrinsic::Remainder), + ), + ("remquo", LibmIntrinsic::Custom(LibmCustomIntrinsic::RemQuo)), + ( + "remquof", + LibmIntrinsic::Custom(LibmCustomIntrinsic::RemQuo), + ), + ("round", LibmIntrinsic::GLOp(GLOp::Round)), + ("roundf", LibmIntrinsic::GLOp(GLOp::Round)), + ("scalbn", LibmIntrinsic::Custom(LibmCustomIntrinsic::Scalbn)), + ( + "scalbnf", + LibmIntrinsic::Custom(LibmCustomIntrinsic::Scalbn), + ), + ("sin", LibmIntrinsic::GLOp(GLOp::Sin)), + ("sincos", LibmIntrinsic::Custom(LibmCustomIntrinsic::SinCos)), + ( + "sincosf", + LibmIntrinsic::Custom(LibmCustomIntrinsic::SinCos), + ), + ("sinf", LibmIntrinsic::GLOp(GLOp::Sin)), + ("sinh", LibmIntrinsic::GLOp(GLOp::Sinh)), + ("sinhf", LibmIntrinsic::GLOp(GLOp::Sinh)), + ("sqrt", LibmIntrinsic::GLOp(GLOp::Sqrt)), + ("sqrtf", LibmIntrinsic::GLOp(GLOp::Sqrt)), + ("tan", LibmIntrinsic::GLOp(GLOp::Tan)), + ("tanf", LibmIntrinsic::GLOp(GLOp::Tan)), + ("tanh", LibmIntrinsic::GLOp(GLOp::Tanh)), + ("tanhf", LibmIntrinsic::GLOp(GLOp::Tanh)), + ("trunc", LibmIntrinsic::GLOp(GLOp::Trunc)), + ("truncf", LibmIntrinsic::GLOp(GLOp::Trunc)), +]; + +impl Builder<'_, '_> { + pub fn call_libm_intrinsic( + &mut self, + intrinsic: LibmIntrinsic, + result_type: Word, + args: &[SpirvValue], + ) -> SpirvValue { + match intrinsic { + LibmIntrinsic::GLOp(op) => self.gl_op(op, result_type, args), + LibmIntrinsic::Custom(LibmCustomIntrinsic::SinCos) => { + assert_eq!(args.len(), 1); + let x = args[0]; + let sin = self.gl_op(GLOp::Sin, x.ty, &[x]).def(self); + let cos = self.gl_op(GLOp::Cos, x.ty, &[x]).def(self); + self.emit() + .composite_construct(result_type, None, [sin, cos].iter().copied()) + .unwrap() + .with_type(result_type) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Fmod) => { + assert_eq!(args.len(), 2); + self.emit() + .f_mod(result_type, None, args[0].def(self), args[1].def(self)) + .unwrap() + .with_type(result_type) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::CopySign) => { + assert_eq!(args.len(), 2); + self.copysign(args[0], args[1]) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Cbrt) => { + assert_eq!(args.len(), 1); + self.gl_op( + GLOp::Pow, + result_type, + &[args[0], self.constant_float(args[0].ty, 1.0 / 3.0)], + ) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Log10) => { + assert_eq!(args.len(), 1); + // log10(x) == (1 / ln(10)) * ln(x) + let mul = self.constant_float(args[0].ty, 1.0 / 10.0f64.ln()); + let ln = self.gl_op(GLOp::Log, result_type, [args[0]]); + self.mul(mul, ln) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Log1p) => { + assert_eq!(args.len(), 1); + let one = self.constant_float(args[0].ty, 1.0); + let add = self.add(args[0], one); + self.gl_op(GLOp::Log, result_type, &[add]) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Exp10) => { + assert_eq!(args.len(), 1); + // exp10(x) == exp(x * log(10)); + let log10 = self.constant_float(args[0].ty, 10.0f64.ln()); + let mul = self.mul(args[0], log10); + self.gl_op(GLOp::Exp, result_type, [mul]) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Expm1) => { + let exp = self.gl_op(GLOp::Exp, args[0].ty, &[args[0]]); + let one = self.constant_float(exp.ty, 1.0); + self.sub(exp, one) + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Erf) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Erf not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Erfc) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Erfc not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Fdim) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Fdim not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Hypot) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Hypot not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Ilogb) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Ilogb not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::J0) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "J0 not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Y0) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Y0 not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::J1) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "J1 not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Y1) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Y1 not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Jn) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Jn not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Yn) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Yn not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Lgamma) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Lgamma not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::LgammaR) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "LgammaR not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Tgamma) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Tgamma not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::NextAfter) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "NextAfter not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Remainder) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Remainder not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::RemQuo) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "RemQuo not supported yet"); + undef + } + LibmIntrinsic::Custom(LibmCustomIntrinsic::Scalbn) => { + let undef = self.undef(result_type); + self.zombie(undef.def(self), "Scalbn not supported yet"); + undef + } + } + } +} diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index 90eb4e9e85..27f799a56d 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -1,6 +1,7 @@ mod builder_methods; mod ext_inst; mod intrinsics; +pub mod libm_intrinsics; mod spirv_asm; pub use ext_inst::ExtInst; diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index cae9016b0a..552a149aa5 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -10,7 +10,7 @@ use rustc_middle::bug; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility}; use rustc_middle::ty::layout::FnAbiExt; -use rustc_middle::ty::{Instance, ParamEnv, TypeFoldable}; +use rustc_middle::ty::{self, Instance, ParamEnv, TypeFoldable}; use rustc_span::def_id::DefId; use rustc_span::Span; use rustc_target::abi::call::FnAbi; @@ -125,6 +125,23 @@ impl<'tcx> CodegenCx<'tcx> { } } + let instance_def_id = instance.def_id(); + if self.tcx.crate_name(instance_def_id.krate) == self.sym.libm { + let item_name = self.tcx.item_name(instance_def_id); + let intrinsic = self.sym.libm_intrinsics.get(&item_name); + if self.tcx.visibility(instance.def_id()) == ty::Visibility::Public { + match intrinsic { + Some(&intrinsic) => { + self.libm_intrinsics.borrow_mut().insert(fn_id, intrinsic); + } + None => self.tcx.sess.err(&format!( + "missing libm intrinsic {}, which is {}", + symbol_name, instance + )), + } + } + } + declared } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index 8f318a43df..6adac6b895 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -53,6 +53,7 @@ pub struct CodegenCx<'tcx> { pub instruction_table: InstructionTable, pub really_unsafe_ignore_bitcasts: RefCell>, pub zombie_undefs_for_system_constant_pointers: RefCell>, + pub libm_intrinsics: RefCell>, /// Some runtimes (e.g. intel-compute-runtime) disallow atomics on i8 and i16, even though it's allowed by the spec. /// This enables/disables them. pub i8_i16_atomics_allowed: bool, @@ -104,6 +105,7 @@ impl<'tcx> CodegenCx<'tcx> { instruction_table: InstructionTable::new(), really_unsafe_ignore_bitcasts: Default::default(), zombie_undefs_for_system_constant_pointers: Default::default(), + libm_intrinsics: Default::default(), i8_i16_atomics_allowed: false, } } @@ -168,6 +170,8 @@ impl<'tcx> CodegenCx<'tcx> { .contains_name(self.tcx.hir().krate_attrs(), sym::compiler_builtins) || self.tcx.crate_name(LOCAL_CRATE) == sym::core || self.tcx.crate_name(LOCAL_CRATE) == self.sym.spirv_std + || self.tcx.crate_name(LOCAL_CRATE) == self.sym.libm + || self.tcx.crate_name(LOCAL_CRATE) == self.sym.num_traits } pub fn finalize_module(self) -> Module { diff --git a/crates/rustc_codegen_spirv/src/lib.rs b/crates/rustc_codegen_spirv/src/lib.rs index 73904ab5ca..6536a7937f 100644 --- a/crates/rustc_codegen_spirv/src/lib.rs +++ b/crates/rustc_codegen_spirv/src/lib.rs @@ -111,8 +111,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::middle::cstore::{EncodedMetadata, MetadataLoader, MetadataLoaderDyn}; use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility}; use rustc_middle::ty::print::with_no_trimmed_paths; -use rustc_middle::ty::query::Providers; -use rustc_middle::ty::{self, DefIdTree, Instance, InstanceDef, TyCtxt}; +use rustc_middle::ty::{self, query, DefIdTree, Instance, InstanceDef, TyCtxt}; use rustc_mir::util::write_mir_pretty; use rustc_session::config::{self, OptLevel, OutputFilenames, OutputType}; use rustc_session::Session; @@ -267,7 +266,7 @@ impl CodegenBackend for SpirvCodegenBackend { Box::new(SpirvMetadataLoader) } - fn provide(&self, providers: &mut Providers) { + fn provide(&self, providers: &mut query::Providers) { // This is a lil weird: so, we obviously don't support C ABIs at all. However, libcore does declare some extern // C functions: // https://github.com/rust-lang/rust/blob/5fae56971d8487088c0099c82c0a5ce1638b5f62/library/core/src/slice/cmp.rs#L119 @@ -290,7 +289,7 @@ impl CodegenBackend for SpirvCodegenBackend { }; } - fn provide_extern(&self, providers: &mut Providers) { + fn provide_extern(&self, providers: &mut query::Providers) { // See comments in provide(), only this time we use the default *extern* provider. providers.fn_sig = |tcx, def_id| { let result = (rustc_interface::DEFAULT_EXTERN_QUERY_PROVIDERS.fn_sig)(tcx, def_id); diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index 024221aab5..7b81336052 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -1,3 +1,4 @@ +use crate::builder::libm_intrinsics; use crate::codegen_cx::CodegenCx; use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass}; use rustc_ast::ast::{AttrKind, Attribute, Lit, LitIntType, LitKind, NestedMetaItem}; @@ -15,6 +16,8 @@ pub struct Symbols { pub spirv: Symbol, pub spirv_std: Symbol, + pub libm: Symbol, + pub num_traits: Symbol, pub kernel: Symbol, pub simple: Symbol, pub vulkan: Symbol, @@ -30,6 +33,7 @@ pub struct Symbols { really_unsafe_ignore_bitcasts: Symbol, attributes: HashMap, execution_modes: HashMap, + pub libm_intrinsics: HashMap, } const BUILTINS: &[(&str, BuiltIn)] = { @@ -322,22 +326,30 @@ impl Symbols { .chain(execution_models) .map(|(a, b)| (Symbol::intern(a), b)); let mut attributes = HashMap::new(); - attributes_iter.for_each(|(a, b)| { + for (a, b) in attributes_iter { let old = attributes.insert(a, b); // `.collect()` into a HashMap does not error on duplicates, so manually write out the // loop here to error on duplicates. assert!(old.is_none()); - }); + } let mut execution_modes = HashMap::new(); - EXECUTION_MODES.iter().for_each(|(key, mode, dim)| { - let old = execution_modes.insert(Symbol::intern(key), (*mode, *dim)); + for &(key, mode, dim) in EXECUTION_MODES { + let old = execution_modes.insert(Symbol::intern(key), (mode, dim)); assert!(old.is_none()); - }); + } + + let mut libm_intrinsics = HashMap::new(); + for &(a, b) in libm_intrinsics::TABLE { + let old = libm_intrinsics.insert(Symbol::intern(a), b); + assert!(old.is_none()); + } Self { fmt_decimal: Symbol::intern("fmt_decimal"), spirv: Symbol::intern("spirv"), spirv_std: Symbol::intern("spirv_std"), + libm: Symbol::intern("libm"), + num_traits: Symbol::intern("num_traits"), kernel: Symbol::intern("kernel"), simple: Symbol::intern("simple"), vulkan: Symbol::intern("vulkan"), @@ -353,6 +365,7 @@ impl Symbols { really_unsafe_ignore_bitcasts: Symbol::intern("really_unsafe_ignore_bitcasts"), attributes, execution_modes, + libm_intrinsics, } } }