diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index c8335a1428..d70e3a64ad 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -1,5 +1,6 @@ #![no_std] #![feature(register_attr, repr_simd, core_intrinsics)] +#![cfg_attr(target_arch = "spirv", feature(asm))] #![register_attr(spirv)] // Our standard Clippy lints that we use in Embark projects, we opt out of a few that are not appropriate for the specific crate (yet) #![warn( @@ -97,6 +98,68 @@ pointer_addrspace!(incoming_ray_payload_khr, IncomingRayPayloadKHR, true); pointer_addrspace!(shader_record_buffer_khr, ShaderRecordBufferKHR, true); pointer_addrspace!(physical_storage_buffer, PhysicalStorageBuffer, true); +pub trait Derivative { + fn ddx(self) -> Self; + fn ddx_fine(self) -> Self; + fn ddx_coarse(self) -> Self; + fn ddy(self) -> Self; + fn ddy_fine(self) -> Self; + fn ddy_coarse(self) -> Self; + fn fwidth(self) -> Self; + fn fwidth_fine(self) -> Self; + fn fwidth_coarse(self) -> Self; +} + +#[cfg(target_arch = "spirv")] +macro_rules! deriv_caps { + (true) => { + asm!("OpCapability DerivativeControl") + }; + (false) => {}; +} + +macro_rules! deriv_fn { + ($name:ident, $inst:ident, $needs_caps:tt) => { + fn $name(self) -> Self { + #[cfg(not(target_arch = "spirv"))] + panic!(concat!(stringify!($name), " is not supported on the CPU")); + #[cfg(target_arch = "spirv")] + unsafe { + let o; + deriv_caps!($needs_caps); + asm!( + concat!("{1} = ", stringify!($inst), " typeof{0} {0}"), + in(reg) self, + out(reg) o, + ); + o + } + } + }; +} +macro_rules! deriv_impl { + ($ty:ty) => { + impl Derivative for $ty { + deriv_fn!(ddx, OpDPdx, false); + deriv_fn!(ddx_fine, OpDPdxFine, true); + deriv_fn!(ddx_coarse, OpDPdxCoarse, true); + deriv_fn!(ddy, OpDPdy, false); + deriv_fn!(ddy_fine, OpDPdyFine, true); + deriv_fn!(ddy_coarse, OpDPdyCoarse, true); + deriv_fn!(fwidth, OpFwidth, false); + deriv_fn!(fwidth_fine, OpFwidthFine, true); + deriv_fn!(fwidth_coarse, OpFwidthCoarse, true); + } + }; +} + +// "must be a scalar or vector of floating-point type. The component width must be 32 bits." +deriv_impl!(f32); +// TODO: Fix rustc to support these +// deriv_impl!(glam::Vec2); +// deriv_impl!(glam::Vec3); +// deriv_impl!(glam::Vec4); + /// libcore requires a few external symbols to be defined: /// /// TODO: This is copied from `compiler_builtins/mem.rs`. Can we use that one instead? The note in the above link says