From cf8e8c37321dd711bc26c500bad81014b5d9224f Mon Sep 17 00:00:00 2001 From: Rua Date: Sat, 5 Aug 2023 16:47:17 +0200 Subject: [PATCH] Some shader fixes and improvements (#2280) --- examples/src/bin/runtime-shader/main.rs | 16 ++------- vulkano-shaders/src/lib.rs | 17 ++-------- vulkano-shaders/src/structs.rs | 12 +++---- vulkano/src/shader/mod.rs | 16 +++------ vulkano/src/shader/spirv.rs | 45 ++++++++++++++++++++++++- 5 files changed, 59 insertions(+), 47 deletions(-) diff --git a/examples/src/bin/runtime-shader/main.rs b/examples/src/bin/runtime-shader/main.rs index 970ec498..f04fca88 100644 --- a/examples/src/bin/runtime-shader/main.rs +++ b/examples/src/bin/runtime-shader/main.rs @@ -440,17 +440,7 @@ fn read_spirv_words_from_file(path: impl AsRef) -> Vec { }); file.read_to_end(&mut bytes).unwrap(); - // Convert the bytes to words. - // SPIR-V is defined to be always little-endian, so this may need an endianness conversion. - assert!( - bytes.len() % 4 == 0, - "file `{}` does not contain a whole number of SPIR-V words", - path.display(), - ); - - // TODO: Use `slice::array_chunks` once it's stable. - bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - .collect() + vulkano::shader::spirv::bytes_to_words(&bytes) + .unwrap_or_else(|err| panic!("file `{}`: {}", path.display(), err)) + .into_owned() } diff --git a/vulkano-shaders/src/lib.rs b/vulkano-shaders/src/lib.rs index d87e9cbd..300ffffe 100644 --- a/vulkano-shaders/src/lib.rs +++ b/vulkano-shaders/src/lib.rs @@ -230,7 +230,6 @@ use shaderc::{EnvVersion, SpirvVersion}; use std::{ env, fs, mem, path::{Path, PathBuf}, - slice, }; use structs::TypeRegistry; use syn::{ @@ -336,20 +335,10 @@ fn shader_inner(mut input: MacroInput) -> Result { let bytes = fs::read(&full_path) .or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?; - if bytes.len() % 4 != 0 { - bail!(path, "SPIR-V bytes must be an integer multiple of 4"); - } + let words = vulkano::shader::spirv::bytes_to_words(&bytes) + .or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?; - // Here, we are praying that the system allocator of the user aligns allocations to - // at least 4, which *should* be the case on all targets. - assert_eq!(bytes.as_ptr() as usize % 4, 0); - - // SAFETY: We checked that the bytes are aligned correctly for `u32`, and that - // there is an integer number of `u32`s contained. - let words = - unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len() / 4) }; - - codegen::reflect(&input, path, name, words, Vec::new(), &mut type_registry)? + codegen::reflect(&input, path, name, &words, Vec::new(), &mut type_registry)? } }; diff --git a/vulkano-shaders/src/structs.rs b/vulkano-shaders/src/structs.rs index 9d6dd519..e6f7b6a5 100644 --- a/vulkano-shaders/src/structs.rs +++ b/vulkano-shaders/src/structs.rs @@ -667,12 +667,7 @@ impl TypeStruct { Instruction::Name { name, .. } => Some(Ident::new(name, Span::call_site())), _ => None, }) - .ok_or_else(|| { - Error::new_spanned( - &shader.source, - "expected struct in shader interface to have an associated `Name` instruction", - ) - })?; + .unwrap_or_else(|| format_ident!("Unnamed{}", struct_id.as_raw())); let mut members = Vec::::with_capacity(member_type_ids.len()); @@ -822,7 +817,10 @@ impl TypeStruct { members.push(Member { ident, ty, offset }); } - Ok(TypeStruct { ident, members }) + Ok(TypeStruct { + ident, + members, + }) } fn size(&self) -> Option { diff --git a/vulkano/src/shader/mod.rs b/vulkano/src/shader/mod.rs index 0c5512ed..55a455ae 100644 --- a/vulkano/src/shader/mod.rs +++ b/vulkano/src/shader/mod.rs @@ -151,7 +151,7 @@ use spirv::ExecutionModel; use std::{ borrow::Cow, collections::hash_map::Entry, - mem::{align_of, discriminant, size_of, size_of_val, MaybeUninit}, + mem::{discriminant, size_of_val, MaybeUninit}, num::NonZeroU64, ptr, sync::Arc, @@ -357,23 +357,15 @@ impl ShaderModule { /// - Panics if the length of `bytes` is not a multiple of 4. #[deprecated( since = "0.34.0", - note = "read little-endian words yourself, and then use `new` instead" + note = "use `shader::spirv::bytes_to_words`, and then use `new` instead" )] #[inline] pub unsafe fn from_bytes( device: Arc, bytes: &[u8], ) -> Result, Validated> { - assert!(bytes.as_ptr() as usize % align_of::() == 0); - assert!(bytes.len() % size_of::() == 0); - - Self::new( - device, - ShaderModuleCreateInfo::new(std::slice::from_raw_parts( - bytes.as_ptr() as *const _, - bytes.len() / size_of::(), - )), - ) + let words = spirv::bytes_to_words(bytes).unwrap(); + Self::new(device, ShaderModuleCreateInfo::new(&words)) } /// Returns information about the entry point with the provided name. Returns `None` if no entry diff --git a/vulkano/src/shader/spirv.rs b/vulkano/src/shader/spirv.rs index 14180d30..b0c43cb3 100644 --- a/vulkano/src/shader/spirv.rs +++ b/vulkano/src/shader/spirv.rs @@ -18,6 +18,7 @@ use crate::Version; use ahash::{HashMap, HashMapExt}; use std::{ + borrow::Cow, error::Error, fmt::{Display, Error as FmtError, Formatter}, ops::Range, @@ -563,10 +564,18 @@ impl<'a> StructMemberInfo<'a> { #[repr(transparent)] pub struct Id(u32); +impl Id { + // Returns the raw numeric value of this Id. + #[inline] + pub const fn as_raw(self) -> u32 { + self.0 + } +} + impl From for u32 { #[inline] fn from(id: Id) -> u32 { - id.0 + id.as_raw() } } @@ -792,3 +801,37 @@ impl Display for ParseErrors { } } } + +/// Converts SPIR-V bytes to words. If necessary, the byte order is swapped from little-endian +/// to native-endian. +pub fn bytes_to_words(bytes: &[u8]) -> Result, SpirvBytesNotMultipleOf4> { + // If the current target is little endian, and the slice already has the right size and + // alignment, then we can just transmute the slice with bytemuck. + #[cfg(target_endian = "little")] + if let Ok(words) = bytemuck::try_cast_slice(bytes) { + return Ok(Cow::Borrowed(words)); + } + + if bytes.len() % 4 != 0 { + return Err(SpirvBytesNotMultipleOf4); + } + + // TODO: Use `slice::array_chunks` once it's stable. + let words: Vec = bytes + .chunks_exact(4) + .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) + .collect(); + + Ok(Cow::Owned(words)) +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct SpirvBytesNotMultipleOf4; + +impl Error for SpirvBytesNotMultipleOf4 {} + +impl Display for SpirvBytesNotMultipleOf4 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "the length of the provided slice is not a multiple of 4") + } +}