Added required_subgroup_size to PipelineShaderStageCreateInfo (#2235)

* Added required_subgroup_size to PipelineShaderStageCreateInfo

* Added validation errors.

* Fixed error msgs / vuids.

* ComputeShaderExecution for validating local_size.

* WorkgroupSizeId reflection.

* contains_enum

* Reworked ComputeShaderExecution.

* panic msgs.

* workgroup size validation

* unused import

* fixed test deprecated fn

* catch workgroup size overflow

* EntryPointInfo::local_size docs

* comments

* typo + error msg
This commit is contained in:
charles-r-earp 2023-08-18 04:37:10 -07:00 committed by GitHub
parent 4133a3bf63
commit ee4e3089d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 676 additions and 18 deletions

View File

@ -73,7 +73,49 @@ fn write_shader_execution(execution: &ShaderExecution) -> TokenStream {
)
}
}
ShaderExecution::Compute => quote! { ::vulkano::shader::ShaderExecution::Compute },
ShaderExecution::Compute(execution) => {
use ::quote::ToTokens;
use ::vulkano::shader::{ComputeShaderExecution, LocalSize};
struct LocalSizeToTokens(LocalSize);
impl ToTokens for LocalSizeToTokens {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self.0 {
LocalSize::Literal(literal) => quote! {
::vulkano::shader::LocalSize::Literal(#literal)
},
LocalSize::SpecId(id) => quote! {
::vulkano::shader::LocalSize::SpecId(#id)
},
}
.to_tokens(tokens);
}
}
match execution {
ComputeShaderExecution::LocalSize([x, y, z]) => {
let [x, y, z] = [
LocalSizeToTokens(*x),
LocalSizeToTokens(*y),
LocalSizeToTokens(*z),
];
quote! { ::vulkano::shader::ShaderExecution::Compute(
::vulkano::shader::ComputeShaderExecution::LocalSize([#x, #y, #z])
) }
}
ComputeShaderExecution::LocalSizeId([x, y, z]) => {
let [x, y, z] = [
LocalSizeToTokens(*x),
LocalSizeToTokens(*y),
LocalSizeToTokens(*z),
];
quote! { ::vulkano::shader::ShaderExecution::Compute(
::vulkano::shader::ComputeShaderExecution::LocalSizeId([#x, #y, #z])
) }
}
}
}
ShaderExecution::RayGeneration => {
quote! { ::vulkano::shader::ShaderExecution::RayGeneration }
}

View File

@ -74,11 +74,9 @@ impl ComputePipeline {
if let Some(cache) = &cache {
assert_eq!(device, cache.device().as_ref());
}
create_info
.validate(device)
.map_err(|err| err.add_context("create_info"))?;
Ok(())
}
@ -100,12 +98,14 @@ impl ComputePipeline {
let specialization_info_vk;
let specialization_map_entries_vk: Vec<_>;
let mut specialization_data_vk: Vec<u8>;
let required_subgroup_size_create_info;
{
let &PipelineShaderStageCreateInfo {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
_ne: _,
} = stage;
@ -135,7 +135,20 @@ impl ComputePipeline {
data_size: specialization_data_vk.len(),
p_data: specialization_data_vk.as_ptr() as *const _,
};
required_subgroup_size_create_info =
required_subgroup_size.map(|required_subgroup_size| {
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
required_subgroup_size,
..Default::default()
}
});
stage_vk = ash::vk::PipelineShaderStageCreateInfo {
p_next: required_subgroup_size_create_info.as_ref().map_or(
ptr::null(),
|required_subgroup_size_create_info| {
required_subgroup_size_create_info as *const _ as _
},
),
flags: flags.into(),
stage: ShaderStage::from(&entry_point_info.execution).into(),
module: entry_point.module().handle(),
@ -333,12 +346,13 @@ impl ComputePipelineCreateInfo {
flags: _,
ref entry_point,
specialization_info: _,
required_subgroup_size: _vk,
_ne: _,
} = &stage;
let entry_point_info = entry_point.info();
if !matches!(entry_point_info.execution, ShaderExecution::Compute) {
if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) {
return Err(Box::new(ValidationError {
context: "stage.entry_point".into(),
problem: "is not a `ShaderStage::Compute` entry point".into(),
@ -514,4 +528,135 @@ mod tests {
let data_buffer_content = data_buffer.read().unwrap();
assert_eq!(*data_buffer_content, 0x12345678);
}
#[test]
fn required_subgroup_size() {
// This test checks whether required_subgroup_size works.
// It executes a single compute shader (one invocation) that writes the subgroup size
// to a buffer. The buffer content is then checked for the right value.
let (device, queue) = gfx_dev_and_queue!(subgroup_size_control);
let cs = unsafe {
/*
#version 450
#extension GL_KHR_shader_subgroup_basic: enable
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer Output {
uint write;
} write;
void main() {
if (gl_GlobalInvocationID.x == 0) {
write.write = gl_SubgroupSize;
}
}
*/
const MODULE: [u32; 246] = [
119734787, 65536, 851978, 30, 0, 131089, 1, 131089, 61, 393227, 1, 1280527431,
1685353262, 808793134, 0, 196622, 0, 1, 458767, 5, 4, 1852399981, 0, 9, 23, 393232,
4, 17, 128, 1, 1, 196611, 2, 450, 655364, 1197427783, 1279741775, 1885560645,
1953718128, 1600482425, 1701734764, 1919509599, 1769235301, 25974, 524292,
1197427783, 1279741775, 1852399429, 1685417059, 1768185701, 1952671090, 6649449,
589828, 1264536647, 1935626824, 1701077352, 1970495346, 1869768546, 1650421877,
1667855201, 0, 262149, 4, 1852399981, 0, 524293, 9, 1197436007, 1633841004,
1986939244, 1952539503, 1231974249, 68, 262149, 18, 1886680399, 29813, 327686, 18,
0, 1953067639, 101, 262149, 20, 1953067639, 101, 393221, 23, 1398762599,
1919378037, 1399879023, 6650473, 262215, 9, 11, 28, 327752, 18, 0, 35, 0, 196679,
18, 3, 262215, 20, 34, 0, 262215, 20, 33, 0, 196679, 23, 0, 262215, 23, 11, 36,
196679, 24, 0, 262215, 29, 11, 25, 131091, 2, 196641, 3, 2, 262165, 6, 32, 0,
262167, 7, 6, 3, 262176, 8, 1, 7, 262203, 8, 9, 1, 262187, 6, 10, 0, 262176, 11, 1,
6, 131092, 14, 196638, 18, 6, 262176, 19, 2, 18, 262203, 19, 20, 2, 262165, 21, 32,
1, 262187, 21, 22, 0, 262203, 11, 23, 1, 262176, 25, 2, 6, 262187, 6, 27, 128,
262187, 6, 28, 1, 393260, 7, 29, 27, 28, 28, 327734, 2, 4, 0, 3, 131320, 5, 327745,
11, 12, 9, 10, 262205, 6, 13, 12, 327850, 14, 15, 13, 10, 196855, 17, 0, 262394,
15, 16, 17, 131320, 16, 262205, 6, 24, 23, 327745, 25, 26, 20, 22, 196670, 26, 24,
131321, 17, 131320, 17, 65789, 65592,
];
let module =
ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&MODULE)).unwrap();
module.entry_point("main").unwrap()
};
let properties = device.physical_device().properties();
let subgroup_size = properties.min_subgroup_size.unwrap_or(1);
let pipeline = {
let stage = PipelineShaderStageCreateInfo {
required_subgroup_size: Some(subgroup_size),
..PipelineShaderStageCreateInfo::new(cs)
};
let layout = PipelineLayout::new(
device.clone(),
PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage])
.into_pipeline_layout_create_info(device.clone())
.unwrap(),
)
.unwrap();
ComputePipeline::new(
device.clone(),
None,
ComputePipelineCreateInfo::stage_layout(stage, layout),
)
.unwrap()
};
let memory_allocator = StandardMemoryAllocator::new_default(device.clone());
let data_buffer = Buffer::from_data(
&memory_allocator,
BufferCreateInfo {
usage: BufferUsage::STORAGE_BUFFER,
..Default::default()
},
AllocationCreateInfo {
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
| MemoryTypeFilter::HOST_RANDOM_ACCESS,
..Default::default()
},
0,
)
.unwrap();
let ds_allocator = StandardDescriptorSetAllocator::new(device.clone());
let set = PersistentDescriptorSet::new(
&ds_allocator,
pipeline.layout().set_layouts().get(0).unwrap().clone(),
[WriteDescriptorSet::buffer(0, data_buffer.clone())],
[],
)
.unwrap();
let cb_allocator = StandardCommandBufferAllocator::new(device.clone(), Default::default());
let mut cbb = AutoCommandBufferBuilder::primary(
&cb_allocator,
queue.queue_family_index(),
CommandBufferUsage::OneTimeSubmit,
)
.unwrap();
cbb.bind_pipeline_compute(pipeline.clone())
.unwrap()
.bind_descriptor_sets(
PipelineBindPoint::Compute,
pipeline.layout().clone(),
0,
set,
)
.unwrap()
.dispatch([128, 1, 1])
.unwrap();
let cb = cbb.build().unwrap();
let future = now(device)
.then_execute(queue, cb)
.unwrap()
.then_signal_fence_and_flush()
.unwrap();
future.wait(None).unwrap();
let data_buffer_content = data_buffer.read().unwrap();
assert_eq!(*data_buffer_content, subgroup_size);
}
}

View File

@ -167,7 +167,6 @@ impl GraphicsPipeline {
create_info
.validate(device)
.map_err(|err| err.add_context("create_info"))?;
Ok(())
}
@ -204,6 +203,8 @@ impl GraphicsPipeline {
specialization_info_vk: ash::vk::SpecializationInfo,
specialization_map_entries_vk: Vec<ash::vk::SpecializationMapEntry>,
specialization_data_vk: Vec<u8>,
required_subgroup_size_create_info:
Option<ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo>,
}
let (mut stages_vk, mut per_stage_vk): (SmallVec<[_; 5]>, SmallVec<[_; 5]>) = stages
@ -213,6 +214,7 @@ impl GraphicsPipeline {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
_ne: _,
} = stage;
@ -235,7 +237,13 @@ impl GraphicsPipeline {
}
})
.collect();
let required_subgroup_size_create_info =
required_subgroup_size.map(|required_subgroup_size| {
ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo {
required_subgroup_size,
..Default::default()
}
});
(
ash::vk::PipelineShaderStageCreateInfo {
flags: flags.into(),
@ -255,6 +263,7 @@ impl GraphicsPipeline {
},
specialization_map_entries_vk,
specialization_data_vk,
required_subgroup_size_create_info,
},
)
})
@ -267,10 +276,17 @@ impl GraphicsPipeline {
specialization_info_vk,
specialization_map_entries_vk,
specialization_data_vk,
required_subgroup_size_create_info,
},
) in (stages_vk.iter_mut()).zip(per_stage_vk.iter_mut())
{
*stage_vk = ash::vk::PipelineShaderStageCreateInfo {
p_next: required_subgroup_size_create_info.as_ref().map_or(
ptr::null(),
|required_subgroup_size_create_info| {
required_subgroup_size_create_info as *const _ as _
},
),
p_name: name_vk.as_ptr(),
p_specialization_info: specialization_info_vk,
..*stage_vk
@ -2420,6 +2436,7 @@ impl GraphicsPipelineCreateInfo {
flags: _,
ref entry_point,
specialization_info: _,
required_subgroup_size: _vk,
_ne: _,
} = stage;

View File

@ -321,6 +321,21 @@ pub struct PipelineShaderStageCreateInfo {
/// The default value is empty.
pub specialization_info: HashMap<u32, SpecializationConstant>,
/// The required subgroup size.
///
/// Requires [`subgroup_size_control`](crate::device::Features::subgroup_size_control). The
/// shader stage must be included in
/// [`required_subgroup_size_stages`](crate::device::Properties::required_subgroup_size_stages).
/// Subgroup size must be power of 2 and within
/// [`min_subgroup_size`](crate::device::Properties::min_subgroup_size)
/// and [`max_subgroup_size`](crate::device::Properties::max_subgroup_size).
///
/// For compute shaders, `max_compute_workgroup_subgroups * required_subgroup_size` must be
/// greater than or equal to `workgroup_size.x * workgroup_size.y * workgroup_size.z`.
///
/// The default value is None.
pub required_subgroup_size: Option<u32>,
pub _ne: crate::NonExhaustive,
}
@ -332,6 +347,7 @@ impl PipelineShaderStageCreateInfo {
flags: PipelineShaderStageCreateFlags::empty(),
entry_point,
specialization_info: HashMap::default(),
required_subgroup_size: None,
_ne: crate::NonExhaustive(()),
}
}
@ -341,6 +357,7 @@ impl PipelineShaderStageCreateInfo {
flags,
ref entry_point,
ref specialization_info,
ref required_subgroup_size,
_ne: _,
} = self;
@ -469,10 +486,145 @@ impl PipelineShaderStageCreateInfo {
}
}
let workgroup_size = if let Some(local_size) =
entry_point_info.local_size(specialization_info)?
{
let [x, y, z] = local_size;
if x == 0 || y == 0 || z == 0 {
return Err(Box::new(ValidationError {
problem: format!("`workgroup size` {local_size:?} cannot be 0").into(),
..Default::default()
}));
}
let properties = device.physical_device().properties();
if stage_enum == ShaderStage::Compute {
let max_compute_work_group_size = properties.max_compute_work_group_size;
let [max_x, max_y, max_z] = max_compute_work_group_size;
if x > max_x || y > max_y || z > max_z {
return Err(Box::new(ValidationError {
problem: format!("`workgroup size` {local_size:?} is greater than `max_compute_work_group_size` {max_compute_work_group_size:?}").into(),
..Default::default()
}));
}
let max_invocations = properties.max_compute_work_group_invocations;
if let Some(workgroup_size) = (|| x.checked_mul(y)?.checked_mul(z))() {
if workgroup_size > max_invocations {
return Err(Box::new(ValidationError {
problem: format!("the product of `workgroup size` {local_size:?} = {workgroup_size} is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
..Default::default()
}));
}
Some(workgroup_size)
} else {
return Err(Box::new(ValidationError {
problem: format!("the product of `workgroup size` {local_size:?} = (overflow) is greater than `max_compute_work_group_invocations` {max_invocations}").into(),
..Default::default()
}));
}
} else {
// TODO: Additional stages when `.local_size()` supports them.
unreachable!()
}
} else {
None
};
if let Some(required_subgroup_size) = required_subgroup_size {
validate_required_subgroup_size(
device,
stage_enum,
workgroup_size,
*required_subgroup_size,
)?;
}
Ok(())
}
}
pub(crate) fn validate_required_subgroup_size(
device: &Device,
stage: ShaderStage,
workgroup_size: Option<u32>,
subgroup_size: u32,
) -> Result<(), Box<ValidationError>> {
if !device.enabled_features().subgroup_size_control {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature(
"subgroup_size_control",
)])]),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
..Default::default()
}));
}
let properties = device.physical_device().properties();
if !properties
.required_subgroup_size_stages
.unwrap_or_default()
.contains_enum(stage)
{
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!("`shader stage` {stage:?} is not in `required_subgroup_size_stages`")
.into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"],
..Default::default()
}));
}
if !subgroup_size.is_power_of_two() {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!("`subgroup_size` {subgroup_size} is not a power of 2").into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"],
..Default::default()
}));
}
let min_subgroup_size = properties.min_subgroup_size.unwrap_or(1);
if subgroup_size < min_subgroup_size {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!(
"`subgroup_size` {subgroup_size} is less than `min_subgroup_size` {min_subgroup_size}"
)
.into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"],
..Default::default()
}));
}
let max_subgroup_size = properties.max_subgroup_size.unwrap_or(128);
if subgroup_size > max_subgroup_size {
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem:
format!("`subgroup_size` {subgroup_size} is greater than `max_subgroup_size` {max_subgroup_size}")
.into(),
vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"],
..Default::default()
}));
}
if let Some(workgroup_size) = workgroup_size {
if stage == ShaderStage::Compute {
let max_compute_workgroup_subgroups = properties
.max_compute_workgroup_subgroups
.unwrap_or_default();
if max_compute_workgroup_subgroups
.checked_mul(subgroup_size)
.unwrap_or(u32::MAX)
< workgroup_size
{
return Err(Box::new(ValidationError {
context: "required_subgroup_size".into(),
problem: format!("`subgroup_size` {subgroup_size} creates more than {max_compute_workgroup_subgroups} subgroups").into(),
vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"],
..Default::default()
}));
}
}
}
Ok(())
}
vulkan_bitflags! {
#[non_exhaustive]

View File

@ -549,6 +549,78 @@ pub struct EntryPointInfo {
pub output_interface: ShaderInterface,
}
impl EntryPointInfo {
/// The local size in Compute shaders, None otherwise.
///
/// `specialization_info` is used for LocalSizeId / WorkgroupSizeId, using specialization_constants if not found.
/// Errors if specialization constants are not found or are not u32's.
pub(crate) fn local_size(
&self,
specialization_info: &HashMap<u32, SpecializationConstant>,
) -> Result<Option<[u32; 3]>, Box<ValidationError>> {
if let ShaderExecution::Compute(execution) = self.execution {
match execution {
ComputeShaderExecution::LocalSize(local_size)
| ComputeShaderExecution::LocalSizeId(local_size) => {
let mut output = [0; 3];
for (output, local_size) in output.iter_mut().zip(local_size) {
let id = match local_size {
LocalSize::Literal(literal) => {
*output = literal;
continue;
}
LocalSize::SpecId(id) => id,
};
if let Some(default) = self.specialization_constants.get(&id) {
let default_value = if let SpecializationConstant::U32(default_value) =
default
{
default_value
} else {
return Err(Box::new(ValidationError {
problem: format!(
"`entry_point.info().specialization_constants[{id}]` is not a 32 bit integer"
)
.into(),
..Default::default()
}));
};
if let Some(provided) = specialization_info.get(&id) {
if let SpecializationConstant::U32(provided_value) = provided {
*output = *provided_value;
} else {
return Err(Box::new(ValidationError {
problem: format!(
"`specialization_info[{0}]` does not have the same type as \
`entry_point.info().specialization_constants[{0}]`",
id,
)
.into(),
vuids: &["VUID-VkSpecializationMapEntry-constantID-00776"],
..Default::default()
}));
}
}
*output = *default_value;
} else {
return Err(Box::new(ValidationError {
problem: format!(
"specialization constant {id} not found in `entry_point.info().specialization_constants`"
)
.into(),
..Default::default()
}));
}
}
Ok(Some(output))
}
}
} else {
Ok(None)
}
}
}
/// Represents a shader entry point in a shader module.
///
/// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module.
@ -581,15 +653,15 @@ pub enum ShaderExecution {
TessellationEvaluation,
Geometry(GeometryShaderExecution),
Fragment(FragmentShaderExecution),
Compute,
Compute(ComputeShaderExecution),
RayGeneration,
AnyHit,
ClosestHit,
Miss,
Intersection,
Callable,
Task,
Mesh,
Task, // TODO: like compute?
Mesh, // TODO: like compute?
SubpassShading,
}
@ -601,7 +673,7 @@ impl From<&ShaderExecution> for ExecutionModel {
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
ShaderExecution::Geometry(_) => Self::Geometry,
ShaderExecution::Fragment(_) => Self::Fragment,
ShaderExecution::Compute => Self::GLCompute,
ShaderExecution::Compute(_) => Self::GLCompute,
ShaderExecution::RayGeneration => Self::RayGenerationKHR,
ShaderExecution::AnyHit => Self::AnyHitKHR,
ShaderExecution::ClosestHit => Self::ClosestHitKHR,
@ -699,6 +771,32 @@ pub enum FragmentTestsStages {
EarlyAndLate,
}
/// LocalSize.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LocalSize {
Literal(u32),
SpecId(u32),
}
/// The mode in which the compute shader executes.
///
/// The workgroup size is specified for x, y, and z dimensions.
///
/// Constants are resolved to literals, while specialization constants
/// map to spec ids.
///
/// The `WorkgroupSize` builtin overrides the values specified in the
/// execution mode. It can decorate a 3 component ConstantComposite or
/// SpecConstantComposite vector.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ComputeShaderExecution {
/// Workgroup size in x, y, and z.
LocalSize([LocalSize; 3]),
/// Requires spirv 1.2.
/// Like `LocalSize`, but uses ids instead of literals.
LocalSizeId([LocalSize; 3]),
}
/// The requirements imposed by a shader on a binding within a descriptor set layout, and on any
/// resource that is bound to that binding.
#[derive(Clone, Debug, Default)]
@ -1263,7 +1361,7 @@ impl From<&ShaderExecution> for ShaderStage {
ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
ShaderExecution::Geometry(_) => Self::Geometry,
ShaderExecution::Fragment(_) => Self::Fragment,
ShaderExecution::Compute => Self::Compute,
ShaderExecution::Compute(_) => Self::Compute,
ShaderExecution::RayGeneration => Self::Raygen,
ShaderExecution::AnyHit => Self::AnyHit,
ShaderExecution::ClosestHit => Self::ClosestHit,

View File

@ -16,12 +16,13 @@ use crate::{
pipeline::layout::PushConstantRange,
shader::{
spirv::{
Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv,
StorageClass,
BuiltIn, Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction,
Spirv, StorageClass,
},
DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, GeometryShaderExecution,
GeometryShaderInput, NumericType, ShaderExecution, ShaderInterface, ShaderInterfaceEntry,
ShaderInterfaceEntryType, ShaderStage, SpecializationConstant,
ComputeShaderExecution, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo,
GeometryShaderExecution, GeometryShaderInput, LocalSize, NumericType, ShaderExecution,
ShaderInterface, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage,
SpecializationConstant,
},
DeviceSize,
};
@ -55,6 +56,9 @@ pub fn spirv_extensions(spirv: &Spirv) -> impl Iterator<Item = &str> {
#[inline]
pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_ {
let interface_variables = interface_variables(spirv);
let u32_constants = u32_constants(spirv);
let specialization_constant_ids = specialization_constant_ids(spirv);
let workgroup_size_decorations = workgroup_size_decorations(spirv);
spirv.iter_entry_point().filter_map(move |instruction| {
let (execution_model, function_id, entry_point_name, interface) = match instruction {
@ -68,7 +72,14 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
_ => return None,
};
let execution = shader_execution(spirv, execution_model, function_id);
let execution = shader_execution(
spirv,
execution_model,
function_id,
&u32_constants,
&specialization_constant_ids,
&workgroup_size_decorations,
);
let stage = ShaderStage::from(&execution);
let descriptor_binding_requirements = inspect_entry_point(
@ -109,11 +120,87 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = EntryPointInfo> + '_
})
}
/// Extracts the u32 constants from `spirv`.
fn u32_constants(spirv: &Spirv) -> HashMap<Id, u32> {
let type_u32s: HashSet<Id> = spirv
.iter_global()
.filter_map(|inst| {
if let Instruction::TypeInt {
result_id,
width,
signedness,
} = inst
{
if *width == 0 && *signedness == 0 {
return Some(*result_id);
}
}
None
})
.collect();
spirv
.iter_decoration()
.filter_map(|inst| {
if let Instruction::Constant {
result_type_id,
result_id,
value,
} = inst
{
if type_u32s.contains(result_type_id) {
if let [value] = value.as_slice() {
return Some((*result_id, *value));
}
}
}
None
})
.collect()
}
/// Extracts the specialization constant ids from `spirv`.
fn specialization_constant_ids(spirv: &Spirv) -> HashMap<Id, u32> {
spirv
.iter_decoration()
.filter_map(|inst| {
if let Instruction::Decorate { target, decoration } = inst {
if let Decoration::SpecId {
specialization_constant_id,
} = decoration
{
return Some((*target, *specialization_constant_id));
}
}
None
})
.collect()
}
/// Extracts the `WorkgroupSize` builtin Id's from `spirv`.
fn workgroup_size_decorations(spirv: &Spirv) -> HashSet<Id> {
spirv
.iter_decoration()
.filter_map(|inst| {
if let Instruction::Decorate { target, decoration } = inst {
if let Decoration::BuiltIn { built_in } = decoration {
if *built_in == BuiltIn::WorkgroupSize {
return Some(*target);
}
}
}
None
})
.collect()
}
/// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
fn shader_execution(
spirv: &Spirv,
execution_model: ExecutionModel,
function_id: Id,
u32_constants: &HashMap<Id, u32>,
specialization_constant_ids: &HashMap<Id, u32>,
workgroup_size_decorations: &HashSet<Id>,
) -> ShaderExecution {
match execution_model {
ExecutionModel::Vertex => ShaderExecution::Vertex,
@ -187,7 +274,124 @@ fn shader_execution(
})
}
ExecutionModel::GLCompute => ShaderExecution::Compute,
ExecutionModel::GLCompute => {
let mut execution = ComputeShaderExecution::LocalSize([LocalSize::Literal(0); 3]);
for instruction in spirv.iter_execution_mode() {
match instruction {
Instruction::ExecutionMode { entry_point, mode }
if *entry_point == function_id =>
{
if let ExecutionMode::LocalSize {
x_size,
y_size,
z_size,
} = mode
{
execution = ComputeShaderExecution::LocalSize([
LocalSize::Literal(*x_size),
LocalSize::Literal(*y_size),
LocalSize::Literal(*z_size),
]);
break;
}
}
Instruction::ExecutionModeId { entry_point, mode }
if *entry_point == function_id =>
{
if let ExecutionMode::LocalSizeId {
x_size,
y_size,
z_size,
} = mode
{
let mut local_size = [LocalSize::Literal(0); 3];
for (local_size, id) in
local_size.iter_mut().zip([*x_size, *y_size, *z_size])
{
if let Some(constant) = u32_constants.get(&id) {
*local_size = LocalSize::Literal(*constant);
} else if let Some(spec_id) = specialization_constant_ids.get(&id) {
*local_size = LocalSize::SpecId(*spec_id);
} else {
panic!("LocalSizeId {id:?} not defined!");
}
}
execution = ComputeShaderExecution::LocalSizeId(local_size);
break;
}
}
_ => continue,
};
}
if !workgroup_size_decorations.is_empty() {
let mut in_function = false;
for instruction in spirv.instructions() {
if !in_function {
match *instruction {
Instruction::Function { result_id, .. } if result_id == function_id => {
in_function = true;
}
_ => {}
}
} else {
let mut local_size = [LocalSize::Literal(0); 3];
match instruction {
Instruction::ConstantComposite {
result_type_id: _,
result_id,
constituents,
} => {
if workgroup_size_decorations.contains(result_id) {
if constituents.len() != 3 {
panic!("WorkgroupSize must be 3 component vector!");
}
for (local_size, id) in
local_size.iter_mut().zip(constituents.iter())
{
if let Some(constant) = u32_constants.get(id) {
*local_size = LocalSize::Literal(*constant);
} else {
panic!("WorkgroupSize Constant {id:?} not defined!");
};
}
}
}
Instruction::SpecConstantComposite {
result_type_id: _,
result_id,
constituents,
} => {
if workgroup_size_decorations.contains(result_id) {
if constituents.len() != 3 {
panic!("WorkgroupSize must be 3 component vector!");
}
for (local_size, id) in
local_size.iter_mut().zip(constituents.iter())
{
if let Some(spec_id) = specialization_constant_ids.get(id) {
*local_size = LocalSize::SpecId(*spec_id);
} else {
panic!("WorkgroupSize SpecializationConstant {id:?} not defined!");
};
}
}
}
Instruction::FunctionEnd => break,
_ => continue,
}
match &mut execution {
ComputeShaderExecution::LocalSize(output) => {
*output = local_size;
}
ComputeShaderExecution::LocalSizeId(output) => {
*output = local_size;
}
}
}
}
}
ShaderExecution::Compute(execution)
}
ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,