Refactor some shader things and add more validation (#2335)

* Refactor some shader things and add more validation

* Remove pub
This commit is contained in:
Rua 2023-09-21 12:18:31 +02:00 committed by GitHub
parent e9790c1fc3
commit a8ca0a7f7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 603 additions and 559 deletions

View File

@ -554,7 +554,7 @@ mod tests {
let spirv = Spirv::new(&instructions).unwrap(); let spirv = Spirv::new(&instructions).unwrap();
let mut descriptors = Vec::new(); let mut descriptors = Vec::new();
for info in reflect::entry_points(&spirv) { for (_, info) in reflect::entry_points(&spirv) {
descriptors.push(info.descriptor_binding_requirements); descriptors.push(info.descriptor_binding_requirements);
} }
@ -622,7 +622,7 @@ mod tests {
.unwrap(); .unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap(); let spirv = Spirv::new(comp.as_binary()).unwrap();
if let Some(info) = reflect::entry_points(&spirv).next() { if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new(); let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_binding_requirements { for (loc, _reqs) in info.descriptor_binding_requirements {
bindings.push(loc); bindings.push(loc);

View File

@ -28,7 +28,7 @@ use crate::{
instance::InstanceOwnedDebugWrapper, instance::InstanceOwnedDebugWrapper,
macros::impl_id_counter, macros::impl_id_counter,
pipeline::{cache::PipelineCache, layout::PipelineLayout, Pipeline, PipelineBindPoint}, pipeline::{cache::PipelineCache, layout::PipelineLayout, Pipeline, PipelineBindPoint},
shader::{DescriptorBindingRequirements, ShaderExecution, ShaderStage}, shader::{spirv::ExecutionModel, DescriptorBindingRequirements, ShaderStage},
Validated, ValidationError, VulkanError, VulkanObject, Validated, ValidationError, VulkanError, VulkanObject,
}; };
use ahash::HashMap; use ahash::HashMap;
@ -155,7 +155,7 @@ impl ComputePipeline {
}, },
), ),
flags: flags.into(), flags: flags.into(),
stage: ShaderStage::from(&entry_point_info.execution).into(), stage: ShaderStage::from(entry_point_info.execution_model).into(),
module: entry_point.module().handle(), module: entry_point.module().handle(),
p_name: name_vk.as_ptr(), p_name: name_vk.as_ptr(),
p_specialization_info: if specialization_info_vk.data_size == 0 { p_specialization_info: if specialization_info_vk.data_size == 0 {
@ -410,7 +410,7 @@ impl ComputePipelineCreateInfo {
let entry_point_info = entry_point.info(); let entry_point_info = entry_point.info();
if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) { if !matches!(entry_point_info.execution_model, ExecutionModel::GLCompute) {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
context: "stage.entry_point".into(), context: "stage.entry_point".into(),
problem: "is not a `ShaderStage::Compute` entry point".into(), problem: "is not a `ShaderStage::Compute` entry point".into(),

View File

@ -90,8 +90,8 @@ use crate::{
PartialStateMode, PartialStateMode,
}, },
shader::{ shader::{
DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages, spirv::{ExecutionMode, ExecutionModel, Instruction},
ShaderExecution, ShaderStage, ShaderStages, DescriptorBindingRequirements, ShaderStage, ShaderStages,
}, },
Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, VulkanError, VulkanObject, Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, VulkanError, VulkanObject,
}; };
@ -220,7 +220,7 @@ impl GraphicsPipeline {
} = stage; } = stage;
let entry_point_info = entry_point.info(); let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution); let stage = ShaderStage::from(entry_point_info.execution_model);
let mut specialization_data_vk: Vec<u8> = Vec::new(); let mut specialization_data_vk: Vec<u8> = Vec::new();
let specialization_map_entries_vk: Vec<_> = entry_point let specialization_map_entries_vk: Vec<_> = entry_point
@ -1223,15 +1223,28 @@ impl GraphicsPipeline {
} = stage; } = stage;
let entry_point_info = entry_point.info(); let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution); let stage = ShaderStage::from(entry_point_info.execution_model);
shaders.insert(stage, ()); shaders.insert(stage, ());
if let ShaderExecution::Fragment(FragmentShaderExecution { let spirv = entry_point.module().spirv();
fragment_tests_stages: s, let entry_point_function = spirv.function(entry_point.id());
..
}) = entry_point_info.execution if matches!(entry_point_info.execution_model, ExecutionModel::Fragment) {
{ fragment_tests_stages = Some(FragmentTestsStages::Late);
fragment_tests_stages = Some(s)
for instruction in entry_point_function.iter_execution_mode() {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::EarlyFragmentTests => {
fragment_tests_stages = Some(FragmentTestsStages::Early);
}
ExecutionMode::EarlyAndLateFragmentTestsAMD => {
fragment_tests_stages = Some(FragmentTestsStages::EarlyAndLate);
}
_ => (),
}
}
}
} }
for (&loc, reqs) in &entry_point_info.descriptor_binding_requirements { for (&loc, reqs) in &entry_point_info.descriptor_binding_requirements {
@ -1989,7 +2002,7 @@ impl GraphicsPipelineCreateInfo {
for (stage_index, stage) in stages.iter().enumerate() { for (stage_index, stage) in stages.iter().enumerate() {
let entry_point_info = stage.entry_point.info(); let entry_point_info = stage.entry_point.info();
let stage_enum = ShaderStage::from(&entry_point_info.execution); let stage_enum = ShaderStage::from(entry_point_info.execution_model);
let stage_flag = ShaderStages::from(stage_enum); let stage_flag = ShaderStages::from(stage_enum);
if stages_present.intersects(stage_flag) { if stages_present.intersects(stage_flag) {
@ -2081,9 +2094,12 @@ impl GraphicsPipelineCreateInfo {
} }
let need_vertex_input_state = need_pre_rasterization_shader_state let need_vertex_input_state = need_pre_rasterization_shader_state
&& stages && stages.iter().any(|stage| {
.iter() matches!(
.any(|stage| matches!(stage.entry_point.info().execution, ShaderExecution::Vertex)); stage.entry_point.info().execution_model,
ExecutionModel::Vertex
)
});
let need_fragment_shader_state = need_pre_rasterization_shader_state let need_fragment_shader_state = need_pre_rasterization_shader_state
&& rasterization_state && rasterization_state
.as_ref() .as_ref()
@ -2535,8 +2551,8 @@ impl GraphicsPipelineCreateInfo {
problem: format!( problem: format!(
"the output interface of the `ShaderStage::{:?}` stage does not \ "the output interface of the `ShaderStage::{:?}` stage does not \
match the input interface of the `ShaderStage::{:?}` stage: {}", match the input interface of the `ShaderStage::{:?}` stage: {}",
ShaderStage::from(&output.entry_point.info().execution), ShaderStage::from(output.entry_point.info().execution_model),
ShaderStage::from(&input.entry_point.info().execution), ShaderStage::from(input.entry_point.info().execution_model),
err err
) )
.into(), .into(),
@ -2816,11 +2832,30 @@ impl GraphicsPipelineCreateInfo {
geometry_stage, geometry_stage,
input_assembly_state, input_assembly_state,
) { ) {
let entry_point_info = geometry_stage.entry_point.info(); let spirv = geometry_stage.entry_point.module().spirv();
let input = match entry_point_info.execution { let entry_point_function = spirv.function(geometry_stage.entry_point.id());
ShaderExecution::Geometry(execution) => execution.input,
_ => unreachable!(), let input = entry_point_function
}; .iter_execution_mode()
.find_map(|instruction| {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::InputPoints => Some(GeometryShaderInput::Points),
ExecutionMode::InputLines => Some(GeometryShaderInput::Lines),
ExecutionMode::InputLinesAdjacency => {
Some(GeometryShaderInput::LinesWithAdjacency)
}
ExecutionMode::Triangles => Some(GeometryShaderInput::Triangles),
ExecutionMode::InputTrianglesAdjacency => {
Some(GeometryShaderInput::TrianglesWithAdjacency)
}
_ => None,
}
} else {
None
}
})
.unwrap();
if let PartialStateMode::Fixed(topology) = input_assembly_state.topology { if let PartialStateMode::Fixed(topology) = input_assembly_state.topology {
if !input.is_compatible_with(topology) { if !input.is_compatible_with(topology) {
@ -3104,3 +3139,51 @@ impl GraphicsPipelineCreateInfo {
Ok(()) Ok(())
} }
} }
/// The input primitive type that is expected by a geometry shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum GeometryShaderInput {
Points,
Lines,
LinesWithAdjacency,
Triangles,
TrianglesWithAdjacency,
}
impl GeometryShaderInput {
/// Returns true if the given primitive topology can be used as input for this geometry shader.
#[inline]
fn is_compatible_with(self, topology: PrimitiveTopology) -> bool {
match self {
Self::Points => matches!(topology, PrimitiveTopology::PointList),
Self::Lines => matches!(
topology,
PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
),
Self::LinesWithAdjacency => matches!(
topology,
PrimitiveTopology::LineListWithAdjacency
| PrimitiveTopology::LineStripWithAdjacency
),
Self::Triangles => matches!(
topology,
PrimitiveTopology::TriangleList
| PrimitiveTopology::TriangleStrip
| PrimitiveTopology::TriangleFan,
),
Self::TrianglesWithAdjacency => matches!(
topology,
PrimitiveTopology::TriangleListWithAdjacency
| PrimitiveTopology::TriangleStripWithAdjacency,
),
}
}
}
/// The fragment tests stages that will be executed in a fragment shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FragmentTestsStages {
Early,
Late,
EarlyAndLate,
}

View File

@ -21,7 +21,10 @@ pub use self::{compute::ComputePipeline, graphics::GraphicsPipeline, layout::Pip
use crate::{ use crate::{
device::{Device, DeviceOwned}, device::{Device, DeviceOwned},
macros::{vulkan_bitflags, vulkan_enum}, macros::{vulkan_bitflags, vulkan_enum},
shader::{DescriptorBindingRequirements, EntryPoint, ShaderExecution, ShaderStage}, shader::{
spirv::{BuiltIn, Decoration, ExecutionMode, Id, Instruction},
DescriptorBindingRequirements, EntryPoint, ShaderStage,
},
Requires, RequiresAllOf, RequiresOneOf, ValidationError, Requires, RequiresAllOf, RequiresOneOf, ValidationError,
}; };
use ahash::HashMap; use ahash::HashMap;
@ -355,7 +358,7 @@ impl PipelineShaderStageCreateInfo {
})?; })?;
let entry_point_info = entry_point.info(); let entry_point_info = entry_point.info();
let stage_enum = ShaderStage::from(&entry_point_info.execution); let stage_enum = ShaderStage::from(entry_point_info.execution_model);
stage_enum.validate_device(device).map_err(|err| { stage_enum.validate_device(device).map_err(|err| {
err.add_context("entry_point.info().execution") err.add_context("entry_point.info().execution")
@ -451,10 +454,241 @@ impl PipelineShaderStageCreateInfo {
ShaderStage::SubpassShading => (), ShaderStage::SubpassShading => (),
} }
let workgroup_size = if let ShaderExecution::Compute(execution) = let spirv = entry_point.module().spirv();
&entry_point_info.execution let entry_point_function = spirv.function(entry_point.id());
let mut clip_distance_array_size = 0;
let mut cull_distance_array_size = 0;
for instruction in spirv.iter_decoration() {
if let Instruction::Decorate {
target,
decoration: Decoration::BuiltIn { built_in },
} = *instruction
{ {
let local_size = execution.local_size; let variable_array_size = |variable| {
let result_type_id = match *spirv.id(variable).instruction() {
Instruction::Variable { result_type_id, .. } => result_type_id,
_ => return None,
};
let length = match *spirv.id(result_type_id).instruction() {
Instruction::TypeArray { length, .. } => length,
_ => return None,
};
let value = match *spirv.id(length).instruction() {
Instruction::Constant { ref value, .. } => {
if value.len() > 1 {
u32::MAX
} else {
value[0]
}
}
_ => return None,
};
Some(value)
};
match built_in {
BuiltIn::ClipDistance => {
clip_distance_array_size = variable_array_size(target).unwrap();
if clip_distance_array_size > properties.max_clip_distances {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the number of elements in the `ClipDistance` built-in \
variable is greater than the \
`max_clip_distances` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxClipDistances-00708",
],
..Default::default()
}));
}
}
BuiltIn::CullDistance => {
cull_distance_array_size = variable_array_size(target).unwrap();
if cull_distance_array_size > properties.max_cull_distances {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the number of elements in the `CullDistance` built-in \
variable is greater than the \
`max_cull_distances` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxCullDistances-00709",
],
..Default::default()
}));
}
}
BuiltIn::SampleMask => {
if variable_array_size(target).unwrap() > properties.max_sample_mask_words {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the number of elements in the `SampleMask` built-in \
variable is greater than the \
`max_sample_mask_words` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxSampleMaskWords-00711",
],
..Default::default()
}));
}
}
_ => (),
}
}
}
if clip_distance_array_size
.checked_add(cull_distance_array_size)
.map_or(true, |sum| {
sum > properties.max_combined_clip_and_cull_distances
})
{
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the sum of the number of elements in the `ClipDistance` and \
`CullDistance` built-in variables is greater than the \
`max_combined_clip_and_cull_distances` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxCombinedClipAndCullDistances-00710",
],
..Default::default()
}));
}
for instruction in entry_point_function.iter_execution_mode() {
if let Instruction::ExecutionMode {
mode: ExecutionMode::OutputVertices { vertex_count },
..
} = *instruction
{
match stage_enum {
ShaderStage::TessellationControl | ShaderStage::TessellationEvaluation => {
if vertex_count == 0 {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is zero"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00713"],
..Default::default()
}));
}
if vertex_count > properties.max_tessellation_patch_size {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is greater than the \
`max_tessellation_patch_size` device limit"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00713"],
..Default::default()
}));
}
}
ShaderStage::Geometry => {
if vertex_count == 0 {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is zero"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00714"],
..Default::default()
}));
}
if vertex_count > properties.max_geometry_output_vertices {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is greater than the \
`max_geometry_output_vertices` device limit"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00714"],
..Default::default()
}));
}
}
_ => (),
}
}
}
let local_size = (spirv
.iter_decoration()
.find_map(|instruction| match *instruction {
Instruction::Decorate {
target,
decoration:
Decoration::BuiltIn {
built_in: BuiltIn::WorkgroupSize,
},
} => {
let constituents: &[Id; 3] = match *spirv.id(target).instruction() {
Instruction::ConstantComposite {
ref constituents, ..
} => constituents.as_slice().try_into().unwrap(),
_ => unreachable!(),
};
let local_size = constituents.map(|id| match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
_ => unreachable!(),
});
Some(local_size)
}
_ => None,
}))
.or_else(|| {
entry_point_function
.iter_execution_mode()
.find_map(|instruction| match *instruction {
Instruction::ExecutionMode {
mode:
ExecutionMode::LocalSize {
x_size,
y_size,
z_size,
},
..
} => Some([x_size, y_size, z_size]),
Instruction::ExecutionModeId {
mode:
ExecutionMode::LocalSizeId {
x_size,
y_size,
z_size,
},
..
} => Some([x_size, y_size, z_size].map(
|id| match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
_ => unreachable!(),
},
)),
_ => None,
})
})
.unwrap_or_default();
let workgroup_size = local_size.into_iter().try_fold(1, u32::checked_mul);
match stage_enum { match stage_enum {
ShaderStage::Compute => { ShaderStage::Compute => {
@ -488,22 +722,18 @@ impl PipelineShaderStageCreateInfo {
})); }));
} }
let workgroup_size = local_size if workgroup_size.map_or(true, |size| {
.into_iter() size > properties.max_compute_work_group_invocations
.try_fold(1, u32::checked_mul) }) {
.filter(|&x| x <= properties.max_compute_work_group_invocations) return Err(Box::new(ValidationError {
.ok_or_else(|| {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \ problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \ `local_size_z` of `entry_point` is greater than the \
`max_compute_work_group_invocations` device limit" `max_compute_work_group_invocations` device limit"
.into(), .into(),
vuids: &["VUID-RuntimeSpirv-x-06432"], vuids: &["VUID-RuntimeSpirv-x-06432"],
..Default::default() ..Default::default()
}) }));
})?; }
Some(workgroup_size)
} }
ShaderStage::Task => { ShaderStage::Task => {
if local_size[0] > properties.max_task_work_group_size.unwrap_or_default()[0] { if local_size[0] > properties.max_task_work_group_size.unwrap_or_default()[0] {
@ -536,26 +766,20 @@ impl PipelineShaderStageCreateInfo {
})); }));
} }
let workgroup_size = local_size if workgroup_size.map_or(true, |size| {
.into_iter() size > properties
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
.max_task_work_group_invocations .max_task_work_group_invocations
.unwrap_or_default() .unwrap_or_default()
}) }) {
.ok_or_else(|| { return Err(Box::new(ValidationError {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \ problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \ `local_size_z` of `entry_point` is greater than the \
`max_task_work_group_invocations` device limit" `max_task_work_group_invocations` device limit"
.into(), .into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07294"], vuids: &["VUID-RuntimeSpirv-TaskEXT-07294"],
..Default::default() ..Default::default()
}) }));
})?; }
Some(workgroup_size)
} }
ShaderStage::Mesh => { ShaderStage::Mesh => {
if local_size[0] > properties.max_mesh_work_group_size.unwrap_or_default()[0] { if local_size[0] > properties.max_mesh_work_group_size.unwrap_or_default()[0] {
@ -588,33 +812,25 @@ impl PipelineShaderStageCreateInfo {
})); }));
} }
let workgroup_size = local_size if workgroup_size.map_or(true, |size| {
.into_iter() size > properties
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
.max_mesh_work_group_invocations .max_mesh_work_group_invocations
.unwrap_or_default() .unwrap_or_default()
}) }) {
.ok_or_else(|| { return Err(Box::new(ValidationError {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \ problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \ `local_size_z` of `entry_point` is greater than the \
`max_mesh_work_group_invocations` device limit" `max_mesh_work_group_invocations` device limit"
.into(), .into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07298"], vuids: &["VUID-RuntimeSpirv-MeshEXT-07298"],
..Default::default() ..Default::default()
}) }));
})?; }
}
_ => (),
}
Some(workgroup_size) let workgroup_size = workgroup_size.unwrap();
}
// TODO: Additional stages when `.local_size()` supports them.
_ => unreachable!(),
}
} else {
None
};
if let Some(required_subgroup_size) = required_subgroup_size { if let Some(required_subgroup_size) = required_subgroup_size {
if !device.enabled_features().subgroup_size_control { if !device.enabled_features().subgroup_size_control {
@ -670,9 +886,10 @@ impl PipelineShaderStageCreateInfo {
})); }));
} }
if let Some(workgroup_size) = workgroup_size { if matches!(
if stage_enum == ShaderStage::Compute { stage_enum,
if workgroup_size ShaderStage::Compute | ShaderStage::Mesh | ShaderStage::Task
) && workgroup_size
> required_subgroup_size > required_subgroup_size
.checked_mul( .checked_mul(
properties properties
@ -692,8 +909,9 @@ impl PipelineShaderStageCreateInfo {
})); }));
} }
} }
}
} // TODO:
// VUID-VkPipelineShaderStageCreateInfo-module-08987
Ok(()) Ok(())
} }

View File

@ -131,7 +131,7 @@
//! [`scalar_block_layout`]: crate::device::Features::scalar_block_layout //! [`scalar_block_layout`]: crate::device::Features::scalar_block_layout
//! [`uniform_buffer_standard_layout`]: crate::device::Features::uniform_buffer_standard_layout //! [`uniform_buffer_standard_layout`]: crate::device::Features::uniform_buffer_standard_layout
use self::spirv::Instruction; use self::spirv::{Id, Instruction};
use crate::{ use crate::{
descriptor_set::layout::DescriptorType, descriptor_set::layout::DescriptorType,
device::{Device, DeviceOwned}, device::{Device, DeviceOwned},
@ -139,7 +139,7 @@ use crate::{
image::view::ImageViewType, image::view::ImageViewType,
instance::InstanceOwnedDebugWrapper, instance::InstanceOwnedDebugWrapper,
macros::{impl_id_counter, vulkan_bitflags_enum}, macros::{impl_id_counter, vulkan_bitflags_enum},
pipeline::{graphics::input_assembly::PrimitiveTopology, layout::PushConstantRange}, pipeline::layout::PushConstantRange,
shader::spirv::{Capability, Spirv}, shader::spirv::{Capability, Spirv},
sync::PipelineStages, sync::PipelineStages,
Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, Version, VulkanError, Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, Version, VulkanError,
@ -699,8 +699,8 @@ impl From<f64> for SpecializationConstant {
pub struct SpecializedShaderModule { pub struct SpecializedShaderModule {
base_module: Arc<ShaderModule>, base_module: Arc<ShaderModule>,
specialization_info: HashMap<u32, SpecializationConstant>, specialization_info: HashMap<u32, SpecializationConstant>,
_spirv: Option<Spirv>, spirv: Option<Spirv>,
entry_point_infos: SmallVec<[EntryPointInfo; 1]>, entry_point_infos: SmallVec<[(Id, EntryPointInfo); 1]>,
} }
impl SpecializedShaderModule { impl SpecializedShaderModule {
@ -760,7 +760,7 @@ impl SpecializedShaderModule {
Arc::new(Self { Arc::new(Self {
base_module, base_module,
specialization_info, specialization_info,
_spirv: spirv, spirv,
entry_point_infos, entry_point_infos,
}) })
} }
@ -777,6 +777,12 @@ impl SpecializedShaderModule {
&self.specialization_info &self.specialization_info
} }
/// Returns the SPIR-V code of this module.
#[inline]
pub(crate) fn spirv(&self) -> &Spirv {
self.spirv.as_ref().unwrap_or(&self.base_module.spirv)
}
/// Returns information about the entry point with the provided name. Returns `None` if no entry /// Returns information about the entry point with the provided name. Returns `None` if no entry
/// point with that name exists in the shader module or if multiple entry points with the same /// point with that name exists in the shader module or if multiple entry points with the same
/// name exist. /// name exist.
@ -794,7 +800,7 @@ impl SpecializedShaderModule {
execution: ExecutionModel, execution: ExecutionModel,
) -> Option<EntryPoint> { ) -> Option<EntryPoint> {
self.single_entry_point_filter(|info| { self.single_entry_point_filter(|info| {
info.name == name && ExecutionModel::from(&info.execution) == execution info.name == name && info.execution_model == execution
}) })
} }
@ -808,11 +814,12 @@ impl SpecializedShaderModule {
.entry_point_infos .entry_point_infos
.iter() .iter()
.enumerate() .enumerate()
.filter(|(_, infos)| filter(infos)) .filter(|(_, (_, infos))| filter(infos))
.map(|(x, _)| x); .map(|(x, _)| x);
let info_index = iter.next()?; let info_index = iter.next()?;
iter.next().is_none().then(|| EntryPoint { iter.next().is_none().then(|| EntryPoint {
module: self.clone(), module: self.clone(),
id: self.entry_point_infos[info_index].0,
info_index, info_index,
}) })
} }
@ -832,7 +839,7 @@ impl SpecializedShaderModule {
self: &Arc<Self>, self: &Arc<Self>,
execution: ExecutionModel, execution: ExecutionModel,
) -> Option<EntryPoint> { ) -> Option<EntryPoint> {
self.single_entry_point_filter(|info| ExecutionModel::from(&info.execution) == execution) self.single_entry_point_filter(|info| info.execution_model == execution)
} }
} }
@ -856,7 +863,7 @@ unsafe impl DeviceOwned for SpecializedShaderModule {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct EntryPointInfo { pub struct EntryPointInfo {
pub name: String, pub name: String,
pub execution: ShaderExecution, pub execution_model: ExecutionModel,
pub descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>, pub descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>,
pub push_constant_requirements: Option<PushConstantRange>, pub push_constant_requirements: Option<PushConstantRange>,
pub input_interface: ShaderInterface, pub input_interface: ShaderInterface,
@ -869,6 +876,7 @@ pub struct EntryPointInfo {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct EntryPoint { pub struct EntryPoint {
module: Arc<SpecializedShaderModule>, module: Arc<SpecializedShaderModule>,
id: Id,
info_index: usize, info_index: usize,
} }
@ -879,151 +887,18 @@ impl EntryPoint {
&self.module &self.module
} }
/// Returns the Id of the entry point function.
pub(crate) fn id(&self) -> Id {
self.id
}
/// Returns information about the entry point. /// Returns information about the entry point.
#[inline] #[inline]
pub fn info(&self) -> &EntryPointInfo { pub fn info(&self) -> &EntryPointInfo {
&self.module.entry_point_infos[self.info_index] &self.module.entry_point_infos[self.info_index].1
} }
} }
/// The mode in which a shader executes. This includes both information about the shader type/stage,
/// and additional data relevant to particular shader types.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ShaderExecution {
Vertex,
TessellationControl,
TessellationEvaluation,
Geometry(GeometryShaderExecution),
Fragment(FragmentShaderExecution),
Compute(ComputeShaderExecution),
RayGeneration,
AnyHit,
ClosestHit,
Miss,
Intersection,
Callable,
Task, // TODO: like compute?
Mesh, // TODO: like compute?
SubpassShading,
}
impl From<&ShaderExecution> for ExecutionModel {
fn from(value: &ShaderExecution) -> Self {
match value {
ShaderExecution::Vertex => Self::Vertex,
ShaderExecution::TessellationControl => Self::TessellationControl,
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
ShaderExecution::Geometry(_) => Self::Geometry,
ShaderExecution::Fragment(_) => Self::Fragment,
ShaderExecution::Compute(_) => Self::GLCompute,
ShaderExecution::RayGeneration => Self::RayGenerationKHR,
ShaderExecution::AnyHit => Self::AnyHitKHR,
ShaderExecution::ClosestHit => Self::ClosestHitKHR,
ShaderExecution::Miss => Self::MissKHR,
ShaderExecution::Intersection => Self::IntersectionKHR,
ShaderExecution::Callable => Self::CallableKHR,
ShaderExecution::Task => Self::TaskNV,
ShaderExecution::Mesh => Self::MeshNV,
ShaderExecution::SubpassShading => todo!(),
}
}
}
/*#[derive(Clone, Copy, Debug)]
pub struct TessellationShaderExecution {
pub num_output_vertices: u32,
pub point_mode: bool,
pub subdivision: TessellationShaderSubdivision,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TessellationShaderSubdivision {
Triangles,
Quads,
Isolines,
}*/
/// The mode in which a geometry shader executes.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GeometryShaderExecution {
pub input: GeometryShaderInput,
/*pub max_output_vertices: u32,
pub num_invocations: u32,
pub output: GeometryShaderOutput,*/
}
/// The input primitive type that is expected by a geometry shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum GeometryShaderInput {
Points,
Lines,
LinesWithAdjacency,
Triangles,
TrianglesWithAdjacency,
}
impl GeometryShaderInput {
/// Returns true if the given primitive topology can be used as input for this geometry shader.
#[inline]
pub fn is_compatible_with(self, topology: PrimitiveTopology) -> bool {
match self {
Self::Points => matches!(topology, PrimitiveTopology::PointList),
Self::Lines => matches!(
topology,
PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
),
Self::LinesWithAdjacency => matches!(
topology,
PrimitiveTopology::LineListWithAdjacency
| PrimitiveTopology::LineStripWithAdjacency
),
Self::Triangles => matches!(
topology,
PrimitiveTopology::TriangleList
| PrimitiveTopology::TriangleStrip
| PrimitiveTopology::TriangleFan,
),
Self::TrianglesWithAdjacency => matches!(
topology,
PrimitiveTopology::TriangleListWithAdjacency
| PrimitiveTopology::TriangleStripWithAdjacency,
),
}
}
}
/*#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum GeometryShaderOutput {
Points,
LineStrip,
TriangleStrip,
}*/
/// The mode in which a fragment shader executes.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FragmentShaderExecution {
pub fragment_tests_stages: FragmentTestsStages,
}
/// The fragment tests stages that will be executed in a fragment shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FragmentTestsStages {
Early,
Late,
EarlyAndLate,
}
/// The mode in which the compute shader executes.
///
/// The `WorkgroupSize` builtin overrides the values specified in the
/// execution mode. It can decorate a 3 component ConstantComposite or
/// SpecConstantComposite vector.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ComputeShaderExecution {
/// Workgroup size in x, y, and z.
pub local_size: [u32; 3],
}
/// The requirements imposed by a shader on a binding within a descriptor set layout, and on any /// The requirements imposed by a shader on a binding within a descriptor set layout, and on any
/// resource that is bound to that binding. /// resource that is bound to that binding.
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
@ -1454,25 +1329,27 @@ vulkan_bitflags_enum! {
]), ]),
} }
impl From<&ShaderExecution> for ShaderStage { impl From<ExecutionModel> for ShaderStage {
#[inline] #[inline]
fn from(value: &ShaderExecution) -> Self { fn from(value: ExecutionModel) -> Self {
match value { match value {
ShaderExecution::Vertex => Self::Vertex, ExecutionModel::Vertex => ShaderStage::Vertex,
ShaderExecution::TessellationControl => Self::TessellationControl, ExecutionModel::TessellationControl => ShaderStage::TessellationControl,
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation, ExecutionModel::TessellationEvaluation => ShaderStage::TessellationEvaluation,
ShaderExecution::Geometry(_) => Self::Geometry, ExecutionModel::Geometry => ShaderStage::Geometry,
ShaderExecution::Fragment(_) => Self::Fragment, ExecutionModel::Fragment => ShaderStage::Fragment,
ShaderExecution::Compute(_) => Self::Compute, ExecutionModel::GLCompute => ShaderStage::Compute,
ShaderExecution::RayGeneration => Self::Raygen, ExecutionModel::Kernel => {
ShaderExecution::AnyHit => Self::AnyHit, unimplemented!("the `Kernel` execution model is not supported by Vulkan")
ShaderExecution::ClosestHit => Self::ClosestHit, }
ShaderExecution::Miss => Self::Miss, ExecutionModel::TaskNV | ExecutionModel::TaskEXT => ShaderStage::Task,
ShaderExecution::Intersection => Self::Intersection, ExecutionModel::MeshNV | ExecutionModel::MeshEXT => ShaderStage::Mesh,
ShaderExecution::Callable => Self::Callable, ExecutionModel::RayGenerationKHR => ShaderStage::Raygen,
ShaderExecution::Task => Self::Task, ExecutionModel::IntersectionKHR => ShaderStage::Intersection,
ShaderExecution::Mesh => Self::Mesh, ExecutionModel::AnyHitKHR => ShaderStage::AnyHit,
ShaderExecution::SubpassShading => Self::SubpassShading, ExecutionModel::ClosestHitKHR => ShaderStage::ClosestHit,
ExecutionModel::MissKHR => ShaderStage::Miss,
ExecutionModel::CallableKHR => ShaderStage::Callable,
} }
} }
} }

View File

@ -9,20 +9,15 @@
//! Extraction of information from SPIR-V modules, that is needed by the rest of Vulkano. //! Extraction of information from SPIR-V modules, that is needed by the rest of Vulkano.
use super::{DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages}; use super::DescriptorBindingRequirements;
use crate::{ use crate::{
descriptor_set::layout::DescriptorType, descriptor_set::layout::DescriptorType,
image::view::ImageViewType, image::view::ImageViewType,
pipeline::layout::PushConstantRange, pipeline::layout::PushConstantRange,
shader::{ shader::{
spirv::{ spirv::{Decoration, Dim, ExecutionModel, Id, Instruction, Spirv, StorageClass},
BuiltIn, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, NumericType, ShaderInterface,
StorageClass, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, SpecializationConstant,
},
ComputeShaderExecution, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo,
GeometryShaderExecution, GeometryShaderInput, NumericType, ShaderExecution,
ShaderInterface, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage,
SpecializationConstant,
}, },
DeviceSize, DeviceSize,
}; };
@ -32,7 +27,7 @@ use std::borrow::Cow;
/// Returns an iterator over all entry points in `spirv`, with information about the entry point. /// Returns an iterator over all entry points in `spirv`, with information about the entry point.
#[inline] #[inline]
pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_ { pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = (Id, EntryPointInfo)> + '_ {
let interface_variables = interface_variables(spirv); let interface_variables = interface_variables(spirv);
spirv.iter_entry_point().filter_map(move |instruction| { spirv.iter_entry_point().filter_map(move |instruction| {
@ -47,8 +42,7 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
_ => return None, _ => return None,
}; };
let execution = shader_execution(spirv, execution_model, function_id); let stage = ShaderStage::from(execution_model);
let stage = ShaderStage::from(&execution);
let descriptor_binding_requirements = inspect_entry_point( let descriptor_binding_requirements = inspect_entry_point(
&interface_variables.descriptor_binding, &interface_variables.descriptor_binding,
@ -75,187 +69,18 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
matches!(execution_model, ExecutionModel::TessellationControl), matches!(execution_model, ExecutionModel::TessellationControl),
); );
Some(EntryPointInfo { Some((
function_id,
EntryPointInfo {
name: entry_point_name.clone(), name: entry_point_name.clone(),
execution, execution_model,
descriptor_binding_requirements, descriptor_binding_requirements,
push_constant_requirements, push_constant_requirements,
input_interface, input_interface,
output_interface, output_interface,
})
})
}
/// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
fn shader_execution(
spirv: &Spirv,
execution_model: ExecutionModel,
function_id: Id,
) -> ShaderExecution {
match execution_model {
ExecutionModel::Vertex => ShaderExecution::Vertex,
ExecutionModel::TessellationControl => ShaderExecution::TessellationControl,
ExecutionModel::TessellationEvaluation => ShaderExecution::TessellationEvaluation,
ExecutionModel::Geometry => {
let mut input = None;
for instruction in spirv.iter_execution_mode() {
let mode = match instruction {
Instruction::ExecutionMode {
entry_point, mode, ..
} if *entry_point == function_id => mode,
_ => continue,
};
match mode {
ExecutionMode::InputPoints => {
input = Some(GeometryShaderInput::Points);
}
ExecutionMode::InputLines => {
input = Some(GeometryShaderInput::Lines);
}
ExecutionMode::InputLinesAdjacency => {
input = Some(GeometryShaderInput::LinesWithAdjacency);
}
ExecutionMode::Triangles => {
input = Some(GeometryShaderInput::Triangles);
}
ExecutionMode::InputTrianglesAdjacency => {
input = Some(GeometryShaderInput::TrianglesWithAdjacency);
}
_ => (),
}
}
ShaderExecution::Geometry(GeometryShaderExecution {
input: input
.expect("Geometry shader does not have an input primitive ExecutionMode"),
})
}
ExecutionModel::Fragment => {
let mut fragment_tests_stages = FragmentTestsStages::Late;
for instruction in spirv.iter_execution_mode() {
let mode = match instruction {
Instruction::ExecutionMode {
entry_point, mode, ..
} if *entry_point == function_id => mode,
_ => continue,
};
match mode {
ExecutionMode::EarlyFragmentTests => {
fragment_tests_stages = FragmentTestsStages::Early;
}
ExecutionMode::EarlyAndLateFragmentTestsAMD => {
fragment_tests_stages = FragmentTestsStages::EarlyAndLate;
}
_ => (),
}
}
ShaderExecution::Fragment(FragmentShaderExecution {
fragment_tests_stages,
})
}
ExecutionModel::GLCompute => {
let local_size = (spirv
.iter_decoration()
.find_map(|instruction| match *instruction {
Instruction::Decorate {
target,
decoration:
Decoration::BuiltIn {
built_in: BuiltIn::WorkgroupSize,
}, },
} => match *spirv.id(target).instruction() { ))
Instruction::ConstantComposite {
ref constituents, ..
} => {
match *constituents.as_slice() {
[x_size, y_size, z_size] => {
Some([x_size, y_size, z_size].map(|id| {
match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
// VUID-WorkgroupSize-WorkgroupSize-04426
// VUID-WorkgroupSize-WorkgroupSize-04427
_ => panic!("WorkgroupSize is not a constant"),
}
}))
}
// VUID-WorkgroupSize-WorkgroupSize-04427
_ => panic!("WorkgroupSize must be 3 component vector!"),
}
}
// VUID-WorkgroupSize-WorkgroupSize-04426
_ => panic!("WorkgroupSize is not a constant"),
},
_ => None,
}))
.or_else(|| {
spirv
.iter_execution_mode()
.find_map(|instruction| match *instruction {
Instruction::ExecutionMode {
entry_point,
mode:
ExecutionMode::LocalSize {
x_size,
y_size,
z_size,
},
} if entry_point == function_id => Some([x_size, y_size, z_size]),
Instruction::ExecutionModeId {
entry_point,
mode:
ExecutionMode::LocalSizeId {
x_size,
y_size,
z_size,
},
} if entry_point == function_id => Some([x_size, y_size, z_size].map(
|id| match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
_ => panic!("LocalSizeId is not a constant"),
},
)),
_ => None,
}) })
});
ShaderExecution::Compute(ComputeShaderExecution {
local_size: local_size.expect(
"Geometry shader does not have a WorkgroupSize builtin, \
or LocalSize or LocalSizeId ExecutionMode",
),
})
}
ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,
ExecutionModel::AnyHitKHR => ShaderExecution::AnyHit,
ExecutionModel::ClosestHitKHR => ShaderExecution::ClosestHit,
ExecutionModel::MissKHR => ShaderExecution::Miss,
ExecutionModel::CallableKHR => ShaderExecution::Callable,
ExecutionModel::TaskEXT => ShaderExecution::Task,
ExecutionModel::TaskNV => todo!(),
ExecutionModel::MeshEXT => ShaderExecution::Mesh,
ExecutionModel::MeshNV => todo!(),
ExecutionModel::Kernel => todo!(),
}
} }
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]

View File

@ -132,8 +132,35 @@ impl Spirv {
let destination = match instruction { let destination = match instruction {
Instruction::Function { result_id, .. } => { Instruction::Function { result_id, .. } => {
current_function = None; current_function = None;
let function = functions.entry(result_id).or_insert(FunctionInfo { let function = functions.entry(result_id).or_insert_with(|| {
let entry_point = instructions_entry_point
.iter()
.find(|instruction| {
matches!(
**instruction,
Instruction::EntryPoint { entry_point, .. }
if entry_point == result_id
)
})
.cloned();
let execution_modes = instructions_execution_mode
.iter()
.filter(|instruction| {
matches!(
**instruction,
Instruction::ExecutionMode { entry_point, .. }
| Instruction::ExecutionModeId { entry_point, .. }
if entry_point == result_id
)
})
.cloned()
.collect();
FunctionInfo {
instructions: Vec::new(), instructions: Vec::new(),
entry_point,
execution_modes,
}
}); });
current_function.insert(&mut function.instructions) current_function.insert(&mut function.instructions)
} }
@ -611,9 +638,11 @@ impl StructMemberInfo {
} }
/// Information associated with a function. /// Information associated with a function.
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug)]
pub struct FunctionInfo { pub struct FunctionInfo {
instructions: Vec<Instruction>, instructions: Vec<Instruction>,
entry_point: Option<Instruction>,
execution_modes: Vec<Instruction>,
} }
impl FunctionInfo { impl FunctionInfo {
@ -622,6 +651,18 @@ impl FunctionInfo {
pub fn iter_instructions(&self) -> impl ExactSizeIterator<Item = &Instruction> { pub fn iter_instructions(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions.iter() self.instructions.iter()
} }
/// Returns the `EntryPoint` instruction that targets this function, if there is one.
#[inline]
pub fn entry_point(&self) -> Option<&Instruction> {
self.entry_point.as_ref()
}
/// Returns an iterator over all execution mode instructions that target this function.
#[inline]
pub fn iter_execution_mode(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.execution_modes.iter()
}
} }
fn iter_instructions( fn iter_instructions(