diff --git a/vulkano-shaders/src/codegen.rs b/vulkano-shaders/src/codegen.rs index f50ce020..214f6ce5 100644 --- a/vulkano-shaders/src/codegen.rs +++ b/vulkano-shaders/src/codegen.rs @@ -242,7 +242,7 @@ where }); let spirv_extensions = reflect::spirv_extensions(&spirv); let entry_points = reflect::entry_points(&spirv, exact_entrypoint_interface) - .map(|(name, info)| entry_point::write_entry_point(&name, &info)); + .map(|(name, model, info)| entry_point::write_entry_point(&name, model, &info)); let specialization_constants = structs::write_specialization_constants( prefix, @@ -708,7 +708,7 @@ mod tests { let spirv = Spirv::new(&instructions).unwrap(); let mut descriptors = Vec::new(); - for (_, info) in reflect::entry_points(&spirv, true) { + for (_, _, info) in reflect::entry_points(&spirv, true) { descriptors.push(info.descriptor_requirements); } @@ -779,7 +779,7 @@ mod tests { .unwrap(); let spirv = Spirv::new(comp.as_binary()).unwrap(); - for (_, info) in reflect::entry_points(&spirv, true) { + for (_, _, info) in reflect::entry_points(&spirv, true) { let mut bindings = Vec::new(); for (loc, _reqs) in info.descriptor_requirements { bindings.push(loc); diff --git a/vulkano-shaders/src/entry_point.rs b/vulkano-shaders/src/entry_point.rs index 030a5a33..9b969ec1 100644 --- a/vulkano-shaders/src/entry_point.rs +++ b/vulkano-shaders/src/entry_point.rs @@ -15,9 +15,11 @@ use vulkano::shader::{ SpecializationConstantRequirements, }; use vulkano::shader::{EntryPointInfo, ShaderInterface, ShaderStages}; +use vulkano::shader::spirv::ExecutionModel; -pub(super) fn write_entry_point(name: &str, info: &EntryPointInfo) -> TokenStream { +pub(super) fn write_entry_point(name: &str, model: ExecutionModel, info: &EntryPointInfo) -> TokenStream { let execution = write_shader_execution(&info.execution); + let model = syn::parse_str::(&format!("vulkano::shader::spirv::ExecutionModel::{:?}", model)).unwrap(); let descriptor_requirements = write_descriptor_requirements(&info.descriptor_requirements); let push_constant_requirements = write_push_constant_requirements(&info.push_constant_requirements); @@ -29,6 +31,7 @@ pub(super) fn write_entry_point(name: &str, info: &EntryPointInfo) -> TokenStrea quote! { ( #name.to_owned(), + #model, EntryPointInfo { execution: #execution, descriptor_requirements: std::array::IntoIter::new(#descriptor_requirements).collect(), diff --git a/vulkano/autogen/spirv_parse.rs b/vulkano/autogen/spirv_parse.rs index 8cf929f4..19d395a3 100644 --- a/vulkano/autogen/spirv_parse.rs +++ b/vulkano/autogen/spirv_parse.rs @@ -530,8 +530,13 @@ fn value_enum_output(enums: &[(Ident, Vec)]) -> TokenStream { ); let name_string = name.to_string(); + let derives = match name_string.as_str() { + "ExecutionModel" => quote! { #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] }, + _ => quote! { #[derive(Clone, Debug, PartialEq)] }, + }; + quote! { - #[derive(Clone, Debug, PartialEq)] + #derives #[allow(non_camel_case_types)] pub enum #name { #(#members_items)* diff --git a/vulkano/src/shader/mod.rs b/vulkano/src/shader/mod.rs index 896dc358..cd99f9ab 100644 --- a/vulkano/src/shader/mod.rs +++ b/vulkano/src/shader/mod.rs @@ -32,7 +32,7 @@ use crate::Version; use crate::VulkanObject; use fnv::FnvHashMap; use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::error; use std::error::Error; use std::ffi::CStr; @@ -49,6 +49,8 @@ use std::sync::Arc; pub mod reflect; pub mod spirv; +use spirv::ExecutionModel; + // Generated by build.rs include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs")); @@ -57,7 +59,7 @@ include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs")); pub struct ShaderModule { handle: ash::vk::ShaderModule, device: Arc, - entry_points: HashMap, + entry_points: HashMap>, } impl ShaderModule { @@ -116,7 +118,7 @@ impl ShaderModule { spirv_version: Version, spirv_capabilities: impl IntoIterator, spirv_extensions: impl IntoIterator, - entry_points: impl IntoIterator, + entry_points: impl IntoIterator, ) -> Result, ShaderCreationError> { if let Err(reason) = check_spirv_version(&device, spirv_version) { return Err(ShaderCreationError::SpirvVersionNotSupported { @@ -162,10 +164,29 @@ impl ShaderModule { output.assume_init() }; + let entries = entry_points.into_iter().collect::>(); + let entry_points = entries + .iter() + .filter_map(|(name, _, _)| Some(name)) + .collect::>() + .iter() + .map(|name| { + ((*name).clone(), + entries.iter().filter_map(|(entry_name, entry_model, info)| { + if &entry_name == name { + Some((*entry_model, info.clone())) + } else { + None + } + }).collect::>() + ) + }) + .collect(); + Ok(Arc::new(ShaderModule { handle, device, - entry_points: entry_points.into_iter().collect(), + entry_points, })) } @@ -180,7 +201,7 @@ impl ShaderModule { spirv_version: Version, spirv_capabilities: impl IntoIterator, spirv_extensions: impl IntoIterator, - entry_points: impl IntoIterator, + entry_points: impl IntoIterator, ) -> Result, ShaderCreationError> { assert!((bytes.len() % 4) == 0); Self::from_words_with_data( @@ -197,12 +218,31 @@ impl ShaderModule { } /// Returns information about the entry point with the provided name. Returns `None` if no entry - /// point with that name exists in the shader module. + /// point with that name exists in the shader module or if multiple entry points with the same + /// name exist. pub fn entry_point<'a>(&'a self, name: &str) -> Option> { - self.entry_points.get(name).map(|info| EntryPoint { - module: self, - name: CString::new(name).unwrap(), - info, + self.entry_points.get(name).and_then(|infos| { + if infos.len() == 1 { + infos.iter().next().map(|(_, info)| EntryPoint { + module: self, + name: CString::new(name).unwrap(), + info, + }) + } else { + None + } + }) + } + + /// Returns information about the entry point with the provided name and execution model. Returns + /// `None` if no entry and execution model exists in the shader module. + pub fn entry_point_with_execution<'a>(&'a self, name: &str, execution: ExecutionModel) -> Option> { + self.entry_points.get(name).and_then(|infos| { + infos.get(&execution).map(|info| EntryPoint { + module: self, + name: CString::new(name).unwrap(), + info, + }) }) } } @@ -400,7 +440,7 @@ impl<'a> EntryPoint<'a> { /// The mode in which a shader executes. This includes both information about the shader type/stage, /// and additional data relevant to particular shader types. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum ShaderExecution { Vertex, TessellationControl, @@ -425,7 +465,7 @@ pub enum TessellationShaderSubdivision { }*/ /// The mode in which a geometry shader executes. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct GeometryShaderExecution { pub input: GeometryShaderInput, /*pub max_output_vertices: u32, diff --git a/vulkano/src/shader/reflect.rs b/vulkano/src/shader/reflect.rs index 82cce79a..559dc0d0 100644 --- a/vulkano/src/shader/reflect.rs +++ b/vulkano/src/shader/reflect.rs @@ -53,7 +53,7 @@ pub fn spirv_extensions<'a>(spirv: &'a Spirv) -> impl Iterator { pub fn entry_points<'a>( spirv: &'a Spirv, exact_interface: bool, -) -> impl Iterator + 'a { +) -> impl Iterator + 'a { spirv.iter_entry_point().filter_map(move |instruction| { let (execution_model, function_id, entry_point_name, interface) = match instruction { &Instruction::EntryPoint { @@ -92,6 +92,7 @@ pub fn entry_points<'a>( Some(( entry_point_name.clone(), + *execution_model, EntryPointInfo { execution, descriptor_requirements,