From a99aa95e61a2f3f5298ae455765a4cd0a4a542ef Mon Sep 17 00:00:00 2001 From: Rua Date: Thu, 24 Aug 2023 01:00:35 +0200 Subject: [PATCH] Clean up and improve shader stage checks (#2298) --- vulkano/src/pipeline/mod.rs | 343 ++++++++++++++++++++++------------ vulkano/src/shader/reflect.rs | 37 ++-- 2 files changed, 248 insertions(+), 132 deletions(-) diff --git a/vulkano/src/pipeline/mod.rs b/vulkano/src/pipeline/mod.rs index a8123a91..ce10fcfe 100644 --- a/vulkano/src/pipeline/mod.rs +++ b/vulkano/src/pipeline/mod.rs @@ -357,10 +357,12 @@ impl PipelineShaderStageCreateInfo { flags, ref entry_point, ref specialization_info, - ref required_subgroup_size, + required_subgroup_size, _ne: _, } = self; + let properties = device.physical_device().properties(); + flags.validate_device(device).map_err(|err| { err.add_context("flags") .set_vuids(&["VUID-VkPipelineShaderStageCreateInfo-flags-parameter"]) @@ -489,142 +491,249 @@ impl PipelineShaderStageCreateInfo { let workgroup_size = if let Some(local_size) = entry_point_info.local_size(specialization_info)? { - let [x, y, z] = local_size; - if x == 0 || y == 0 || z == 0 { - return Err(Box::new(ValidationError { - problem: format!("`workgroup size` {local_size:?} cannot be 0").into(), - ..Default::default() - })); - } - let properties = device.physical_device().properties(); - if stage_enum == ShaderStage::Compute { - let max_compute_work_group_size = properties.max_compute_work_group_size; - let [max_x, max_y, max_z] = max_compute_work_group_size; - if x > max_x || y > max_y || z > max_z { - return Err(Box::new(ValidationError { - problem: format!("`workgroup size` {local_size:?} is greater than `max_compute_work_group_size` {max_compute_work_group_size:?}").into(), - ..Default::default() - })); - } - let max_invocations = properties.max_compute_work_group_invocations; - if let Some(workgroup_size) = (|| x.checked_mul(y)?.checked_mul(z))() { - if workgroup_size > max_invocations { + match stage_enum { + ShaderStage::Compute => { + if local_size[0] > properties.max_compute_work_group_size[0] { return Err(Box::new(ValidationError { - problem: format!("the product of `workgroup size` {local_size:?} = {workgroup_size} is greater than `max_compute_work_group_invocations` {max_invocations}").into(), + problem: "the `local_size_x` of `entry_point` is greater than \ + `max_compute_work_group_size[0]`" + .into(), + vuids: &["VUID-RuntimeSpirv-x-06429"], ..Default::default() })); } - Some(workgroup_size) - } else { - return Err(Box::new(ValidationError { - problem: format!("the product of `workgroup size` {local_size:?} = (overflow) is greater than `max_compute_work_group_invocations` {max_invocations}").into(), + + if local_size[1] > properties.max_compute_work_group_size[1] { + return Err(Box::new(ValidationError { + problem: "the `local_size_y` of `entry_point` is greater than \ + `max_compute_work_group_size[1]`" + .into(), + vuids: &["VUID-RuntimeSpirv-x-06430"], ..Default::default() })); + } + + if local_size[2] > properties.max_compute_work_group_size[2] { + return Err(Box::new(ValidationError { + problem: "the `local_size_x` of `entry_point` is greater than \ + `max_compute_work_group_size[2]`" + .into(), + vuids: &["VUID-RuntimeSpirv-x-06431"], + ..Default::default() + })); + } + + let workgroup_size = local_size + .into_iter() + .try_fold(1, u32::checked_mul) + .filter(|&x| x <= properties.max_compute_work_group_invocations) + .ok_or_else(|| { + Box::new(ValidationError { + problem: "the product of the `local_size_x`, `local_size_y` and \ + `local_size_z` of `entry_point` is greater than the \ + `max_compute_work_group_invocations` device limit" + .into(), + vuids: &["VUID-RuntimeSpirv-x-06432"], + ..Default::default() + }) + })?; + + Some(workgroup_size) + } + ShaderStage::Task => { + if local_size[0] > properties.max_task_work_group_size.unwrap_or_default()[0] { + return Err(Box::new(ValidationError { + problem: "the `local_size_x` of `entry_point` is greater than \ + `max_task_work_group_size[0]`" + .into(), + vuids: &["VUID-RuntimeSpirv-TaskEXT-07291"], + ..Default::default() + })); + } + + if local_size[1] > properties.max_task_work_group_size.unwrap_or_default()[1] { + return Err(Box::new(ValidationError { + problem: "the `local_size_y` of `entry_point` is greater than \ + `max_task_work_group_size[1]`" + .into(), + vuids: &["VUID-RuntimeSpirv-TaskEXT-07292"], + ..Default::default() + })); + } + + if local_size[2] > properties.max_task_work_group_size.unwrap_or_default()[2] { + return Err(Box::new(ValidationError { + problem: "the `local_size_x` of `entry_point` is greater than \ + `max_task_work_group_size[2]`" + .into(), + vuids: &["VUID-RuntimeSpirv-TaskEXT-07293"], + ..Default::default() + })); + } + + let workgroup_size = local_size + .into_iter() + .try_fold(1, u32::checked_mul) + .filter(|&x| { + x <= properties + .max_task_work_group_invocations + .unwrap_or_default() + }) + .ok_or_else(|| { + Box::new(ValidationError { + problem: "the product of the `local_size_x`, `local_size_y` and \ + `local_size_z` of `entry_point` is greater than the \ + `max_task_work_group_invocations` device limit" + .into(), + vuids: &["VUID-RuntimeSpirv-TaskEXT-07294"], + ..Default::default() + }) + })?; + + Some(workgroup_size) + } + ShaderStage::Mesh => { + if local_size[0] > properties.max_mesh_work_group_size.unwrap_or_default()[0] { + return Err(Box::new(ValidationError { + problem: "the `local_size_x` of `entry_point` is greater than \ + `max_mesh_work_group_size[0]`" + .into(), + vuids: &["VUID-RuntimeSpirv-MeshEXT-07295"], + ..Default::default() + })); + } + + if local_size[1] > properties.max_mesh_work_group_size.unwrap_or_default()[1] { + return Err(Box::new(ValidationError { + problem: "the `local_size_y` of `entry_point` is greater than \ + `max_mesh_work_group_size[1]`" + .into(), + vuids: &["VUID-RuntimeSpirv-MeshEXT-07296"], + ..Default::default() + })); + } + + if local_size[2] > properties.max_mesh_work_group_size.unwrap_or_default()[2] { + return Err(Box::new(ValidationError { + problem: "the `local_size_x` of `entry_point` is greater than \ + `max_mesh_work_group_size[2]`" + .into(), + vuids: &["VUID-RuntimeSpirv-MeshEXT-07297"], + ..Default::default() + })); + } + + let workgroup_size = local_size + .into_iter() + .try_fold(1, u32::checked_mul) + .filter(|&x| { + x <= properties + .max_mesh_work_group_invocations + .unwrap_or_default() + }) + .ok_or_else(|| { + Box::new(ValidationError { + problem: "the product of the `local_size_x`, `local_size_y` and \ + `local_size_z` of `entry_point` is greater than the \ + `max_mesh_work_group_invocations` device limit" + .into(), + vuids: &["VUID-RuntimeSpirv-MeshEXT-07298"], + ..Default::default() + }) + })?; + + Some(workgroup_size) } - } else { // TODO: Additional stages when `.local_size()` supports them. - unreachable!() + _ => unreachable!(), } } else { None }; if let Some(required_subgroup_size) = required_subgroup_size { - validate_required_subgroup_size( - device, - stage_enum, - workgroup_size, - *required_subgroup_size, - )?; + if !device.enabled_features().subgroup_size_control { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: "is `Some`".into(), + requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature( + "subgroup_size_control", + )])]), + vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], + })); + } + + if !properties + .required_subgroup_size_stages + .unwrap_or_default() + .contains_enum(stage_enum) + { + return Err(Box::new(ValidationError { + problem: "`required_subgroup_size` is `Some`, but the \ + `required_subgroup_size_stages` device property does not contain the \ + shader stage of `entry_point`" + .into(), + vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], + ..Default::default() + })); + } + + if !required_subgroup_size.is_power_of_two() { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: "is not a power of 2".into(), + vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"], + ..Default::default() + })); + } + + if required_subgroup_size < properties.min_subgroup_size.unwrap_or(1) { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: "is less than the `min_subgroup_size` device limit".into(), + vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"], + ..Default::default() + })); + } + + if required_subgroup_size > properties.max_subgroup_size.unwrap_or(128) { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: "is greater than the `max_subgroup_size` device limit".into(), + vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"], + ..Default::default() + })); + } + + if let Some(workgroup_size) = workgroup_size { + if stage_enum == ShaderStage::Compute { + if workgroup_size + > required_subgroup_size + .checked_mul( + properties + .max_compute_workgroup_subgroups + .unwrap_or_default(), + ) + .unwrap_or(u32::MAX) + { + return Err(Box::new(ValidationError { + problem: "the product of the `local_size_x`, `local_size_y` and \ + `local_size_z` of `entry_point` is greater than the the product \ + of `required_subgroup_size` and the \ + `max_compute_workgroup_subgroups` device limit" + .into(), + vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"], + ..Default::default() + })); + } + } + } } Ok(()) } } -pub(crate) fn validate_required_subgroup_size( - device: &Device, - stage: ShaderStage, - workgroup_size: Option, - subgroup_size: u32, -) -> Result<(), Box> { - if !device.enabled_features().subgroup_size_control { - return Err(Box::new(ValidationError { - context: "required_subgroup_size".into(), - requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature( - "subgroup_size_control", - )])]), - vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], - ..Default::default() - })); - } - let properties = device.physical_device().properties(); - if !properties - .required_subgroup_size_stages - .unwrap_or_default() - .contains_enum(stage) - { - return Err(Box::new(ValidationError { - context: "required_subgroup_size".into(), - problem: format!("`shader stage` {stage:?} is not in `required_subgroup_size_stages`") - .into(), - vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], - ..Default::default() - })); - } - if !subgroup_size.is_power_of_two() { - return Err(Box::new(ValidationError { - context: "required_subgroup_size".into(), - problem: format!("`subgroup_size` {subgroup_size} is not a power of 2").into(), - vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"], - ..Default::default() - })); - } - let min_subgroup_size = properties.min_subgroup_size.unwrap_or(1); - if subgroup_size < min_subgroup_size { - return Err(Box::new(ValidationError { - context: "required_subgroup_size".into(), - problem: format!( - "`subgroup_size` {subgroup_size} is less than `min_subgroup_size` {min_subgroup_size}" - ) - .into(), - vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"], - ..Default::default() - })); - } - let max_subgroup_size = properties.max_subgroup_size.unwrap_or(128); - if subgroup_size > max_subgroup_size { - return Err(Box::new(ValidationError { - context: "required_subgroup_size".into(), - problem: - format!("`subgroup_size` {subgroup_size} is greater than `max_subgroup_size` {max_subgroup_size}") - .into(), - vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"], - ..Default::default() - })); - } - if let Some(workgroup_size) = workgroup_size { - if stage == ShaderStage::Compute { - let max_compute_workgroup_subgroups = properties - .max_compute_workgroup_subgroups - .unwrap_or_default(); - if max_compute_workgroup_subgroups - .checked_mul(subgroup_size) - .unwrap_or(u32::MAX) - < workgroup_size - { - return Err(Box::new(ValidationError { - context: "required_subgroup_size".into(), - problem: format!("`subgroup_size` {subgroup_size} creates more than {max_compute_workgroup_subgroups} subgroups").into(), - vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"], - ..Default::default() - })); - } - } - } - Ok(()) -} - vulkan_bitflags! { #[non_exhaustive] diff --git a/vulkano/src/shader/reflect.rs b/vulkano/src/shader/reflect.rs index 786a14aa..a35f97c8 100644 --- a/vulkano/src/shader/reflect.rs +++ b/vulkano/src/shader/reflect.rs @@ -163,15 +163,18 @@ fn specialization_constant_ids(spirv: &Spirv) -> HashMap { spirv .iter_decoration() .filter_map(|inst| { - if let Instruction::Decorate { target, decoration } = inst { - if let Decoration::SpecId { - specialization_constant_id, - } = decoration - { - return Some((*target, *specialization_constant_id)); - } + if let Instruction::Decorate { + target, + decoration: + Decoration::SpecId { + specialization_constant_id, + }, + } = inst + { + Some((*target, *specialization_constant_id)) + } else { + None } - None }) .collect() } @@ -181,14 +184,18 @@ fn workgroup_size_decorations(spirv: &Spirv) -> HashSet { spirv .iter_decoration() .filter_map(|inst| { - if let Instruction::Decorate { target, decoration } = inst { - if let Decoration::BuiltIn { built_in } = decoration { - if *built_in == BuiltIn::WorkgroupSize { - return Some(*target); - } - } + if let Instruction::Decorate { + target, + decoration: + Decoration::BuiltIn { + built_in: BuiltIn::WorkgroupSize, + }, + } = inst + { + Some(*target) + } else { + None } - None }) .collect() }