Clean up and improve shader stage checks (#2298)

This commit is contained in:
Rua 2023-08-24 01:00:35 +02:00 committed by GitHub
parent b4c628d0af
commit a99aa95e61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 248 additions and 132 deletions

View File

@ -357,10 +357,12 @@ impl PipelineShaderStageCreateInfo {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
required_subgroup_size,
_ne: _,
} = self;
let properties = device.physical_device().properties();
flags.validate_device(device).map_err(|err| {
err.add_context("flags")
.set_vuids(&["VUID-VkPipelineShaderStageCreateInfo-flags-parameter"])
@ -489,140 +491,247 @@ impl PipelineShaderStageCreateInfo {
let workgroup_size = if let Some(local_size) =
entry_point_info.local_size(specialization_info)?
{
let [x, y, z] = local_size;
if x == 0 || y == 0 || z == 0 {
match stage_enum {
ShaderStage::Compute => {
if local_size[0] > properties.max_compute_work_group_size[0] {
return Err(Box::new(ValidationError {
problem: format!("`workgroup size` {local_size:?} cannot be 0").into(),
problem: "the `local_size_x` of `entry_point` is greater than \
`max_compute_work_group_size[0]`"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06429"],
..Default::default()
}));
}
let properties = device.physical_device().properties();
if stage_enum == ShaderStage::Compute {
let max_compute_work_group_size = properties.max_compute_work_group_size;
let [max_x, max_y, max_z] = max_compute_work_group_size;
if x > max_x || y > max_y || z > max_z {
if local_size[1] > properties.max_compute_work_group_size[1] {
return Err(Box::new(ValidationError {
problem: format!("`workgroup size` {local_size:?} is greater than `max_compute_work_group_size` {max_compute_work_group_size:?}").into(),
problem: "the `local_size_y` of `entry_point` is greater than \
`max_compute_work_group_size[1]`"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06430"],
..Default::default()
}));
}
let max_invocations = properties.max_compute_work_group_invocations;
if let Some(workgroup_size) = (|| x.checked_mul(y)?.checked_mul(z))() {
if workgroup_size > max_invocations {
if local_size[2] > properties.max_compute_work_group_size[2] {
return Err(Box::new(ValidationError {
problem: format!("the product of `workgroup size` {local_size:?} = {workgroup_size} is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
problem: "the `local_size_x` of `entry_point` is greater than \
`max_compute_work_group_size[2]`"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06431"],
..Default::default()
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| x <= properties.max_compute_work_group_invocations)
.ok_or_else(|| {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_compute_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-x-06432"],
..Default::default()
})
})?;
Some(workgroup_size)
} else {
}
ShaderStage::Task => {
if local_size[0] > properties.max_task_work_group_size.unwrap_or_default()[0] {
return Err(Box::new(ValidationError {
problem: format!("the product of `workgroup size` {local_size:?} = (overflow) is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
problem: "the `local_size_x` of `entry_point` is greater than \
`max_task_work_group_size[0]`"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07291"],
..Default::default()
}));
}
} else {
if local_size[1] > properties.max_task_work_group_size.unwrap_or_default()[1] {
return Err(Box::new(ValidationError {
problem: "the `local_size_y` of `entry_point` is greater than \
`max_task_work_group_size[1]`"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07292"],
..Default::default()
}));
}
if local_size[2] > properties.max_task_work_group_size.unwrap_or_default()[2] {
return Err(Box::new(ValidationError {
problem: "the `local_size_x` of `entry_point` is greater than \
`max_task_work_group_size[2]`"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07293"],
..Default::default()
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
.max_task_work_group_invocations
.unwrap_or_default()
})
.ok_or_else(|| {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_task_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-TaskEXT-07294"],
..Default::default()
})
})?;
Some(workgroup_size)
}
ShaderStage::Mesh => {
if local_size[0] > properties.max_mesh_work_group_size.unwrap_or_default()[0] {
return Err(Box::new(ValidationError {
problem: "the `local_size_x` of `entry_point` is greater than \
`max_mesh_work_group_size[0]`"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07295"],
..Default::default()
}));
}
if local_size[1] > properties.max_mesh_work_group_size.unwrap_or_default()[1] {
return Err(Box::new(ValidationError {
problem: "the `local_size_y` of `entry_point` is greater than \
`max_mesh_work_group_size[1]`"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07296"],
..Default::default()
}));
}
if local_size[2] > properties.max_mesh_work_group_size.unwrap_or_default()[2] {
return Err(Box::new(ValidationError {
problem: "the `local_size_x` of `entry_point` is greater than \
`max_mesh_work_group_size[2]`"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07297"],
..Default::default()
}));
}
let workgroup_size = local_size
.into_iter()
.try_fold(1, u32::checked_mul)
.filter(|&x| {
x <= properties
.max_mesh_work_group_invocations
.unwrap_or_default()
})
.ok_or_else(|| {
Box::new(ValidationError {
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the \
`max_mesh_work_group_invocations` device limit"
.into(),
vuids: &["VUID-RuntimeSpirv-MeshEXT-07298"],
..Default::default()
})
})?;
Some(workgroup_size)
}
// TODO: Additional stages when `.local_size()` supports them.
unreachable!()
_ => unreachable!(),
}
} else {
None
};
if let Some(required_subgroup_size) = required_subgroup_size {
validate_required_subgroup_size(
device,
stage_enum,
workgroup_size,
*required_subgroup_size,
)?;
}
Ok(())
}
}
pub(crate) fn validate_required_subgroup_size(
device: &Device,
stage: ShaderStage,
workgroup_size: Option<u32>,
subgroup_size: u32,
) -> Result<(), Box<ValidationError>> {
if !device.enabled_features().subgroup_size_control {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: "is `Some`".into(),
requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature(
"subgroup_size_control",
)])]),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
..Default::default()
}));
}
let properties = device.physical_device().properties();
if !properties
.required_subgroup_size_stages
.unwrap_or_default()
.contains_enum(stage)
.contains_enum(stage_enum)
{
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!("`shader stage` {stage:?} is not in `required_subgroup_size_stages`")
problem: "`required_subgroup_size` is `Some`, but the \
`required_subgroup_size_stages` device property does not contain the \
shader stage of `entry_point`"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
..Default::default()
}));
}
if !subgroup_size.is_power_of_two() {
if !required_subgroup_size.is_power_of_two() {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!("`subgroup_size` {subgroup_size} is not a power of 2").into(),
problem: "is not a power of 2".into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"],
..Default::default()
}));
}
let min_subgroup_size = properties.min_subgroup_size.unwrap_or(1);
if subgroup_size < min_subgroup_size {
if required_subgroup_size < properties.min_subgroup_size.unwrap_or(1) {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!(
"`subgroup_size` {subgroup_size} is less than `min_subgroup_size` {min_subgroup_size}"
)
.into(),
problem: "is less than the `min_subgroup_size` device limit".into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"],
..Default::default()
}));
}
let max_subgroup_size = properties.max_subgroup_size.unwrap_or(128);
if subgroup_size > max_subgroup_size {
if required_subgroup_size > properties.max_subgroup_size.unwrap_or(128) {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem:
format!("`subgroup_size` {subgroup_size} is greater than `max_subgroup_size` {max_subgroup_size}")
.into(),
problem: "is greater than the `max_subgroup_size` device limit".into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"],
..Default::default()
}));
}
if let Some(workgroup_size) = workgroup_size {
if stage == ShaderStage::Compute {
let max_compute_workgroup_subgroups = properties
if stage_enum == ShaderStage::Compute {
if workgroup_size
> required_subgroup_size
.checked_mul(
properties
.max_compute_workgroup_subgroups
.unwrap_or_default();
if max_compute_workgroup_subgroups
.checked_mul(subgroup_size)
.unwrap_or_default(),
)
.unwrap_or(u32::MAX)
< workgroup_size
{
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!("`subgroup_size` {subgroup_size} creates more than {max_compute_workgroup_subgroups} subgroups").into(),
problem: "the product of the `local_size_x`, `local_size_y` and \
`local_size_z` of `entry_point` is greater than the the product \
of `required_subgroup_size` and the \
`max_compute_workgroup_subgroups` device limit"
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"],
..Default::default()
}));
}
}
}
}
Ok(())
}
}
vulkan_bitflags! {

View File

@ -163,15 +163,18 @@ fn specialization_constant_ids(spirv: &Spirv) -> HashMap<Id, u32> {
spirv
.iter_decoration()
.filter_map(|inst| {
if let Instruction::Decorate { target, decoration } = inst {
if let Decoration::SpecId {
if let Instruction::Decorate {
target,
decoration:
Decoration::SpecId {
specialization_constant_id,
} = decoration
},
} = inst
{
return Some((*target, *specialization_constant_id));
}
}
Some((*target, *specialization_constant_id))
} else {
None
}
})
.collect()
}
@ -181,14 +184,18 @@ fn workgroup_size_decorations(spirv: &Spirv) -> HashSet<Id> {
spirv
.iter_decoration()
.filter_map(|inst| {
if let Instruction::Decorate { target, decoration } = inst {
if let Decoration::BuiltIn { built_in } = decoration {
if *built_in == BuiltIn::WorkgroupSize {
return Some(*target);
}
}
}
if let Instruction::Decorate {
target,
decoration:
Decoration::BuiltIn {
built_in: BuiltIn::WorkgroupSize,
},
} = inst
{
Some(*target)
} else {
None
}
})
.collect()
}