Clean up and improve shader stage checks (#2298)

This commit is contained in:
Rua 2023-08-24 01:00:35 +02:00 committed by GitHub
parent b4c628d0af
commit a99aa95e61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 248 additions and 132 deletions

View File

@ -357,10 +357,12 @@ impl PipelineShaderStageCreateInfo {
flags, flags,
ref entry_point, ref entry_point,
ref specialization_info, ref specialization_info,
ref required_subgroup_size, required_subgroup_size,
_ne: _, _ne: _,
} = self; } = self;
let properties = device.physical_device().properties();
flags.validate_device(device).map_err(|err| { flags.validate_device(device).map_err(|err| {
err.add_context("flags") err.add_context("flags")
.set_vuids(&["VUID-VkPipelineShaderStageCreateInfo-flags-parameter"]) .set_vuids(&["VUID-VkPipelineShaderStageCreateInfo-flags-parameter"])
@ -489,140 +491,247 @@ impl PipelineShaderStageCreateInfo {
let workgroup_size = if let Some(local_size) = let workgroup_size = if let Some(local_size) =
entry_point_info.local_size(specialization_info)? entry_point_info.local_size(specialization_info)?
{ {
let [x, y, z] = local_size; match stage_enum {
if x == 0 || y == 0 || z == 0 { ShaderStage::Compute => {
if local_size[0] > properties.max_compute_work_group_size[0] {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
problem: format!("`workgroup size` {local_size:?} cannot be 0").into(), problem: "the `local_size_x` of `entry_point` is greater than \
`max_compute_work_group_size[0]`"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06429"],
..Default::default() ..Default::default()
})); }));
} }
let properties = device.physical_device().properties();
if stage_enum == ShaderStage::Compute { if local_size[1] > properties.max_compute_work_group_size[1] {
let max_compute_work_group_size = properties.max_compute_work_group_size;
let [max_x, max_y, max_z] = max_compute_work_group_size;
if x > max_x || y > max_y || z > max_z {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
problem: format!("`workgroup size` {local_size:?} is greater than `max_compute_work_group_size` {max_compute_work_group_size:?}").into(), problem: "the `local_size_y` of `entry_point` is greater than \
`max_compute_work_group_size[1]`"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06430"],
..Default::default() ..Default::default()
})); }));
} }
let max_invocations = properties.max_compute_work_group_invocations;
if let Some(workgroup_size) = (|| x.checked_mul(y)?.checked_mul(z))() { if local_size[2] > properties.max_compute_work_group_size[2] {
if workgroup_size > max_invocations {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
problem: format!("the product of `workgroup size` {local_size:?} = {workgroup_size} is greater than `max_compute_work_group_invocations` {max_invocations}").into(), problem: "the `local_size_x` of `entry_point` is greater than \
`max_compute_work_group_size[2]`"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06431"],
..Default::default() ..Default::default()
})); }));
} }
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| x <= properties.max_compute_work_group_invocations)
.ok_or_else(|| {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_compute_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06432"],
..Default::default()
})
})?;
Some(workgroup_size) Some(workgroup_size)
} else { }
ShaderStage::Task => {
if local_size[0] > properties.max_task_work_group_size.unwrap_or_default()[0] {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
problem: format!("the product of `workgroup size` {local_size:?} = (overflow) is greater than `max_compute_work_group_invocations` {max_invocations}").into(), problem: "the `local_size_x` of `entry_point` is greater than \
`max_task_work_group_size[0]`"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07291"],
..Default::default() ..Default::default()
})); }));
} }
} else {
if local_size[1] > properties.max_task_work_group_size.unwrap_or_default()[1] {
return Err(Box::new(ValidationError {
problem: "the `local_size_y` of `entry_point` is greater than \
`max_task_work_group_size[1]`"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07292"],
..Default::default()
}));
}
if local_size[2] > properties.max_task_work_group_size.unwrap_or_default()[2] {
return Err(Box::new(ValidationError {
problem: "the `local_size_x` of `entry_point` is greater than \
`max_task_work_group_size[2]`"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07293"],
..Default::default()
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
.max_task_work_group_invocations
.unwrap_or_default()
})
.ok_or_else(|| {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_task_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07294"],
..Default::default()
})
})?;
Some(workgroup_size)
}
ShaderStage::Mesh => {
if local_size[0] > properties.max_mesh_work_group_size.unwrap_or_default()[0] {
return Err(Box::new(ValidationError {
problem: "the `local_size_x` of `entry_point` is greater than \
`max_mesh_work_group_size[0]`"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07295"],
..Default::default()
}));
}
if local_size[1] > properties.max_mesh_work_group_size.unwrap_or_default()[1] {
return Err(Box::new(ValidationError {
problem: "the `local_size_y` of `entry_point` is greater than \
`max_mesh_work_group_size[1]`"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07296"],
..Default::default()
}));
}
if local_size[2] > properties.max_mesh_work_group_size.unwrap_or_default()[2] {
return Err(Box::new(ValidationError {
problem: "the `local_size_x` of `entry_point` is greater than \
`max_mesh_work_group_size[2]`"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07297"],
..Default::default()
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
.max_mesh_work_group_invocations
.unwrap_or_default()
})
.ok_or_else(|| {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_mesh_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07298"],
..Default::default()
})
})?;
Some(workgroup_size)
}
// TODO: Additional stages when `.local_size()` supports them. // TODO: Additional stages when `.local_size()` supports them.
unreachable!() _ => unreachable!(),
} }
} else { } else {
None None
}; };
if let Some(required_subgroup_size) = required_subgroup_size { if let Some(required_subgroup_size) = required_subgroup_size {
validate_required_subgroup_size(
device,
stage_enum,
workgroup_size,
*required_subgroup_size,
)?;
}
Ok(())
}
}
pub(crate) fn validate_required_subgroup_size(
device: &Device,
stage: ShaderStage,
workgroup_size: Option<u32>,
subgroup_size: u32,
) -> Result<(), Box<ValidationError>> {
if !device.enabled_features().subgroup_size_control { if !device.enabled_features().subgroup_size_control {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(), context: "required_subgroup_size".into(),
problem: "is `Some`".into(),
requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature( requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature(
"subgroup_size_control", "subgroup_size_control",
)])]), )])]),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
..Default::default()
})); }));
} }
let properties = device.physical_device().properties();
if !properties if !properties
.required_subgroup_size_stages .required_subgroup_size_stages
.unwrap_or_default() .unwrap_or_default()
.contains_enum(stage) .contains_enum(stage_enum)
{ {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(), problem: "`required_subgroup_size` is `Some`, but the \
problem: format!("`shader stage` {stage:?} is not in `required_subgroup_size_stages`") `required_subgroup_size_stages` device property does not contain the \
shader stage of `entry_point`"
.into(), .into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
..Default::default() ..Default::default()
})); }));
} }
if !subgroup_size.is_power_of_two() {
if !required_subgroup_size.is_power_of_two() {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(), context: "required_subgroup_size".into(),
problem: format!("`subgroup_size` {subgroup_size} is not a power of 2").into(), problem: "is not a power of 2".into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"], vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"],
..Default::default() ..Default::default()
})); }));
} }
let min_subgroup_size = properties.min_subgroup_size.unwrap_or(1);
if subgroup_size < min_subgroup_size { if required_subgroup_size < properties.min_subgroup_size.unwrap_or(1) {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(), context: "required_subgroup_size".into(),
problem: format!( problem: "is less than the `min_subgroup_size` device limit".into(),
"`subgroup_size` {subgroup_size} is less than `min_subgroup_size` {min_subgroup_size}"
)
.into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"], vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"],
..Default::default() ..Default::default()
})); }));
} }
let max_subgroup_size = properties.max_subgroup_size.unwrap_or(128);
if subgroup_size > max_subgroup_size { if required_subgroup_size > properties.max_subgroup_size.unwrap_or(128) {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(), context: "required_subgroup_size".into(),
problem: problem: "is greater than the `max_subgroup_size` device limit".into(),
format!("`subgroup_size` {subgroup_size} is greater than `max_subgroup_size` {max_subgroup_size}")
.into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"], vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"],
..Default::default() ..Default::default()
})); }));
} }
if let Some(workgroup_size) = workgroup_size { if let Some(workgroup_size) = workgroup_size {
if stage == ShaderStage::Compute { if stage_enum == ShaderStage::Compute {
let max_compute_workgroup_subgroups = properties if workgroup_size
> required_subgroup_size
.checked_mul(
properties
.max_compute_workgroup_subgroups .max_compute_workgroup_subgroups
.unwrap_or_default(); .unwrap_or_default(),
if max_compute_workgroup_subgroups )
.checked_mul(subgroup_size)
.unwrap_or(u32::MAX) .unwrap_or(u32::MAX)
< workgroup_size
{ {
return Err(Box::new(ValidationError { return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(), problem: "the product of the `local_size_x`, `local_size_y` and \
problem: format!("`subgroup_size` {subgroup_size} creates more than {max_compute_workgroup_subgroups} subgroups").into(), `local_size_z` of `entry_point` is greater than the the product \
of `required_subgroup_size` and the \
`max_compute_workgroup_subgroups` device limit"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"], vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"],
..Default::default() ..Default::default()
})); }));
} }
} }
} }
}
Ok(()) Ok(())
}
} }
vulkan_bitflags! { vulkan_bitflags! {

View File

@ -163,15 +163,18 @@ fn specialization_constant_ids(spirv: &Spirv) -> HashMap<Id, u32> {
spirv spirv
.iter_decoration() .iter_decoration()
.filter_map(|inst| { .filter_map(|inst| {
if let Instruction::Decorate { target, decoration } = inst { if let Instruction::Decorate {
if let Decoration::SpecId { target,
decoration:
Decoration::SpecId {
specialization_constant_id, specialization_constant_id,
} = decoration },
} = inst
{ {
return Some((*target, *specialization_constant_id)); Some((*target, *specialization_constant_id))
} } else {
}
None None
}
}) })
.collect() .collect()
} }
@ -181,14 +184,18 @@ fn workgroup_size_decorations(spirv: &Spirv) -> HashSet<Id> {
spirv spirv
.iter_decoration() .iter_decoration()
.filter_map(|inst| { .filter_map(|inst| {
if let Instruction::Decorate { target, decoration } = inst { if let Instruction::Decorate {
if let Decoration::BuiltIn { built_in } = decoration { target,
if *built_in == BuiltIn::WorkgroupSize { decoration:
return Some(*target); Decoration::BuiltIn {
} built_in: BuiltIn::WorkgroupSize,
} },
} } = inst
{
Some(*target)
} else {
None None
}
}) })
.collect() .collect()
} }