Some shader fixes and improvements (#2280)

This commit is contained in:
Rua 2023-08-05 16:47:17 +02:00 committed by GitHub
parent b00e6627b4
commit cf8e8c3732
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 47 deletions

View File

@ -440,17 +440,7 @@ fn read_spirv_words_from_file(path: impl AsRef<Path>) -> Vec<u32> {
}); });
file.read_to_end(&mut bytes).unwrap(); file.read_to_end(&mut bytes).unwrap();
// Convert the bytes to words. vulkano::shader::spirv::bytes_to_words(&bytes)
// SPIR-V is defined to be always little-endian, so this may need an endianness conversion. .unwrap_or_else(|err| panic!("file `{}`: {}", path.display(), err))
assert!( .into_owned()
bytes.len() % 4 == 0,
"file `{}` does not contain a whole number of SPIR-V words",
path.display(),
);
// TODO: Use `slice::array_chunks` once it's stable.
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect()
} }

View File

@ -230,7 +230,6 @@ use shaderc::{EnvVersion, SpirvVersion};
use std::{ use std::{
env, fs, mem, env, fs, mem,
path::{Path, PathBuf}, path::{Path, PathBuf},
slice,
}; };
use structs::TypeRegistry; use structs::TypeRegistry;
use syn::{ use syn::{
@ -336,20 +335,10 @@ fn shader_inner(mut input: MacroInput) -> Result<TokenStream> {
let bytes = fs::read(&full_path) let bytes = fs::read(&full_path)
.or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?; .or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?;
if bytes.len() % 4 != 0 { let words = vulkano::shader::spirv::bytes_to_words(&bytes)
bail!(path, "SPIR-V bytes must be an integer multiple of 4"); .or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?;
}
// Here, we are praying that the system allocator of the user aligns allocations to codegen::reflect(&input, path, name, &words, Vec::new(), &mut type_registry)?
// at least 4, which *should* be the case on all targets.
assert_eq!(bytes.as_ptr() as usize % 4, 0);
// SAFETY: We checked that the bytes are aligned correctly for `u32`, and that
// there is an integer number of `u32`s contained.
let words =
unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len() / 4) };
codegen::reflect(&input, path, name, words, Vec::new(), &mut type_registry)?
} }
}; };

View File

@ -667,12 +667,7 @@ impl TypeStruct {
Instruction::Name { name, .. } => Some(Ident::new(name, Span::call_site())), Instruction::Name { name, .. } => Some(Ident::new(name, Span::call_site())),
_ => None, _ => None,
}) })
.ok_or_else(|| { .unwrap_or_else(|| format_ident!("Unnamed{}", struct_id.as_raw()));
Error::new_spanned(
&shader.source,
"expected struct in shader interface to have an associated `Name` instruction",
)
})?;
let mut members = Vec::<Member>::with_capacity(member_type_ids.len()); let mut members = Vec::<Member>::with_capacity(member_type_ids.len());
@ -822,7 +817,10 @@ impl TypeStruct {
members.push(Member { ident, ty, offset }); members.push(Member { ident, ty, offset });
} }
Ok(TypeStruct { ident, members }) Ok(TypeStruct {
ident,
members,
})
} }
fn size(&self) -> Option<usize> { fn size(&self) -> Option<usize> {

View File

@ -151,7 +151,7 @@ use spirv::ExecutionModel;
use std::{ use std::{
borrow::Cow, borrow::Cow,
collections::hash_map::Entry, collections::hash_map::Entry,
mem::{align_of, discriminant, size_of, size_of_val, MaybeUninit}, mem::{discriminant, size_of_val, MaybeUninit},
num::NonZeroU64, num::NonZeroU64,
ptr, ptr,
sync::Arc, sync::Arc,
@ -357,23 +357,15 @@ impl ShaderModule {
/// - Panics if the length of `bytes` is not a multiple of 4. /// - Panics if the length of `bytes` is not a multiple of 4.
#[deprecated( #[deprecated(
since = "0.34.0", since = "0.34.0",
note = "read little-endian words yourself, and then use `new` instead" note = "use `shader::spirv::bytes_to_words`, and then use `new` instead"
)] )]
#[inline] #[inline]
pub unsafe fn from_bytes( pub unsafe fn from_bytes(
device: Arc<Device>, device: Arc<Device>,
bytes: &[u8], bytes: &[u8],
) -> Result<Arc<ShaderModule>, Validated<VulkanError>> { ) -> Result<Arc<ShaderModule>, Validated<VulkanError>> {
assert!(bytes.as_ptr() as usize % align_of::<u32>() == 0); let words = spirv::bytes_to_words(bytes).unwrap();
assert!(bytes.len() % size_of::<u32>() == 0); Self::new(device, ShaderModuleCreateInfo::new(&words))
Self::new(
device,
ShaderModuleCreateInfo::new(std::slice::from_raw_parts(
bytes.as_ptr() as *const _,
bytes.len() / size_of::<u32>(),
)),
)
} }
/// Returns information about the entry point with the provided name. Returns `None` if no entry /// Returns information about the entry point with the provided name. Returns `None` if no entry

View File

@ -18,6 +18,7 @@
use crate::Version; use crate::Version;
use ahash::{HashMap, HashMapExt}; use ahash::{HashMap, HashMapExt};
use std::{ use std::{
borrow::Cow,
error::Error, error::Error,
fmt::{Display, Error as FmtError, Formatter}, fmt::{Display, Error as FmtError, Formatter},
ops::Range, ops::Range,
@ -563,10 +564,18 @@ impl<'a> StructMemberInfo<'a> {
#[repr(transparent)] #[repr(transparent)]
pub struct Id(u32); pub struct Id(u32);
impl Id {
// Returns the raw numeric value of this Id.
#[inline]
pub const fn as_raw(self) -> u32 {
self.0
}
}
impl From<Id> for u32 { impl From<Id> for u32 {
#[inline] #[inline]
fn from(id: Id) -> u32 { fn from(id: Id) -> u32 {
id.0 id.as_raw()
} }
} }
@ -792,3 +801,37 @@ impl Display for ParseErrors {
} }
} }
} }
/// Converts SPIR-V bytes to words. If necessary, the byte order is swapped from little-endian
/// to native-endian.
pub fn bytes_to_words(bytes: &[u8]) -> Result<Cow<'_, [u32]>, SpirvBytesNotMultipleOf4> {
// If the current target is little endian, and the slice already has the right size and
// alignment, then we can just transmute the slice with bytemuck.
#[cfg(target_endian = "little")]
if let Ok(words) = bytemuck::try_cast_slice(bytes) {
return Ok(Cow::Borrowed(words));
}
if bytes.len() % 4 != 0 {
return Err(SpirvBytesNotMultipleOf4);
}
// TODO: Use `slice::array_chunks` once it's stable.
let words: Vec<u32> = bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect();
Ok(Cow::Owned(words))
}
#[derive(Clone, Copy, Debug, Default)]
pub struct SpirvBytesNotMultipleOf4;
impl Error for SpirvBytesNotMultipleOf4 {}
impl Display for SpirvBytesNotMultipleOf4 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "the length of the provided slice is not a multiple of 4")
}
}