From 906877421f5faf48133f0d586342e1eb16d33078 Mon Sep 17 00:00:00 2001 From: Rua Date: Sun, 19 Dec 2021 11:44:58 +0100 Subject: [PATCH] Always generate shader requirements for the exact entrypoint interface (#1778) --- vulkano-shaders/src/codegen.rs | 7 +++---- vulkano-shaders/src/lib.rs | 4 ---- vulkano/src/shader/mod.rs | 28 ++++++++++++++++++---------- vulkano/src/shader/reflect.rs | 12 ++++-------- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/vulkano-shaders/src/codegen.rs b/vulkano-shaders/src/codegen.rs index f4b85568..40b0dda3 100644 --- a/vulkano-shaders/src/codegen.rs +++ b/vulkano-shaders/src/codegen.rs @@ -210,7 +210,6 @@ pub(super) fn reflect<'a, I>( words: &[u32], types_meta: &TypesMeta, input_paths: I, - exact_entrypoint_interface: bool, shared_constants: bool, types_registry: &'a mut HashMap, ) -> Result<(TokenStream, TokenStream), Error> @@ -244,7 +243,7 @@ where quote! { &Capability::#name } }); let spirv_extensions = reflect::spirv_extensions(&spirv); - let entry_points = reflect::entry_points(&spirv, exact_entrypoint_interface) + let entry_points = reflect::entry_points(&spirv) .map(|(name, model, info)| entry_point::write_entry_point(&name, model, &info)); let specialization_constants = structs::write_specialization_constants( @@ -711,7 +710,7 @@ mod tests { let spirv = Spirv::new(&instructions).unwrap(); let mut descriptors = Vec::new(); - for (_, _, info) in reflect::entry_points(&spirv, true) { + for (_, _, info) in reflect::entry_points(&spirv) { descriptors.push(info.descriptor_requirements); } @@ -782,7 +781,7 @@ mod tests { .unwrap(); let spirv = Spirv::new(comp.as_binary()).unwrap(); - for (_, _, info) in reflect::entry_points(&spirv, true) { + for (_, _, info) in reflect::entry_points(&spirv) { let mut bindings = Vec::new(); for (loc, _reqs) in info.descriptor_requirements { bindings.push(loc); diff --git a/vulkano-shaders/src/lib.rs b/vulkano-shaders/src/lib.rs index 6392ea8e..a25bc124 100644 --- a/vulkano-shaders/src/lib.rs +++ b/vulkano-shaders/src/lib.rs @@ -356,7 +356,6 @@ impl RegisteredType { struct MacroInput { dump: bool, - exact_entrypoint_interface: bool, include_directories: Vec, macro_defines: Vec<(String, String)>, shared_constants: bool, @@ -762,7 +761,6 @@ impl Parse for MacroInput { Ok(Self { dump: dump.unwrap_or(false), - exact_entrypoint_interface: exact_entrypoint_interface.unwrap_or(false), include_directories, macro_defines, shared_constants: shared_constants.unwrap_or(false), @@ -818,7 +816,6 @@ pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { unsafe { from_raw_parts(bytes.as_slice().as_ptr() as *const u32, bytes.len() / 4) }, &input.types_meta, empty(), - input.exact_entrypoint_interface, input.shared_constants, &mut types_registry, ) @@ -884,7 +881,6 @@ pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { content.as_binary(), &input.types_meta, input_paths, - input.exact_entrypoint_interface, input.shared_constants, &mut types_registry, ) diff --git a/vulkano/src/shader/mod.rs b/vulkano/src/shader/mod.rs index 1264cdf7..3e50d21f 100644 --- a/vulkano/src/shader/mod.rs +++ b/vulkano/src/shader/mod.rs @@ -80,7 +80,7 @@ impl ShaderModule { spirv.version(), reflect::spirv_capabilities(&spirv), reflect::spirv_extensions(&spirv), - reflect::entry_points(&spirv, false), + reflect::entry_points(&spirv), ) } @@ -170,14 +170,18 @@ impl ShaderModule { .collect::>() .iter() .map(|name| { - ((*name).clone(), - entries.iter().filter_map(|(entry_name, entry_model, info)| { - if &entry_name == name { - Some((*entry_model, info.clone())) - } else { - None - } - }).collect::>() + ( + (*name).clone(), + entries + .iter() + .filter_map(|(entry_name, entry_model, info)| { + if &entry_name == name { + Some((*entry_model, info.clone())) + } else { + None + } + }) + .collect::>(), ) }) .collect(); @@ -235,7 +239,11 @@ impl ShaderModule { /// Returns information about the entry point with the provided name and execution model. Returns /// `None` if no entry and execution model exists in the shader module. - pub fn entry_point_with_execution<'a>(&'a self, name: &str, execution: ExecutionModel) -> Option> { + pub fn entry_point_with_execution<'a>( + &'a self, + name: &str, + execution: ExecutionModel, + ) -> Option> { self.entry_points.get(name).and_then(|infos| { infos.get(&execution).map(|info| EntryPoint { module: self, diff --git a/vulkano/src/shader/reflect.rs b/vulkano/src/shader/reflect.rs index 749444b0..855b3621 100644 --- a/vulkano/src/shader/reflect.rs +++ b/vulkano/src/shader/reflect.rs @@ -53,7 +53,6 @@ pub fn spirv_extensions<'a>(spirv: &'a Spirv) -> impl Iterator { /// Returns an iterator over all entry points in `spirv`, with information about the entry point. pub fn entry_points<'a>( spirv: &'a Spirv, - exact_interface: bool, ) -> impl Iterator + 'a { spirv.iter_entry_point().filter_map(move |instruction| { let (execution_model, function_id, entry_point_name, interface) = match instruction { @@ -70,7 +69,7 @@ pub fn entry_points<'a>( let execution = shader_execution(&spirv, execution_model, function_id); let stage = ShaderStage::from(execution); let descriptor_requirements = - descriptor_requirements(&spirv, function_id, stage, interface, exact_interface); + descriptor_requirements(&spirv, function_id, stage, interface); let push_constant_requirements = push_constant_requirements(&spirv, stage); let specialization_constant_requirements = specialization_constant_requirements(&spirv); let input_interface = shader_interface( @@ -169,13 +168,12 @@ fn descriptor_requirements( function_id: Id, stage: ShaderStage, interface: &[Id], - exact: bool, ) -> FnvHashMap<(u32, u32), DescriptorRequirements> { // For SPIR-V 1.4+, the entrypoint interface can specify variables of all storage classes, // and most tools will put all used variables in the entrypoint interface. However, // SPIR-V 1.0-1.3 do not specify variables other than Input/Output ones in the interface, // and instead the function itself must be inspected. - let variables = if exact { + let variables = { let mut found_variables: HashSet = interface.iter().cloned().collect(); let mut inspected_functions: HashSet = HashSet::new(); find_variables_in_function( @@ -184,9 +182,7 @@ fn descriptor_requirements( &mut inspected_functions, &mut found_variables, ); - Some(found_variables) - } else { - None + found_variables }; // Looping to find all the global variables that have the `DescriptorSet` decoration. @@ -216,7 +212,7 @@ fn descriptor_requirements( _ => return None, }; - if exact && !variables.as_ref().unwrap().contains(&variable_id) { + if !variables.contains(&variable_id) { return None; }