diff --git a/vulkano/src/shader/reflect.rs b/vulkano/src/shader/reflect.rs index c9bbd58f..57d6543e 100644 --- a/vulkano/src/shader/reflect.rs +++ b/vulkano/src/shader/reflect.rs @@ -8,9 +8,10 @@ use crate::{ shader::{ spirv::{Decoration, Dim, ExecutionModel, Id, Instruction, Spirv, StorageClass}, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, NumericType, ShaderInterface, - ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, SpecializationConstant, + ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, ShaderStages, + SpecializationConstant, }, - DeviceSize, + DeviceSize, Version, }; use ahash::{HashMap, HashSet}; use half::f16; @@ -41,7 +42,12 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator stage, function_id, ); - let push_constant_requirements = push_constant_requirements(spirv, stage); + let push_constant_requirements = push_constant_requirements( + &interface_variables.push_constant, + spirv, + stage, + function_id, + ); let input_interface = shader_interface( spirv, interface, @@ -77,6 +83,7 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator #[derive(Clone, Debug, Default)] struct InterfaceVariables { descriptor_binding: HashMap, + push_constant: HashMap, } // See also section 14.5.2 of the Vulkan specs: Descriptor Set Interface. @@ -93,7 +100,7 @@ fn interface_variables(spirv: &Spirv) -> InterfaceVariables { for instruction in spirv.iter_global() { if let Instruction::Variable { result_id, - result_type_id: _, + result_type_id, storage_class, .. } = instruction @@ -107,6 +114,12 @@ fn interface_variables(spirv: &Spirv) -> InterfaceVariables { descriptor_binding_requirements_of(spirv, *result_id), ); } + StorageClass::PushConstant => { + variables.push_constant.insert( + *result_id, + push_constant_requirements_of(spirv, *result_type_id), + ); + } _ => (), } } @@ -841,48 +854,125 @@ fn descriptor_binding_requirements_of(spirv: &Spirv, variable_id: Id) -> Descrip } } +fn push_constant_requirements_of(spirv: &Spirv, pointer_type_id: Id) -> PushConstantRange { + let struct_type_id = match *spirv.id(pointer_type_id).instruction() { + Instruction::TypePointer { + ty, + storage_class: StorageClass::PushConstant, + .. + } => ty, + _ => unreachable!(), + }; + + assert!( + matches!( + spirv.id(struct_type_id).instruction(), + Instruction::TypeStruct { .. } + ), + "VUID-StandaloneSpirv-PushConstant-06808" + ); + let start = offset_of_struct(spirv, struct_type_id); + let end = + size_of_type(spirv, struct_type_id).expect("Found runtime-sized push constants") as u32; + + PushConstantRange { + stages: ShaderStages::default(), + offset: start, + size: end - start, + } +} + /// Extracts the `PushConstantRange` from `spirv`. -fn push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option { - spirv - .iter_global() - // TODO: doesn't work with more than one entry point in the shader. - // We should really use the interface variables in Instruction::EntryPoint, - // but before SPIR-V 1.4 these do not contain push constants. - .find_map(|instruction| match *instruction { - Instruction::Variable { - result_type_id, - storage_class: StorageClass::PushConstant, - .. - } => Some(result_type_id), - _ => None, - }) - .map(|pointer_type_id| { - let struct_type_id = match *spirv.id(pointer_type_id).instruction() { - Instruction::TypePointer { - ty, - storage_class: StorageClass::PushConstant, +fn push_constant_requirements( + global: &HashMap, + spirv: &Spirv, + stage: ShaderStage, + function_id: Id, +) -> Option { + fn find_variables_used( + function_id: Id, + global: &HashMap, + spirv: &Spirv, + visited_fns: &mut HashSet, + variables: &mut HashSet, + ) { + visited_fns.insert(function_id); + let function_info = spirv.function(function_id); + for instruction in function_info.iter_instructions() { + match instruction { + Instruction::FunctionCall { + function, + arguments, .. - } => ty, - _ => unreachable!(), - }; - - assert!( - matches!( - spirv.id(struct_type_id).instruction(), - Instruction::TypeStruct { .. } - ), - "VUID-StandaloneSpirv-PushConstant-06808" - ); - let start = offset_of_struct(spirv, struct_type_id); - let end = size_of_type(spirv, struct_type_id) - .expect("Found runtime-sized push constants") as u32; - - PushConstantRange { - stages: stage.into(), - offset: start, - size: end - start, + } => { + for arg in arguments { + if global.contains_key(arg) { + variables.insert(*arg); + } + } + if !visited_fns.contains(function) { + find_variables_used(*function, global, spirv, visited_fns, variables); + } + } + Instruction::AccessChain { + base: variable_id, .. + } + | Instruction::InBoundsAccessChain { + base: variable_id, .. + } + | Instruction::PtrAccessChain { + base: variable_id, .. + } + | Instruction::InBoundsPtrAccessChain { + base: variable_id, .. + } + | Instruction::Load { + pointer: variable_id, + .. + } + | Instruction::CopyMemory { + source: variable_id, + .. + } + | Instruction::CopyObject { + result_id: variable_id, + .. + } => { + if global.contains_key(variable_id) { + variables.insert(*variable_id); + } + } + _ => (), } - }) + } + } + + if global.is_empty() { + return None; + } + let mut variables = HashSet::default(); + if spirv.version() < Version::V1_4 { + let mut visited_fns = HashSet::default(); + find_variables_used(function_id, global, spirv, &mut visited_fns, &mut variables); + } else if let Instruction::EntryPoint { interface, .. } = + spirv.function(function_id).entry_point().unwrap() + { + for id in interface { + if global.contains_key(id) { + variables.insert(*id); + } + } + } else { + unreachable!(); + } + assert!( + variables.len() <= 1, + "VUID-StandaloneSpirv-OpEntryPoint-06674" + ); + let variable_id = variables.into_iter().next()?; + let mut push_constant_range = global.get(&variable_id).copied().unwrap(); + push_constant_range.stages = stage.into(); + Some(push_constant_range) } /// Extracts the `SpecializationConstant` map from `spirv`. @@ -1339,3 +1429,202 @@ fn is_builtin(spirv: &Spirv, id: Id) -> bool { _ => false, } } + +#[cfg(test)] +mod tests { + use super::{HashMap, PushConstantRange, ShaderStages, Version}; + + #[test] + fn push_constant_range() { + /* + ; SPIR-V + ; Version: 1.0 + ; Generator: Google Shaderc over Glslang; 10 + ; Bound: 27 + ; Schema: 0 + OpCapability Shader + %glsl_std450 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main_cs "main_cs" %push_cs + OpEntryPoint Fragment %main_fs "main_fs" %push_fs + OpExecutionMode %main_cs LocalSize 1 1 1 + OpExecutionMode %main_fs OriginUpperLeft + OpName %main_cs "main_cs" + OpName %PushConstsCS "PushCS" + OpMemberName %PushConstsCS 0 "a" + OpMemberName %PushConstsCS 1 "b" + OpName %main_fs "main_fs" + OpName %push_fs "PushFS" + OpMemberName %PushConstsFS 0 "a" + OpMemberDecorate %PushConstsCS 0 Offset 0 + OpMemberDecorate %PushConstsCS 1 Offset 4 + OpDecorate %PushConstsCS Block + OpMemberDecorate %PushConstsFS 0 Offset 0 + OpDecorate %PushConstsFS Block + %void = OpTypeVoid + %fn_void = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 + %float = OpTypeFloat 32 + %PushConstsCS = OpTypeStruct %uint %float + %_ptr_PushConstant_PushConstsCS = OpTypePointer PushConstant %PushConstsCS + %push_cs = OpVariable %_ptr_PushConstant_PushConstsCS PushConstant + %_ptr_PushConstant_uint = OpTypePointer PushConstant %uint + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %_ptr_PushConstant_float = OpTypePointer PushConstant %float + %PushConstsFS = OpTypeStruct %float + %_ptr_PushConstant_PushConstsFS = OpTypePointer PushConstant %PushConstsFS + %push_fs = OpVariable %_ptr_PushConstant_PushConstsFS PushConstant + %main_cs = OpFunction %void None %fn_void + %main_cs_label = OpLabel + %push_cs_access_0 = OpAccessChain %_ptr_PushConstant_uint %push_cs %int_0 + %push_cs_access_1 = OpAccessChain %_ptr_PushConstant_float %push_cs %int_1 + %push_cs_load_0 = OpLoad %uint %push_cs_access_0 + %push_cs_load_1 = OpLoad %float %push_cs_access_1 + OpReturn + OpFunctionEnd + %main_fs = OpFunction %void None %fn_void + %main_fs_label = OpLabel + %push_fs_access_0 = OpAccessChain %_ptr_PushConstant_float %push_fs %int_0 + %push_fs_load_0 = OpLoad %float %push_fs_access_0 + OpReturn + OpFunctionEnd + */ + const MODULE: [u32; 186] = [ + 119734787, 65536, 458752, 27, 0, 131089, 1, 393227, 1, 1280527431, 1685353262, + 808793134, 0, 196622, 0, 1, 393231, 5, 2, 1852399981, 7562079, 3, 393231, 4, 4, + 1852399981, 7562847, 5, 393232, 2, 17, 1, 1, 1, 196624, 4, 7, 262149, 2, 1852399981, + 7562079, 262149, 6, 1752397136, 21315, 262150, 6, 0, 97, 262150, 6, 1, 98, 262149, 4, + 1852399981, 7562847, 262149, 5, 1752397136, 21318, 262150, 7, 0, 97, 327752, 6, 0, 35, + 0, 327752, 6, 1, 35, 4, 196679, 6, 2, 327752, 7, 0, 35, 0, 196679, 7, 2, 131091, 8, + 196641, 9, 8, 262165, 10, 32, 0, 262165, 11, 32, 1, 196630, 12, 32, 262174, 6, 10, 12, + 262176, 13, 9, 6, 262203, 13, 3, 9, 262176, 14, 9, 10, 262187, 11, 15, 0, 262187, 11, + 16, 1, 262176, 17, 9, 12, 196638, 7, 12, 262176, 18, 9, 7, 262203, 18, 5, 9, 327734, 8, + 2, 0, 9, 131320, 19, 327745, 14, 20, 3, 15, 327745, 17, 21, 3, 16, 262205, 10, 22, 20, + 262205, 12, 23, 21, 65789, 65592, 327734, 8, 4, 0, 9, 131320, 24, 327745, 17, 25, 5, + 15, 262205, 12, 26, 25, 65789, 65592, + ]; + let spirv = crate::shader::spirv::Spirv::new(&MODULE).unwrap(); + assert_eq!(spirv.version(), Version::V1_0); + let entry_points: HashMap<_, _> = super::entry_points(&spirv) + .map(|(_, v)| (v.name.clone(), v)) + .collect(); + assert_eq!(entry_points.len(), 2); + let main_cs = &entry_points["main_cs"]; + assert_eq!( + main_cs.push_constant_requirements, + Some(PushConstantRange { + stages: ShaderStages::COMPUTE, + offset: 0, + size: 8, + }) + ); + let main_fs = &entry_points["main_fs"]; + assert_eq!( + main_fs.push_constant_requirements, + Some(PushConstantRange { + stages: ShaderStages::FRAGMENT, + offset: 0, + size: 4, + }) + ); + } + + #[test] + fn push_constant_range_spirv_1_4() { + /* + ; SPIR-V + ; Version: 1.4 + ; Generator: Google Shaderc over Glslang; 10 + ; Bound: 27 + ; Schema: 0 + OpCapability Shader + %glsl_std450 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main_cs "main_cs" %push_cs + OpEntryPoint Fragment %main_fs "main_fs" %push_fs + OpExecutionMode %main_cs LocalSize 1 1 1 + OpExecutionMode %main_fs OriginUpperLeft + OpName %main_cs "main_cs" + OpName %PushConstsCS "PushCS" + OpMemberName %PushConstsCS 0 "a" + OpMemberName %PushConstsCS 1 "b" + OpName %main_fs "main_fs" + OpName %push_fs "PushFS" + OpMemberName %PushConstsFS 0 "a" + OpMemberDecorate %PushConstsCS 0 Offset 0 + OpMemberDecorate %PushConstsCS 1 Offset 4 + OpDecorate %PushConstsCS Block + OpMemberDecorate %PushConstsFS 0 Offset 0 + OpDecorate %PushConstsFS Block + %void = OpTypeVoid + %fn_void = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %int = OpTypeInt 32 1 + %float = OpTypeFloat 32 + %PushConstsCS = OpTypeStruct %uint %float + %_ptr_PushConstant_PushConstsCS = OpTypePointer PushConstant %PushConstsCS + %push_cs = OpVariable %_ptr_PushConstant_PushConstsCS PushConstant + %_ptr_PushConstant_uint = OpTypePointer PushConstant %uint + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %_ptr_PushConstant_float = OpTypePointer PushConstant %float + %PushConstsFS = OpTypeStruct %float + %_ptr_PushConstant_PushConstsFS = OpTypePointer PushConstant %PushConstsFS + %push_fs = OpVariable %_ptr_PushConstant_PushConstsFS PushConstant + %main_cs = OpFunction %void None %fn_void + %main_cs_label = OpLabel + %push_cs_access_0 = OpAccessChain %_ptr_PushConstant_uint %push_cs %int_0 + %push_cs_access_1 = OpAccessChain %_ptr_PushConstant_float %push_cs %int_1 + %push_cs_load_0 = OpLoad %uint %push_cs_access_0 + %push_cs_load_1 = OpLoad %float %push_cs_access_1 + OpReturn + OpFunctionEnd + %main_fs = OpFunction %void None %fn_void + %main_fs_label = OpLabel + %push_fs_access_0 = OpAccessChain %_ptr_PushConstant_float %push_fs %int_0 + %push_fs_load_0 = OpLoad %float %push_fs_access_0 + OpReturn + OpFunctionEnd + */ + const MODULE: [u32; 186] = [ + 119734787, 66560, 458752, 27, 0, 131089, 1, 393227, 1, 1280527431, 1685353262, + 808793134, 0, 196622, 0, 1, 393231, 5, 2, 1852399981, 7562079, 3, 393231, 4, 4, + 1852399981, 7562847, 5, 393232, 2, 17, 1, 1, 1, 196624, 4, 7, 262149, 2, 1852399981, + 7562079, 262149, 6, 1752397136, 21315, 262150, 6, 0, 97, 262150, 6, 1, 98, 262149, 4, + 1852399981, 7562847, 262149, 5, 1752397136, 21318, 262150, 7, 0, 97, 327752, 6, 0, 35, + 0, 327752, 6, 1, 35, 4, 196679, 6, 2, 327752, 7, 0, 35, 0, 196679, 7, 2, 131091, 8, + 196641, 9, 8, 262165, 10, 32, 0, 262165, 11, 32, 1, 196630, 12, 32, 262174, 6, 10, 12, + 262176, 13, 9, 6, 262203, 13, 3, 9, 262176, 14, 9, 10, 262187, 11, 15, 0, 262187, 11, + 16, 1, 262176, 17, 9, 12, 196638, 7, 12, 262176, 18, 9, 7, 262203, 18, 5, 9, 327734, 8, + 2, 0, 9, 131320, 19, 327745, 14, 20, 3, 15, 327745, 17, 21, 3, 16, 262205, 10, 22, 20, + 262205, 12, 23, 21, 65789, 65592, 327734, 8, 4, 0, 9, 131320, 24, 327745, 17, 25, 5, + 15, 262205, 12, 26, 25, 65789, 65592, + ]; + let spirv = crate::shader::spirv::Spirv::new(&MODULE).unwrap(); + assert_eq!(spirv.version(), Version::V1_4); + let entry_points: HashMap<_, _> = super::entry_points(&spirv) + .map(|(_, v)| (v.name.clone(), v)) + .collect(); + assert_eq!(entry_points.len(), 2); + let main_cs = &entry_points["main_cs"]; + assert_eq!( + main_cs.push_constant_requirements, + Some(PushConstantRange { + stages: ShaderStages::COMPUTE, + offset: 0, + size: 8, + }) + ); + let main_fs = &entry_points["main_fs"]; + assert_eq!( + main_fs.push_constant_requirements, + Some(PushConstantRange { + stages: ShaderStages::FRAGMENT, + offset: 0, + size: 4, + }) + ); + } +}