Refactor some shader things and add more validation (#2335)

* Refactor some shader things and add more validation

* Remove pub
This commit is contained in:
Rua 2023-09-21 12:18:31 +02:00 committed by GitHub
parent e9790c1fc3
commit a8ca0a7f7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 603 additions and 559 deletions

View File

@ -554,7 +554,7 @@ mod tests {
let spirv = Spirv::new(&instructions).unwrap();
let mut descriptors = Vec::new();
for info in reflect::entry_points(&spirv) {
for (_, info) in reflect::entry_points(&spirv) {
descriptors.push(info.descriptor_binding_requirements);
}
@ -622,7 +622,7 @@ mod tests {
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
if let Some(info) = reflect::entry_points(&spirv).next() {
if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_binding_requirements {
bindings.push(loc);

View File

@ -28,7 +28,7 @@ use crate::{
instance::InstanceOwnedDebugWrapper,
macros::impl_id_counter,
pipeline::{cache::PipelineCache, layout::PipelineLayout, Pipeline, PipelineBindPoint},
shader::{DescriptorBindingRequirements, ShaderExecution, ShaderStage},
shader::{spirv::ExecutionModel, DescriptorBindingRequirements, ShaderStage},
Validated, ValidationError, VulkanError, VulkanObject,
};
use ahash::HashMap;
@ -155,7 +155,7 @@ impl ComputePipeline {
},
),
flags: flags.into(),
stage: ShaderStage::from(&entry_point_info.execution).into(),
stage: ShaderStage::from(entry_point_info.execution_model).into(),
module: entry_point.module().handle(),
p_name: name_vk.as_ptr(),
p_specialization_info: if specialization_info_vk.data_size == 0 {
@ -410,7 +410,7 @@ impl ComputePipelineCreateInfo {
let entry_point_info = entry_point.info();
if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
if !matches!(entry_point_info.execution_model, ExecutionModel::GLCompute) {
return Err(Box::new(ValidationError {
context: "stage.entry_point".into(),
problem: "is not a `ShaderStage::Compute` entry point".into(),

View File

@ -90,8 +90,8 @@ use crate::{
PartialStateMode,
},
shader::{
DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages,
ShaderExecution, ShaderStage, ShaderStages,
spirv::{ExecutionMode, ExecutionModel, Instruction},
DescriptorBindingRequirements, ShaderStage, ShaderStages,
},
Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, VulkanError, VulkanObject,
};
@ -220,7 +220,7 @@ impl GraphicsPipeline {
} = stage;
let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution);
let stage = ShaderStage::from(entry_point_info.execution_model);
let mut specialization_data_vk: Vec<u8> = Vec::new();
let specialization_map_entries_vk: Vec<_> = entry_point
@ -1223,15 +1223,28 @@ impl GraphicsPipeline {
} = stage;
let entry_point_info = entry_point.info();
let stage = ShaderStage::from(&entry_point_info.execution);
let stage = ShaderStage::from(entry_point_info.execution_model);
shaders.insert(stage, ());
if let ShaderExecution::Fragment(FragmentShaderExecution {
fragment_tests_stages: s,
..
}) = entry_point_info.execution
{
fragment_tests_stages = Some(s)
let spirv = entry_point.module().spirv();
let entry_point_function = spirv.function(entry_point.id());
if matches!(entry_point_info.execution_model, ExecutionModel::Fragment) {
fragment_tests_stages = Some(FragmentTestsStages::Late);
for instruction in entry_point_function.iter_execution_mode() {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::EarlyFragmentTests => {
fragment_tests_stages = Some(FragmentTestsStages::Early);
}
ExecutionMode::EarlyAndLateFragmentTestsAMD => {
fragment_tests_stages = Some(FragmentTestsStages::EarlyAndLate);
}
_ => (),
}
}
}
}
for (&loc, reqs) in &entry_point_info.descriptor_binding_requirements {
@ -1989,7 +2002,7 @@ impl GraphicsPipelineCreateInfo {
for (stage_index, stage) in stages.iter().enumerate() {
let entry_point_info = stage.entry_point.info();
let stage_enum = ShaderStage::from(&entry_point_info.execution);
let stage_enum = ShaderStage::from(entry_point_info.execution_model);
let stage_flag = ShaderStages::from(stage_enum);
if stages_present.intersects(stage_flag) {
@ -2081,9 +2094,12 @@ impl GraphicsPipelineCreateInfo {
}
let need_vertex_input_state = need_pre_rasterization_shader_state
&& stages
.iter()
.any(|stage| matches!(stage.entry_point.info().execution, ShaderExecution::Vertex));
&& stages.iter().any(|stage| {
matches!(
stage.entry_point.info().execution_model,
ExecutionModel::Vertex
)
});
let need_fragment_shader_state = need_pre_rasterization_shader_state
&& rasterization_state
.as_ref()
@ -2535,8 +2551,8 @@ impl GraphicsPipelineCreateInfo {
problem: format!(
"the output interface of the `ShaderStage::{:?}` stage does not \
match the input interface of the `ShaderStage::{:?}` stage: {}",
ShaderStage::from(&output.entry_point.info().execution),
ShaderStage::from(&input.entry_point.info().execution),
ShaderStage::from(output.entry_point.info().execution_model),
ShaderStage::from(input.entry_point.info().execution_model),
err
)
.into(),
@ -2816,11 +2832,30 @@ impl GraphicsPipelineCreateInfo {
geometry_stage,
input_assembly_state,
) {
let entry_point_info = geometry_stage.entry_point.info();
let input = match entry_point_info.execution {
ShaderExecution::Geometry(execution) => execution.input,
_ => unreachable!(),
};
let spirv = geometry_stage.entry_point.module().spirv();
let entry_point_function = spirv.function(geometry_stage.entry_point.id());
let input = entry_point_function
.iter_execution_mode()
.find_map(|instruction| {
if let Instruction::ExecutionMode { mode, .. } = *instruction {
match mode {
ExecutionMode::InputPoints => Some(GeometryShaderInput::Points),
ExecutionMode::InputLines => Some(GeometryShaderInput::Lines),
ExecutionMode::InputLinesAdjacency => {
Some(GeometryShaderInput::LinesWithAdjacency)
}
ExecutionMode::Triangles => Some(GeometryShaderInput::Triangles),
ExecutionMode::InputTrianglesAdjacency => {
Some(GeometryShaderInput::TrianglesWithAdjacency)
}
_ => None,
}
} else {
None
}
})
.unwrap();
if let PartialStateMode::Fixed(topology) = input_assembly_state.topology {
if !input.is_compatible_with(topology) {
@ -3104,3 +3139,51 @@ impl GraphicsPipelineCreateInfo {
Ok(())
}
}
/// The input primitive type that is expected by a geometry shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
enum GeometryShaderInput {
Points,
Lines,
LinesWithAdjacency,
Triangles,
TrianglesWithAdjacency,
}
impl GeometryShaderInput {
/// Returns true if the given primitive topology can be used as input for this geometry shader.
#[inline]
fn is_compatible_with(self, topology: PrimitiveTopology) -> bool {
match self {
Self::Points => matches!(topology, PrimitiveTopology::PointList),
Self::Lines => matches!(
topology,
PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
),
Self::LinesWithAdjacency => matches!(
topology,
PrimitiveTopology::LineListWithAdjacency
| PrimitiveTopology::LineStripWithAdjacency
),
Self::Triangles => matches!(
topology,
PrimitiveTopology::TriangleList
| PrimitiveTopology::TriangleStrip
| PrimitiveTopology::TriangleFan,
),
Self::TrianglesWithAdjacency => matches!(
topology,
PrimitiveTopology::TriangleListWithAdjacency
| PrimitiveTopology::TriangleStripWithAdjacency,
),
}
}
}
/// The fragment tests stages that will be executed in a fragment shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FragmentTestsStages {
Early,
Late,
EarlyAndLate,
}

View File

@ -21,7 +21,10 @@ pub use self::{compute::ComputePipeline, graphics::GraphicsPipeline, layout::Pip
use crate::{
device::{Device, DeviceOwned},
macros::{vulkan_bitflags, vulkan_enum},
shader::{DescriptorBindingRequirements, EntryPoint, ShaderExecution, ShaderStage},
shader::{
spirv::{BuiltIn, Decoration, ExecutionMode, Id, Instruction},
DescriptorBindingRequirements, EntryPoint, ShaderStage,
},
Requires, RequiresAllOf, RequiresOneOf, ValidationError,
};
use ahash::HashMap;
@ -355,7 +358,7 @@ impl PipelineShaderStageCreateInfo {
})?;
let entry_point_info = entry_point.info();
let stage_enum = ShaderStage::from(&entry_point_info.execution);
let stage_enum = ShaderStage::from(entry_point_info.execution_model);
stage_enum.validate_device(device).map_err(|err| {
err.add_context("entry_point.info().execution")
@ -451,10 +454,241 @@ impl PipelineShaderStageCreateInfo {
ShaderStage::SubpassShading => (),
}
let workgroup_size = if let ShaderExecution::Compute(execution) =
&entry_point_info.execution
let spirv = entry_point.module().spirv();
let entry_point_function = spirv.function(entry_point.id());
let mut clip_distance_array_size = 0;
let mut cull_distance_array_size = 0;
for instruction in spirv.iter_decoration() {
if let Instruction::Decorate {
target,
decoration: Decoration::BuiltIn { built_in },
} = *instruction
{
let local_size = execution.local_size;
let variable_array_size = |variable| {
let result_type_id = match *spirv.id(variable).instruction() {
Instruction::Variable { result_type_id, .. } => result_type_id,
_ => return None,
};
let length = match *spirv.id(result_type_id).instruction() {
Instruction::TypeArray { length, .. } => length,
_ => return None,
};
let value = match *spirv.id(length).instruction() {
Instruction::Constant { ref value, .. } => {
if value.len() > 1 {
u32::MAX
} else {
value[0]
}
}
_ => return None,
};
Some(value)
};
match built_in {
BuiltIn::ClipDistance => {
clip_distance_array_size = variable_array_size(target).unwrap();
if clip_distance_array_size > properties.max_clip_distances {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the number of elements in the `ClipDistance` built-in \
variable is greater than the \
`max_clip_distances` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxClipDistances-00708",
],
..Default::default()
}));
}
}
BuiltIn::CullDistance => {
cull_distance_array_size = variable_array_size(target).unwrap();
if cull_distance_array_size > properties.max_cull_distances {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the number of elements in the `CullDistance` built-in \
variable is greater than the \
`max_cull_distances` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxCullDistances-00709",
],
..Default::default()
}));
}
}
BuiltIn::SampleMask => {
if variable_array_size(target).unwrap() > properties.max_sample_mask_words {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the number of elements in the `SampleMask` built-in \
variable is greater than the \
`max_sample_mask_words` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxSampleMaskWords-00711",
],
..Default::default()
}));
}
}
_ => (),
}
}
}
if clip_distance_array_size
.checked_add(cull_distance_array_size)
.map_or(true, |sum| {
sum > properties.max_combined_clip_and_cull_distances
})
{
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the sum of the number of elements in the `ClipDistance` and \
`CullDistance` built-in variables is greater than the \
`max_combined_clip_and_cull_distances` device limit"
.into(),
vuids: &[
"VUID-VkPipelineShaderStageCreateInfo-maxCombinedClipAndCullDistances-00710",
],
..Default::default()
}));
}
for instruction in entry_point_function.iter_execution_mode() {
if let Instruction::ExecutionMode {
mode: ExecutionMode::OutputVertices { vertex_count },
..
} = *instruction
{
match stage_enum {
ShaderStage::TessellationControl | ShaderStage::TessellationEvaluation => {
if vertex_count == 0 {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is zero"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00713"],
..Default::default()
}));
}
if vertex_count > properties.max_tessellation_patch_size {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is greater than the \
`max_tessellation_patch_size` device limit"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00713"],
..Default::default()
}));
}
}
ShaderStage::Geometry => {
if vertex_count == 0 {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is zero"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00714"],
..Default::default()
}));
}
if vertex_count > properties.max_geometry_output_vertices {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "the `vertex_count` of the \
`ExecutionMode::OutputVertices` is greater than the \
`max_geometry_output_vertices` device limit"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-stage-00714"],
..Default::default()
}));
}
}
_ => (),
}
}
}
let local_size = (spirv
.iter_decoration()
.find_map(|instruction| match *instruction {
Instruction::Decorate {
target,
decoration:
Decoration::BuiltIn {
built_in: BuiltIn::WorkgroupSize,
},
} => {
let constituents: &[Id; 3] = match *spirv.id(target).instruction() {
Instruction::ConstantComposite {
ref constituents, ..
} => constituents.as_slice().try_into().unwrap(),
_ => unreachable!(),
};
let local_size = constituents.map(|id| match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
_ => unreachable!(),
});
Some(local_size)
}
_ => None,
}))
.or_else(|| {
entry_point_function
.iter_execution_mode()
.find_map(|instruction| match *instruction {
Instruction::ExecutionMode {
mode:
ExecutionMode::LocalSize {
x_size,
y_size,
z_size,
},
..
} => Some([x_size, y_size, z_size]),
Instruction::ExecutionModeId {
mode:
ExecutionMode::LocalSizeId {
x_size,
y_size,
z_size,
},
..
} => Some([x_size, y_size, z_size].map(
|id| match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
_ => unreachable!(),
},
)),
_ => None,
})
})
.unwrap_or_default();
let workgroup_size = local_size.into_iter().try_fold(1, u32::checked_mul);
match stage_enum {
ShaderStage::Compute => {
@ -488,22 +722,18 @@ impl PipelineShaderStageCreateInfo {
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| x <= properties.max_compute_work_group_invocations)
.ok_or_else(|| {
Box::new(ValidationError {
if workgroup_size.map_or(true, |size| {
size > properties.max_compute_work_group_invocations
}) {
return Err(Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_compute_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06432"],
..Default::default()
})
})?;
Some(workgroup_size)
}));
}
}
ShaderStage::Task => {
if local_size[0] > properties.max_task_work_group_size.unwrap_or_default()[0] {
@ -536,26 +766,20 @@ impl PipelineShaderStageCreateInfo {
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
if workgroup_size.map_or(true, |size| {
size > properties
.max_task_work_group_invocations
.unwrap_or_default()
})
.ok_or_else(|| {
Box::new(ValidationError {
}) {
return Err(Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_task_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07294"],
..Default::default()
})
})?;
Some(workgroup_size)
}));
}
}
ShaderStage::Mesh => {
if local_size[0] > properties.max_mesh_work_group_size.unwrap_or_default()[0] {
@ -588,33 +812,25 @@ impl PipelineShaderStageCreateInfo {
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
if workgroup_size.map_or(true, |size| {
size > properties
.max_mesh_work_group_invocations
.unwrap_or_default()
})
.ok_or_else(|| {
Box::new(ValidationError {
}) {
return Err(Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_mesh_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07298"],
..Default::default()
})
})?;
}));
}
}
_ => (),
}
Some(workgroup_size)
}
// TODO: Additional stages when `.local_size()` supports them.
_ => unreachable!(),
}
} else {
None
};
let workgroup_size = workgroup_size.unwrap();
if let Some(required_subgroup_size) = required_subgroup_size {
if !device.enabled_features().subgroup_size_control {
@ -670,9 +886,10 @@ impl PipelineShaderStageCreateInfo {
}));
}
if let Some(workgroup_size) = workgroup_size {
if stage_enum == ShaderStage::Compute {
if workgroup_size
if matches!(
stage_enum,
ShaderStage::Compute | ShaderStage::Mesh | ShaderStage::Task
) && workgroup_size
> required_subgroup_size
.checked_mul(
properties
@ -692,8 +909,9 @@ impl PipelineShaderStageCreateInfo {
}));
}
}
}
}
// TODO:
// VUID-VkPipelineShaderStageCreateInfo-module-08987
Ok(())
}

View File

@ -131,7 +131,7 @@
//! [`scalar_block_layout`]: crate::device::Features::scalar_block_layout
//! [`uniform_buffer_standard_layout`]: crate::device::Features::uniform_buffer_standard_layout
use self::spirv::Instruction;
use self::spirv::{Id, Instruction};
use crate::{
descriptor_set::layout::DescriptorType,
device::{Device, DeviceOwned},
@ -139,7 +139,7 @@ use crate::{
image::view::ImageViewType,
instance::InstanceOwnedDebugWrapper,
macros::{impl_id_counter, vulkan_bitflags_enum},
pipeline::{graphics::input_assembly::PrimitiveTopology, layout::PushConstantRange},
pipeline::layout::PushConstantRange,
shader::spirv::{Capability, Spirv},
sync::PipelineStages,
Requires, RequiresAllOf, RequiresOneOf, Validated, ValidationError, Version, VulkanError,
@ -699,8 +699,8 @@ impl From<f64> for SpecializationConstant {
pub struct SpecializedShaderModule {
base_module: Arc<ShaderModule>,
specialization_info: HashMap<u32, SpecializationConstant>,
_spirv: Option<Spirv>,
entry_point_infos: SmallVec<[EntryPointInfo; 1]>,
spirv: Option<Spirv>,
entry_point_infos: SmallVec<[(Id, EntryPointInfo); 1]>,
}
impl SpecializedShaderModule {
@ -760,7 +760,7 @@ impl SpecializedShaderModule {
Arc::new(Self {
base_module,
specialization_info,
_spirv: spirv,
spirv,
entry_point_infos,
})
}
@ -777,6 +777,12 @@ impl SpecializedShaderModule {
&self.specialization_info
}
/// Returns the SPIR-V code of this module.
#[inline]
pub(crate) fn spirv(&self) -> &Spirv {
self.spirv.as_ref().unwrap_or(&self.base_module.spirv)
}
/// Returns information about the entry point with the provided name. Returns `None` if no entry
/// point with that name exists in the shader module or if multiple entry points with the same
/// name exist.
@ -794,7 +800,7 @@ impl SpecializedShaderModule {
execution: ExecutionModel,
) -> Option<EntryPoint> {
self.single_entry_point_filter(|info| {
info.name == name && ExecutionModel::from(&info.execution) == execution
info.name == name && info.execution_model == execution
})
}
@ -808,11 +814,12 @@ impl SpecializedShaderModule {
.entry_point_infos
.iter()
.enumerate()
.filter(|(_, infos)| filter(infos))
.filter(|(_, (_, infos))| filter(infos))
.map(|(x, _)| x);
let info_index = iter.next()?;
iter.next().is_none().then(|| EntryPoint {
module: self.clone(),
id: self.entry_point_infos[info_index].0,
info_index,
})
}
@ -832,7 +839,7 @@ impl SpecializedShaderModule {
self: &Arc<Self>,
execution: ExecutionModel,
) -> Option<EntryPoint> {
self.single_entry_point_filter(|info| ExecutionModel::from(&info.execution) == execution)
self.single_entry_point_filter(|info| info.execution_model == execution)
}
}
@ -856,7 +863,7 @@ unsafe impl DeviceOwned for SpecializedShaderModule {
#[derive(Clone, Debug)]
pub struct EntryPointInfo {
pub name: String,
pub execution: ShaderExecution,
pub execution_model: ExecutionModel,
pub descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>,
pub push_constant_requirements: Option<PushConstantRange>,
pub input_interface: ShaderInterface,
@ -869,6 +876,7 @@ pub struct EntryPointInfo {
#[derive(Clone, Debug)]
pub struct EntryPoint {
module: Arc<SpecializedShaderModule>,
id: Id,
info_index: usize,
}
@ -879,151 +887,18 @@ impl EntryPoint {
&self.module
}
/// Returns the Id of the entry point function.
pub(crate) fn id(&self) -> Id {
self.id
}
/// Returns information about the entry point.
#[inline]
pub fn info(&self) -> &EntryPointInfo {
&self.module.entry_point_infos[self.info_index]
&self.module.entry_point_infos[self.info_index].1
}
}
/// The mode in which a shader executes. This includes both information about the shader type/stage,
/// and additional data relevant to particular shader types.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ShaderExecution {
Vertex,
TessellationControl,
TessellationEvaluation,
Geometry(GeometryShaderExecution),
Fragment(FragmentShaderExecution),
Compute(ComputeShaderExecution),
RayGeneration,
AnyHit,
ClosestHit,
Miss,
Intersection,
Callable,
Task, // TODO: like compute?
Mesh, // TODO: like compute?
SubpassShading,
}
impl From<&ShaderExecution> for ExecutionModel {
fn from(value: &ShaderExecution) -> Self {
match value {
ShaderExecution::Vertex => Self::Vertex,
ShaderExecution::TessellationControl => Self::TessellationControl,
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
ShaderExecution::Geometry(_) => Self::Geometry,
ShaderExecution::Fragment(_) => Self::Fragment,
ShaderExecution::Compute(_) => Self::GLCompute,
ShaderExecution::RayGeneration => Self::RayGenerationKHR,
ShaderExecution::AnyHit => Self::AnyHitKHR,
ShaderExecution::ClosestHit => Self::ClosestHitKHR,
ShaderExecution::Miss => Self::MissKHR,
ShaderExecution::Intersection => Self::IntersectionKHR,
ShaderExecution::Callable => Self::CallableKHR,
ShaderExecution::Task => Self::TaskNV,
ShaderExecution::Mesh => Self::MeshNV,
ShaderExecution::SubpassShading => todo!(),
}
}
}
/*#[derive(Clone, Copy, Debug)]
pub struct TessellationShaderExecution {
pub num_output_vertices: u32,
pub point_mode: bool,
pub subdivision: TessellationShaderSubdivision,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TessellationShaderSubdivision {
Triangles,
Quads,
Isolines,
}*/
/// The mode in which a geometry shader executes.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GeometryShaderExecution {
pub input: GeometryShaderInput,
/*pub max_output_vertices: u32,
pub num_invocations: u32,
pub output: GeometryShaderOutput,*/
}
/// The input primitive type that is expected by a geometry shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum GeometryShaderInput {
Points,
Lines,
LinesWithAdjacency,
Triangles,
TrianglesWithAdjacency,
}
impl GeometryShaderInput {
/// Returns true if the given primitive topology can be used as input for this geometry shader.
#[inline]
pub fn is_compatible_with(self, topology: PrimitiveTopology) -> bool {
match self {
Self::Points => matches!(topology, PrimitiveTopology::PointList),
Self::Lines => matches!(
topology,
PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
),
Self::LinesWithAdjacency => matches!(
topology,
PrimitiveTopology::LineListWithAdjacency
| PrimitiveTopology::LineStripWithAdjacency
),
Self::Triangles => matches!(
topology,
PrimitiveTopology::TriangleList
| PrimitiveTopology::TriangleStrip
| PrimitiveTopology::TriangleFan,
),
Self::TrianglesWithAdjacency => matches!(
topology,
PrimitiveTopology::TriangleListWithAdjacency
| PrimitiveTopology::TriangleStripWithAdjacency,
),
}
}
}
/*#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum GeometryShaderOutput {
Points,
LineStrip,
TriangleStrip,
}*/
/// The mode in which a fragment shader executes.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FragmentShaderExecution {
pub fragment_tests_stages: FragmentTestsStages,
}
/// The fragment tests stages that will be executed in a fragment shader.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum FragmentTestsStages {
Early,
Late,
EarlyAndLate,
}
/// The mode in which the compute shader executes.
///
/// The `WorkgroupSize` builtin overrides the values specified in the
/// execution mode. It can decorate a 3 component ConstantComposite or
/// SpecConstantComposite vector.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ComputeShaderExecution {
/// Workgroup size in x, y, and z.
pub local_size: [u32; 3],
}
/// The requirements imposed by a shader on a binding within a descriptor set layout, and on any
/// resource that is bound to that binding.
#[derive(Clone, Debug, Default)]
@ -1454,25 +1329,27 @@ vulkan_bitflags_enum! {
]),
}
impl From<&ShaderExecution> for ShaderStage {
impl From<ExecutionModel> for ShaderStage {
#[inline]
fn from(value: &ShaderExecution) -> Self {
fn from(value: ExecutionModel) -> Self {
match value {
ShaderExecution::Vertex => Self::Vertex,
ShaderExecution::TessellationControl => Self::TessellationControl,
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
ShaderExecution::Geometry(_) => Self::Geometry,
ShaderExecution::Fragment(_) => Self::Fragment,
ShaderExecution::Compute(_) => Self::Compute,
ShaderExecution::RayGeneration => Self::Raygen,
ShaderExecution::AnyHit => Self::AnyHit,
ShaderExecution::ClosestHit => Self::ClosestHit,
ShaderExecution::Miss => Self::Miss,
ShaderExecution::Intersection => Self::Intersection,
ShaderExecution::Callable => Self::Callable,
ShaderExecution::Task => Self::Task,
ShaderExecution::Mesh => Self::Mesh,
ShaderExecution::SubpassShading => Self::SubpassShading,
ExecutionModel::Vertex => ShaderStage::Vertex,
ExecutionModel::TessellationControl => ShaderStage::TessellationControl,
ExecutionModel::TessellationEvaluation => ShaderStage::TessellationEvaluation,
ExecutionModel::Geometry => ShaderStage::Geometry,
ExecutionModel::Fragment => ShaderStage::Fragment,
ExecutionModel::GLCompute => ShaderStage::Compute,
ExecutionModel::Kernel => {
unimplemented!("the `Kernel` execution model is not supported by Vulkan")
}
ExecutionModel::TaskNV | ExecutionModel::TaskEXT => ShaderStage::Task,
ExecutionModel::MeshNV | ExecutionModel::MeshEXT => ShaderStage::Mesh,
ExecutionModel::RayGenerationKHR => ShaderStage::Raygen,
ExecutionModel::IntersectionKHR => ShaderStage::Intersection,
ExecutionModel::AnyHitKHR => ShaderStage::AnyHit,
ExecutionModel::ClosestHitKHR => ShaderStage::ClosestHit,
ExecutionModel::MissKHR => ShaderStage::Miss,
ExecutionModel::CallableKHR => ShaderStage::Callable,
}
}
}

View File

@ -9,20 +9,15 @@
//! Extraction of information from SPIR-V modules, that is needed by the rest of Vulkano.
use super::{DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages};
use super::DescriptorBindingRequirements;
use crate::{
descriptor_set::layout::DescriptorType,
image::view::ImageViewType,
pipeline::layout::PushConstantRange,
shader::{
spirv::{
BuiltIn, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv,
StorageClass,
},
ComputeShaderExecution, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo,
GeometryShaderExecution, GeometryShaderInput, NumericType, ShaderExecution,
ShaderInterface, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage,
SpecializationConstant,
spirv::{Decoration, Dim, ExecutionModel, Id, Instruction, Spirv, StorageClass},
DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, NumericType, ShaderInterface,
ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, SpecializationConstant,
},
DeviceSize,
};
@ -32,7 +27,7 @@ use std::borrow::Cow;
/// Returns an iterator over all entry points in `spirv`, with information about the entry point.
#[inline]
pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_ {
pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = (Id, EntryPointInfo)> + '_ {
let interface_variables = interface_variables(spirv);
spirv.iter_entry_point().filter_map(move |instruction| {
@ -47,8 +42,7 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
_ => return None,
};
let execution = shader_execution(spirv, execution_model, function_id);
let stage = ShaderStage::from(&execution);
let stage = ShaderStage::from(execution_model);
let descriptor_binding_requirements = inspect_entry_point(
&interface_variables.descriptor_binding,
@ -75,187 +69,18 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
matches!(execution_model, ExecutionModel::TessellationControl),
);
Some(EntryPointInfo {
Some((
function_id,
EntryPointInfo {
name: entry_point_name.clone(),
execution,
execution_model,
descriptor_binding_requirements,
push_constant_requirements,
input_interface,
output_interface,
})
})
}
/// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
fn shader_execution(
spirv: &Spirv,
execution_model: ExecutionModel,
function_id: Id,
) -> ShaderExecution {
match execution_model {
ExecutionModel::Vertex => ShaderExecution::Vertex,
ExecutionModel::TessellationControl => ShaderExecution::TessellationControl,
ExecutionModel::TessellationEvaluation => ShaderExecution::TessellationEvaluation,
ExecutionModel::Geometry => {
let mut input = None;
for instruction in spirv.iter_execution_mode() {
let mode = match instruction {
Instruction::ExecutionMode {
entry_point, mode, ..
} if *entry_point == function_id => mode,
_ => continue,
};
match mode {
ExecutionMode::InputPoints => {
input = Some(GeometryShaderInput::Points);
}
ExecutionMode::InputLines => {
input = Some(GeometryShaderInput::Lines);
}
ExecutionMode::InputLinesAdjacency => {
input = Some(GeometryShaderInput::LinesWithAdjacency);
}
ExecutionMode::Triangles => {
input = Some(GeometryShaderInput::Triangles);
}
ExecutionMode::InputTrianglesAdjacency => {
input = Some(GeometryShaderInput::TrianglesWithAdjacency);
}
_ => (),
}
}
ShaderExecution::Geometry(GeometryShaderExecution {
input: input
.expect("Geometry shader does not have an input primitive ExecutionMode"),
})
}
ExecutionModel::Fragment => {
let mut fragment_tests_stages = FragmentTestsStages::Late;
for instruction in spirv.iter_execution_mode() {
let mode = match instruction {
Instruction::ExecutionMode {
entry_point, mode, ..
} if *entry_point == function_id => mode,
_ => continue,
};
match mode {
ExecutionMode::EarlyFragmentTests => {
fragment_tests_stages = FragmentTestsStages::Early;
}
ExecutionMode::EarlyAndLateFragmentTestsAMD => {
fragment_tests_stages = FragmentTestsStages::EarlyAndLate;
}
_ => (),
}
}
ShaderExecution::Fragment(FragmentShaderExecution {
fragment_tests_stages,
})
}
ExecutionModel::GLCompute => {
let local_size = (spirv
.iter_decoration()
.find_map(|instruction| match *instruction {
Instruction::Decorate {
target,
decoration:
Decoration::BuiltIn {
built_in: BuiltIn::WorkgroupSize,
},
} => match *spirv.id(target).instruction() {
Instruction::ConstantComposite {
ref constituents, ..
} => {
match *constituents.as_slice() {
[x_size, y_size, z_size] => {
Some([x_size, y_size, z_size].map(|id| {
match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
// VUID-WorkgroupSize-WorkgroupSize-04426
// VUID-WorkgroupSize-WorkgroupSize-04427
_ => panic!("WorkgroupSize is not a constant"),
}
}))
}
// VUID-WorkgroupSize-WorkgroupSize-04427
_ => panic!("WorkgroupSize must be 3 component vector!"),
}
}
// VUID-WorkgroupSize-WorkgroupSize-04426
_ => panic!("WorkgroupSize is not a constant"),
},
_ => None,
}))
.or_else(|| {
spirv
.iter_execution_mode()
.find_map(|instruction| match *instruction {
Instruction::ExecutionMode {
entry_point,
mode:
ExecutionMode::LocalSize {
x_size,
y_size,
z_size,
},
} if entry_point == function_id => Some([x_size, y_size, z_size]),
Instruction::ExecutionModeId {
entry_point,
mode:
ExecutionMode::LocalSizeId {
x_size,
y_size,
z_size,
},
} if entry_point == function_id => Some([x_size, y_size, z_size].map(
|id| match *spirv.id(id).instruction() {
Instruction::Constant { ref value, .. } => {
assert!(value.len() == 1);
value[0]
}
_ => panic!("LocalSizeId is not a constant"),
},
)),
_ => None,
))
})
});
ShaderExecution::Compute(ComputeShaderExecution {
local_size: local_size.expect(
"Geometry shader does not have a WorkgroupSize builtin, \
or LocalSize or LocalSizeId ExecutionMode",
),
})
}
ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,
ExecutionModel::AnyHitKHR => ShaderExecution::AnyHit,
ExecutionModel::ClosestHitKHR => ShaderExecution::ClosestHit,
ExecutionModel::MissKHR => ShaderExecution::Miss,
ExecutionModel::CallableKHR => ShaderExecution::Callable,
ExecutionModel::TaskEXT => ShaderExecution::Task,
ExecutionModel::TaskNV => todo!(),
ExecutionModel::MeshEXT => ShaderExecution::Mesh,
ExecutionModel::MeshNV => todo!(),
ExecutionModel::Kernel => todo!(),
}
}
#[derive(Clone, Debug, Default)]

View File

@ -132,8 +132,35 @@ impl Spirv {
let destination = match instruction {
Instruction::Function { result_id, .. } => {
current_function = None;
let function = functions.entry(result_id).or_insert(FunctionInfo {
let function = functions.entry(result_id).or_insert_with(|| {
let entry_point = instructions_entry_point
.iter()
.find(|instruction| {
matches!(
**instruction,
Instruction::EntryPoint { entry_point, .. }
if entry_point == result_id
)
})
.cloned();
let execution_modes = instructions_execution_mode
.iter()
.filter(|instruction| {
matches!(
**instruction,
Instruction::ExecutionMode { entry_point, .. }
| Instruction::ExecutionModeId { entry_point, .. }
if entry_point == result_id
)
})
.cloned()
.collect();
FunctionInfo {
instructions: Vec::new(),
entry_point,
execution_modes,
}
});
current_function.insert(&mut function.instructions)
}
@ -611,9 +638,11 @@ impl StructMemberInfo {
}
/// Information associated with a function.
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug)]
pub struct FunctionInfo {
instructions: Vec<Instruction>,
entry_point: Option<Instruction>,
execution_modes: Vec<Instruction>,
}
impl FunctionInfo {
@ -622,6 +651,18 @@ impl FunctionInfo {
pub fn iter_instructions(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.instructions.iter()
}
/// Returns the `EntryPoint` instruction that targets this function, if there is one.
#[inline]
pub fn entry_point(&self) -> Option<&Instruction> {
self.entry_point.as_ref()
}
/// Returns an iterator over all execution mode instructions that target this function.
#[inline]
pub fn iter_execution_mode(&self) -> impl ExactSizeIterator<Item = &Instruction> {
self.execution_modes.iter()
}
}
fn iter_instructions(