Fix #2398 Reflect push constant requirements by variable usage. (#2405)

* Reflect push constants by variable usage.

* Added shader::reflect::tests::push_constant_range.

* Reflect using interface with spirv_1_4.

* Fix spirv version.
This commit is contained in:
charles-r-earp 2023-11-15 07:18:21 -08:00 committed by GitHub
parent 6e07c01478
commit 84c6dbe18b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,9 +8,10 @@ use crate::{
shader::{
spirv::{Decoration, Dim, ExecutionModel, Id, Instruction, Spirv, StorageClass},
DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, NumericType, ShaderInterface,
ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, SpecializationConstant,
ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, ShaderStages,
SpecializationConstant,
},
DeviceSize,
DeviceSize, Version,
};
use ahash::{HashMap, HashSet};
use half::f16;
@ -41,7 +42,12 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = (Id, EntryPointInfo)>
stage,
function_id,
);
let push_constant_requirements = push_constant_requirements(spirv, stage);
let push_constant_requirements = push_constant_requirements(
&interface_variables.push_constant,
spirv,
stage,
function_id,
);
let input_interface = shader_interface(
spirv,
interface,
@ -77,6 +83,7 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = (Id, EntryPointInfo)>
#[derive(Clone, Debug, Default)]
struct InterfaceVariables {
descriptor_binding: HashMap<Id, DescriptorBindingVariable>,
push_constant: HashMap<Id, PushConstantRange>,
}
// See also section 14.5.2 of the Vulkan specs: Descriptor Set Interface.
@ -93,7 +100,7 @@ fn interface_variables(spirv: &Spirv) -> InterfaceVariables {
for instruction in spirv.iter_global() {
if let Instruction::Variable {
result_id,
result_type_id: _,
result_type_id,
storage_class,
..
} = instruction
@ -107,6 +114,12 @@ fn interface_variables(spirv: &Spirv) -> InterfaceVariables {
descriptor_binding_requirements_of(spirv, *result_id),
);
}
StorageClass::PushConstant => {
variables.push_constant.insert(
*result_id,
push_constant_requirements_of(spirv, *result_type_id),
);
}
_ => (),
}
}
@ -841,22 +854,7 @@ fn descriptor_binding_requirements_of(spirv: &Spirv, variable_id: Id) -> Descrip
}
}
/// Extracts the `PushConstantRange` from `spirv`.
fn push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option<PushConstantRange> {
spirv
.iter_global()
// TODO: doesn't work with more than one entry point in the shader.
// We should really use the interface variables in Instruction::EntryPoint,
// but before SPIR-V 1.4 these do not contain push constants.
.find_map(|instruction| match *instruction {
Instruction::Variable {
result_type_id,
storage_class: StorageClass::PushConstant,
..
} => Some(result_type_id),
_ => None,
})
.map(|pointer_type_id| {
fn push_constant_requirements_of(spirv: &Spirv, pointer_type_id: Id) -> PushConstantRange {
let struct_type_id = match *spirv.id(pointer_type_id).instruction() {
Instruction::TypePointer {
ty,
@ -874,15 +872,107 @@ fn push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option<PushC
"VUID-StandaloneSpirv-PushConstant-06808"
);
let start = offset_of_struct(spirv, struct_type_id);
let end = size_of_type(spirv, struct_type_id)
.expect("Found runtime-sized push constants") as u32;
let end =
size_of_type(spirv, struct_type_id).expect("Found runtime-sized push constants") as u32;
PushConstantRange {
stages: stage.into(),
stages: ShaderStages::default(),
offset: start,
size: end - start,
}
})
}
/// Extracts the `PushConstantRange` from `spirv`.
fn push_constant_requirements(
global: &HashMap<Id, PushConstantRange>,
spirv: &Spirv,
stage: ShaderStage,
function_id: Id,
) -> Option<PushConstantRange> {
fn find_variables_used(
function_id: Id,
global: &HashMap<Id, PushConstantRange>,
spirv: &Spirv,
visited_fns: &mut HashSet<Id>,
variables: &mut HashSet<Id>,
) {
visited_fns.insert(function_id);
let function_info = spirv.function(function_id);
for instruction in function_info.iter_instructions() {
match instruction {
Instruction::FunctionCall {
function,
arguments,
..
} => {
for arg in arguments {
if global.contains_key(arg) {
variables.insert(*arg);
}
}
if !visited_fns.contains(function) {
find_variables_used(*function, global, spirv, visited_fns, variables);
}
}
Instruction::AccessChain {
base: variable_id, ..
}
| Instruction::InBoundsAccessChain {
base: variable_id, ..
}
| Instruction::PtrAccessChain {
base: variable_id, ..
}
| Instruction::InBoundsPtrAccessChain {
base: variable_id, ..
}
| Instruction::Load {
pointer: variable_id,
..
}
| Instruction::CopyMemory {
source: variable_id,
..
}
| Instruction::CopyObject {
result_id: variable_id,
..
} => {
if global.contains_key(variable_id) {
variables.insert(*variable_id);
}
}
_ => (),
}
}
}
if global.is_empty() {
return None;
}
let mut variables = HashSet::default();
if spirv.version() < Version::V1_4 {
let mut visited_fns = HashSet::default();
find_variables_used(function_id, global, spirv, &mut visited_fns, &mut variables);
} else if let Instruction::EntryPoint { interface, .. } =
spirv.function(function_id).entry_point().unwrap()
{
for id in interface {
if global.contains_key(id) {
variables.insert(*id);
}
}
} else {
unreachable!();
}
assert!(
variables.len() <= 1,
"VUID-StandaloneSpirv-OpEntryPoint-06674"
);
let variable_id = variables.into_iter().next()?;
let mut push_constant_range = global.get(&variable_id).copied().unwrap();
push_constant_range.stages = stage.into();
Some(push_constant_range)
}
/// Extracts the `SpecializationConstant` map from `spirv`.
@ -1339,3 +1429,202 @@ fn is_builtin(spirv: &Spirv, id: Id) -> bool {
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::{HashMap, PushConstantRange, ShaderStages, Version};
#[test]
fn push_constant_range() {
/*
; SPIR-V
; Version: 1.0
; Generator: Google Shaderc over Glslang; 10
; Bound: 27
; Schema: 0
OpCapability Shader
%glsl_std450 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main_cs "main_cs" %push_cs
OpEntryPoint Fragment %main_fs "main_fs" %push_fs
OpExecutionMode %main_cs LocalSize 1 1 1
OpExecutionMode %main_fs OriginUpperLeft
OpName %main_cs "main_cs"
OpName %PushConstsCS "PushCS"
OpMemberName %PushConstsCS 0 "a"
OpMemberName %PushConstsCS 1 "b"
OpName %main_fs "main_fs"
OpName %push_fs "PushFS"
OpMemberName %PushConstsFS 0 "a"
OpMemberDecorate %PushConstsCS 0 Offset 0
OpMemberDecorate %PushConstsCS 1 Offset 4
OpDecorate %PushConstsCS Block
OpMemberDecorate %PushConstsFS 0 Offset 0
OpDecorate %PushConstsFS Block
%void = OpTypeVoid
%fn_void = OpTypeFunction %void
%uint = OpTypeInt 32 0
%int = OpTypeInt 32 1
%float = OpTypeFloat 32
%PushConstsCS = OpTypeStruct %uint %float
%_ptr_PushConstant_PushConstsCS = OpTypePointer PushConstant %PushConstsCS
%push_cs = OpVariable %_ptr_PushConstant_PushConstsCS PushConstant
%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%_ptr_PushConstant_float = OpTypePointer PushConstant %float
%PushConstsFS = OpTypeStruct %float
%_ptr_PushConstant_PushConstsFS = OpTypePointer PushConstant %PushConstsFS
%push_fs = OpVariable %_ptr_PushConstant_PushConstsFS PushConstant
%main_cs = OpFunction %void None %fn_void
%main_cs_label = OpLabel
%push_cs_access_0 = OpAccessChain %_ptr_PushConstant_uint %push_cs %int_0
%push_cs_access_1 = OpAccessChain %_ptr_PushConstant_float %push_cs %int_1
%push_cs_load_0 = OpLoad %uint %push_cs_access_0
%push_cs_load_1 = OpLoad %float %push_cs_access_1
OpReturn
OpFunctionEnd
%main_fs = OpFunction %void None %fn_void
%main_fs_label = OpLabel
%push_fs_access_0 = OpAccessChain %_ptr_PushConstant_float %push_fs %int_0
%push_fs_load_0 = OpLoad %float %push_fs_access_0
OpReturn
OpFunctionEnd
*/
const MODULE: [u32; 186] = [
119734787, 65536, 458752, 27, 0, 131089, 1, 393227, 1, 1280527431, 1685353262,
808793134, 0, 196622, 0, 1, 393231, 5, 2, 1852399981, 7562079, 3, 393231, 4, 4,
1852399981, 7562847, 5, 393232, 2, 17, 1, 1, 1, 196624, 4, 7, 262149, 2, 1852399981,
7562079, 262149, 6, 1752397136, 21315, 262150, 6, 0, 97, 262150, 6, 1, 98, 262149, 4,
1852399981, 7562847, 262149, 5, 1752397136, 21318, 262150, 7, 0, 97, 327752, 6, 0, 35,
0, 327752, 6, 1, 35, 4, 196679, 6, 2, 327752, 7, 0, 35, 0, 196679, 7, 2, 131091, 8,
196641, 9, 8, 262165, 10, 32, 0, 262165, 11, 32, 1, 196630, 12, 32, 262174, 6, 10, 12,
262176, 13, 9, 6, 262203, 13, 3, 9, 262176, 14, 9, 10, 262187, 11, 15, 0, 262187, 11,
16, 1, 262176, 17, 9, 12, 196638, 7, 12, 262176, 18, 9, 7, 262203, 18, 5, 9, 327734, 8,
2, 0, 9, 131320, 19, 327745, 14, 20, 3, 15, 327745, 17, 21, 3, 16, 262205, 10, 22, 20,
262205, 12, 23, 21, 65789, 65592, 327734, 8, 4, 0, 9, 131320, 24, 327745, 17, 25, 5,
15, 262205, 12, 26, 25, 65789, 65592,
];
let spirv = crate::shader::spirv::Spirv::new(&MODULE).unwrap();
assert_eq!(spirv.version(), Version::V1_0);
let entry_points: HashMap<_, _> = super::entry_points(&spirv)
.map(|(_, v)| (v.name.clone(), v))
.collect();
assert_eq!(entry_points.len(), 2);
let main_cs = &entry_points["main_cs"];
assert_eq!(
main_cs.push_constant_requirements,
Some(PushConstantRange {
stages: ShaderStages::COMPUTE,
offset: 0,
size: 8,
})
);
let main_fs = &entry_points["main_fs"];
assert_eq!(
main_fs.push_constant_requirements,
Some(PushConstantRange {
stages: ShaderStages::FRAGMENT,
offset: 0,
size: 4,
})
);
}
#[test]
fn push_constant_range_spirv_1_4() {
/*
; SPIR-V
; Version: 1.4
; Generator: Google Shaderc over Glslang; 10
; Bound: 27
; Schema: 0
OpCapability Shader
%glsl_std450 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main_cs "main_cs" %push_cs
OpEntryPoint Fragment %main_fs "main_fs" %push_fs
OpExecutionMode %main_cs LocalSize 1 1 1
OpExecutionMode %main_fs OriginUpperLeft
OpName %main_cs "main_cs"
OpName %PushConstsCS "PushCS"
OpMemberName %PushConstsCS 0 "a"
OpMemberName %PushConstsCS 1 "b"
OpName %main_fs "main_fs"
OpName %push_fs "PushFS"
OpMemberName %PushConstsFS 0 "a"
OpMemberDecorate %PushConstsCS 0 Offset 0
OpMemberDecorate %PushConstsCS 1 Offset 4
OpDecorate %PushConstsCS Block
OpMemberDecorate %PushConstsFS 0 Offset 0
OpDecorate %PushConstsFS Block
%void = OpTypeVoid
%fn_void = OpTypeFunction %void
%uint = OpTypeInt 32 0
%int = OpTypeInt 32 1
%float = OpTypeFloat 32
%PushConstsCS = OpTypeStruct %uint %float
%_ptr_PushConstant_PushConstsCS = OpTypePointer PushConstant %PushConstsCS
%push_cs = OpVariable %_ptr_PushConstant_PushConstsCS PushConstant
%_ptr_PushConstant_uint = OpTypePointer PushConstant %uint
%int_0 = OpConstant %int 0
%int_1 = OpConstant %int 1
%_ptr_PushConstant_float = OpTypePointer PushConstant %float
%PushConstsFS = OpTypeStruct %float
%_ptr_PushConstant_PushConstsFS = OpTypePointer PushConstant %PushConstsFS
%push_fs = OpVariable %_ptr_PushConstant_PushConstsFS PushConstant
%main_cs = OpFunction %void None %fn_void
%main_cs_label = OpLabel
%push_cs_access_0 = OpAccessChain %_ptr_PushConstant_uint %push_cs %int_0
%push_cs_access_1 = OpAccessChain %_ptr_PushConstant_float %push_cs %int_1
%push_cs_load_0 = OpLoad %uint %push_cs_access_0
%push_cs_load_1 = OpLoad %float %push_cs_access_1
OpReturn
OpFunctionEnd
%main_fs = OpFunction %void None %fn_void
%main_fs_label = OpLabel
%push_fs_access_0 = OpAccessChain %_ptr_PushConstant_float %push_fs %int_0
%push_fs_load_0 = OpLoad %float %push_fs_access_0
OpReturn
OpFunctionEnd
*/
const MODULE: [u32; 186] = [
119734787, 66560, 458752, 27, 0, 131089, 1, 393227, 1, 1280527431, 1685353262,
808793134, 0, 196622, 0, 1, 393231, 5, 2, 1852399981, 7562079, 3, 393231, 4, 4,
1852399981, 7562847, 5, 393232, 2, 17, 1, 1, 1, 196624, 4, 7, 262149, 2, 1852399981,
7562079, 262149, 6, 1752397136, 21315, 262150, 6, 0, 97, 262150, 6, 1, 98, 262149, 4,
1852399981, 7562847, 262149, 5, 1752397136, 21318, 262150, 7, 0, 97, 327752, 6, 0, 35,
0, 327752, 6, 1, 35, 4, 196679, 6, 2, 327752, 7, 0, 35, 0, 196679, 7, 2, 131091, 8,
196641, 9, 8, 262165, 10, 32, 0, 262165, 11, 32, 1, 196630, 12, 32, 262174, 6, 10, 12,
262176, 13, 9, 6, 262203, 13, 3, 9, 262176, 14, 9, 10, 262187, 11, 15, 0, 262187, 11,
16, 1, 262176, 17, 9, 12, 196638, 7, 12, 262176, 18, 9, 7, 262203, 18, 5, 9, 327734, 8,
2, 0, 9, 131320, 19, 327745, 14, 20, 3, 15, 327745, 17, 21, 3, 16, 262205, 10, 22, 20,
262205, 12, 23, 21, 65789, 65592, 327734, 8, 4, 0, 9, 131320, 24, 327745, 17, 25, 5,
15, 262205, 12, 26, 25, 65789, 65592,
];
let spirv = crate::shader::spirv::Spirv::new(&MODULE).unwrap();
assert_eq!(spirv.version(), Version::V1_4);
let entry_points: HashMap<_, _> = super::entry_points(&spirv)
.map(|(_, v)| (v.name.clone(), v))
.collect();
assert_eq!(entry_points.len(), 2);
let main_cs = &entry_points["main_cs"];
assert_eq!(
main_cs.push_constant_requirements,
Some(PushConstantRange {
stages: ShaderStages::COMPUTE,
offset: 0,
size: 8,
})
);
let main_fs = &entry_points["main_fs"];
assert_eq!(
main_fs.push_constant_requirements,
Some(PushConstantRange {
stages: ShaderStages::FRAGMENT,
offset: 0,
size: 4,
})
);
}
}