From ee4e3089d3b82809875c0ed81fcee8b2fccb5ae9 Mon Sep 17 00:00:00 2001 From: charles-r-earp Date: Fri, 18 Aug 2023 04:37:10 -0700 Subject: [PATCH] Added required_subgroup_size to PipelineShaderStageCreateInfo (#2235) * Added required_subgroup_size to PipelineShaderStageCreateInfo * Added validation errors. * Fixed error msgs / vuids. * ComputeShaderExecution for validating local_size. * WorkgroupSizeId reflection. * contains_enum * Reworked ComputeShaderExecution. * panic msgs. * workgroup size validation * unused import * fixed test deprecated fn * catch workgroup size overflow * EntryPointInfo::local_size docs * comments * typo + error msg --- vulkano-shaders/src/entry_point.rs | 44 +++++- vulkano/src/pipeline/compute.rs | 151 ++++++++++++++++++- vulkano/src/pipeline/graphics/mod.rs | 21 ++- vulkano/src/pipeline/mod.rs | 152 +++++++++++++++++++ vulkano/src/shader/mod.rs | 108 ++++++++++++- vulkano/src/shader/reflect.rs | 218 ++++++++++++++++++++++++++- 6 files changed, 676 insertions(+), 18 deletions(-) diff --git a/vulkano-shaders/src/entry_point.rs b/vulkano-shaders/src/entry_point.rs index 18b03a03..4ca4e15f 100644 --- a/vulkano-shaders/src/entry_point.rs +++ b/vulkano-shaders/src/entry_point.rs @@ -73,7 +73,49 @@ fn write_shader_execution(execution: &ShaderExecution) -> TokenStream { ) } } - ShaderExecution::Compute => quote! { ::vulkano::shader::ShaderExecution::Compute }, + ShaderExecution::Compute(execution) => { + use ::quote::ToTokens; + use ::vulkano::shader::{ComputeShaderExecution, LocalSize}; + + struct LocalSizeToTokens(LocalSize); + + impl ToTokens for LocalSizeToTokens { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self.0 { + LocalSize::Literal(literal) => quote! { + ::vulkano::shader::LocalSize::Literal(#literal) + }, + LocalSize::SpecId(id) => quote! { + ::vulkano::shader::LocalSize::SpecId(#id) + }, + } + .to_tokens(tokens); + } + } + + match execution { + ComputeShaderExecution::LocalSize([x, y, z]) => { + let [x, y, z] = [ + LocalSizeToTokens(*x), + LocalSizeToTokens(*y), + LocalSizeToTokens(*z), + ]; + quote! { ::vulkano::shader::ShaderExecution::Compute( + ::vulkano::shader::ComputeShaderExecution::LocalSize([#x, #y, #z]) + ) } + } + ComputeShaderExecution::LocalSizeId([x, y, z]) => { + let [x, y, z] = [ + LocalSizeToTokens(*x), + LocalSizeToTokens(*y), + LocalSizeToTokens(*z), + ]; + quote! { ::vulkano::shader::ShaderExecution::Compute( + ::vulkano::shader::ComputeShaderExecution::LocalSizeId([#x, #y, #z]) + ) } + } + } + } ShaderExecution::RayGeneration => { quote! { ::vulkano::shader::ShaderExecution::RayGeneration } } diff --git a/vulkano/src/pipeline/compute.rs b/vulkano/src/pipeline/compute.rs index dfb4c78c..281aab9d 100644 --- a/vulkano/src/pipeline/compute.rs +++ b/vulkano/src/pipeline/compute.rs @@ -74,11 +74,9 @@ impl ComputePipeline { if let Some(cache) = &cache { assert_eq!(device, cache.device().as_ref()); } - create_info .validate(device) .map_err(|err| err.add_context("create_info"))?; - Ok(()) } @@ -100,12 +98,14 @@ impl ComputePipeline { let specialization_info_vk; let specialization_map_entries_vk: Vec<_>; let mut specialization_data_vk: Vec; + let required_subgroup_size_create_info; { let &PipelineShaderStageCreateInfo { flags, ref entry_point, ref specialization_info, + ref required_subgroup_size, _ne: _, } = stage; @@ -135,7 +135,20 @@ impl ComputePipeline { data_size: specialization_data_vk.len(), p_data: specialization_data_vk.as_ptr() as *const _, }; + required_subgroup_size_create_info = + required_subgroup_size.map(|required_subgroup_size| { + ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo { + required_subgroup_size, + ..Default::default() + } + }); stage_vk = ash::vk::PipelineShaderStageCreateInfo { + p_next: required_subgroup_size_create_info.as_ref().map_or( + ptr::null(), + |required_subgroup_size_create_info| { + required_subgroup_size_create_info as *const _ as _ + }, + ), flags: flags.into(), stage: ShaderStage::from(&entry_point_info.execution).into(), module: entry_point.module().handle(), @@ -333,12 +346,13 @@ impl ComputePipelineCreateInfo { flags: _, ref entry_point, specialization_info: _, + required_subgroup_size: _vk, _ne: _, } = &stage; let entry_point_info = entry_point.info(); - if !matches!(entry_point_info.execution, ShaderExecution::Compute) { + if !matches!(entry_point_info.execution, ShaderExecution::Compute(_)) { return Err(Box::new(ValidationError { context: "stage.entry_point".into(), problem: "is not a `ShaderStage::Compute` entry point".into(), @@ -514,4 +528,135 @@ mod tests { let data_buffer_content = data_buffer.read().unwrap(); assert_eq!(*data_buffer_content, 0x12345678); } + + #[test] + fn required_subgroup_size() { + // This test checks whether required_subgroup_size works. + // It executes a single compute shader (one invocation) that writes the subgroup size + // to a buffer. The buffer content is then checked for the right value. + + let (device, queue) = gfx_dev_and_queue!(subgroup_size_control); + + let cs = unsafe { + /* + #version 450 + + #extension GL_KHR_shader_subgroup_basic: enable + + layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + + layout(set = 0, binding = 0) buffer Output { + uint write; + } write; + + void main() { + if (gl_GlobalInvocationID.x == 0) { + write.write = gl_SubgroupSize; + } + } + */ + const MODULE: [u32; 246] = [ + 119734787, 65536, 851978, 30, 0, 131089, 1, 131089, 61, 393227, 1, 1280527431, + 1685353262, 808793134, 0, 196622, 0, 1, 458767, 5, 4, 1852399981, 0, 9, 23, 393232, + 4, 17, 128, 1, 1, 196611, 2, 450, 655364, 1197427783, 1279741775, 1885560645, + 1953718128, 1600482425, 1701734764, 1919509599, 1769235301, 25974, 524292, + 1197427783, 1279741775, 1852399429, 1685417059, 1768185701, 1952671090, 6649449, + 589828, 1264536647, 1935626824, 1701077352, 1970495346, 1869768546, 1650421877, + 1667855201, 0, 262149, 4, 1852399981, 0, 524293, 9, 1197436007, 1633841004, + 1986939244, 1952539503, 1231974249, 68, 262149, 18, 1886680399, 29813, 327686, 18, + 0, 1953067639, 101, 262149, 20, 1953067639, 101, 393221, 23, 1398762599, + 1919378037, 1399879023, 6650473, 262215, 9, 11, 28, 327752, 18, 0, 35, 0, 196679, + 18, 3, 262215, 20, 34, 0, 262215, 20, 33, 0, 196679, 23, 0, 262215, 23, 11, 36, + 196679, 24, 0, 262215, 29, 11, 25, 131091, 2, 196641, 3, 2, 262165, 6, 32, 0, + 262167, 7, 6, 3, 262176, 8, 1, 7, 262203, 8, 9, 1, 262187, 6, 10, 0, 262176, 11, 1, + 6, 131092, 14, 196638, 18, 6, 262176, 19, 2, 18, 262203, 19, 20, 2, 262165, 21, 32, + 1, 262187, 21, 22, 0, 262203, 11, 23, 1, 262176, 25, 2, 6, 262187, 6, 27, 128, + 262187, 6, 28, 1, 393260, 7, 29, 27, 28, 28, 327734, 2, 4, 0, 3, 131320, 5, 327745, + 11, 12, 9, 10, 262205, 6, 13, 12, 327850, 14, 15, 13, 10, 196855, 17, 0, 262394, + 15, 16, 17, 131320, 16, 262205, 6, 24, 23, 327745, 25, 26, 20, 22, 196670, 26, 24, + 131321, 17, 131320, 17, 65789, 65592, + ]; + let module = + ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&MODULE)).unwrap(); + module.entry_point("main").unwrap() + }; + + let properties = device.physical_device().properties(); + let subgroup_size = properties.min_subgroup_size.unwrap_or(1); + + let pipeline = { + let stage = PipelineShaderStageCreateInfo { + required_subgroup_size: Some(subgroup_size), + ..PipelineShaderStageCreateInfo::new(cs) + }; + let layout = PipelineLayout::new( + device.clone(), + PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) + .into_pipeline_layout_create_info(device.clone()) + .unwrap(), + ) + .unwrap(); + ComputePipeline::new( + device.clone(), + None, + ComputePipelineCreateInfo::stage_layout(stage, layout), + ) + .unwrap() + }; + + let memory_allocator = StandardMemoryAllocator::new_default(device.clone()); + let data_buffer = Buffer::from_data( + &memory_allocator, + BufferCreateInfo { + usage: BufferUsage::STORAGE_BUFFER, + ..Default::default() + }, + AllocationCreateInfo { + memory_type_filter: MemoryTypeFilter::PREFER_DEVICE + | MemoryTypeFilter::HOST_RANDOM_ACCESS, + ..Default::default() + }, + 0, + ) + .unwrap(); + + let ds_allocator = StandardDescriptorSetAllocator::new(device.clone()); + let set = PersistentDescriptorSet::new( + &ds_allocator, + pipeline.layout().set_layouts().get(0).unwrap().clone(), + [WriteDescriptorSet::buffer(0, data_buffer.clone())], + [], + ) + .unwrap(); + + let cb_allocator = StandardCommandBufferAllocator::new(device.clone(), Default::default()); + let mut cbb = AutoCommandBufferBuilder::primary( + &cb_allocator, + queue.queue_family_index(), + CommandBufferUsage::OneTimeSubmit, + ) + .unwrap(); + cbb.bind_pipeline_compute(pipeline.clone()) + .unwrap() + .bind_descriptor_sets( + PipelineBindPoint::Compute, + pipeline.layout().clone(), + 0, + set, + ) + .unwrap() + .dispatch([128, 1, 1]) + .unwrap(); + let cb = cbb.build().unwrap(); + + let future = now(device) + .then_execute(queue, cb) + .unwrap() + .then_signal_fence_and_flush() + .unwrap(); + future.wait(None).unwrap(); + + let data_buffer_content = data_buffer.read().unwrap(); + assert_eq!(*data_buffer_content, subgroup_size); + } } diff --git a/vulkano/src/pipeline/graphics/mod.rs b/vulkano/src/pipeline/graphics/mod.rs index 28d14bc1..19d04041 100644 --- a/vulkano/src/pipeline/graphics/mod.rs +++ b/vulkano/src/pipeline/graphics/mod.rs @@ -167,7 +167,6 @@ impl GraphicsPipeline { create_info .validate(device) .map_err(|err| err.add_context("create_info"))?; - Ok(()) } @@ -204,6 +203,8 @@ impl GraphicsPipeline { specialization_info_vk: ash::vk::SpecializationInfo, specialization_map_entries_vk: Vec, specialization_data_vk: Vec, + required_subgroup_size_create_info: + Option, } let (mut stages_vk, mut per_stage_vk): (SmallVec<[_; 5]>, SmallVec<[_; 5]>) = stages @@ -213,6 +214,7 @@ impl GraphicsPipeline { flags, ref entry_point, ref specialization_info, + ref required_subgroup_size, _ne: _, } = stage; @@ -235,7 +237,13 @@ impl GraphicsPipeline { } }) .collect(); - + let required_subgroup_size_create_info = + required_subgroup_size.map(|required_subgroup_size| { + ash::vk::PipelineShaderStageRequiredSubgroupSizeCreateInfo { + required_subgroup_size, + ..Default::default() + } + }); ( ash::vk::PipelineShaderStageCreateInfo { flags: flags.into(), @@ -255,6 +263,7 @@ impl GraphicsPipeline { }, specialization_map_entries_vk, specialization_data_vk, + required_subgroup_size_create_info, }, ) }) @@ -267,10 +276,17 @@ impl GraphicsPipeline { specialization_info_vk, specialization_map_entries_vk, specialization_data_vk, + required_subgroup_size_create_info, }, ) in (stages_vk.iter_mut()).zip(per_stage_vk.iter_mut()) { *stage_vk = ash::vk::PipelineShaderStageCreateInfo { + p_next: required_subgroup_size_create_info.as_ref().map_or( + ptr::null(), + |required_subgroup_size_create_info| { + required_subgroup_size_create_info as *const _ as _ + }, + ), p_name: name_vk.as_ptr(), p_specialization_info: specialization_info_vk, ..*stage_vk @@ -2420,6 +2436,7 @@ impl GraphicsPipelineCreateInfo { flags: _, ref entry_point, specialization_info: _, + required_subgroup_size: _vk, _ne: _, } = stage; diff --git a/vulkano/src/pipeline/mod.rs b/vulkano/src/pipeline/mod.rs index 71903c7f..a8123a91 100644 --- a/vulkano/src/pipeline/mod.rs +++ b/vulkano/src/pipeline/mod.rs @@ -321,6 +321,21 @@ pub struct PipelineShaderStageCreateInfo { /// The default value is empty. pub specialization_info: HashMap, + /// The required subgroup size. + /// + /// Requires [`subgroup_size_control`](crate::device::Features::subgroup_size_control). The + /// shader stage must be included in + /// [`required_subgroup_size_stages`](crate::device::Properties::required_subgroup_size_stages). + /// Subgroup size must be power of 2 and within + /// [`min_subgroup_size`](crate::device::Properties::min_subgroup_size) + /// and [`max_subgroup_size`](crate::device::Properties::max_subgroup_size). + /// + /// For compute shaders, `max_compute_workgroup_subgroups * required_subgroup_size` must be + /// greater than or equal to `workgroup_size.x * workgroup_size.y * workgroup_size.z`. + /// + /// The default value is None. + pub required_subgroup_size: Option, + pub _ne: crate::NonExhaustive, } @@ -332,6 +347,7 @@ impl PipelineShaderStageCreateInfo { flags: PipelineShaderStageCreateFlags::empty(), entry_point, specialization_info: HashMap::default(), + required_subgroup_size: None, _ne: crate::NonExhaustive(()), } } @@ -341,6 +357,7 @@ impl PipelineShaderStageCreateInfo { flags, ref entry_point, ref specialization_info, + ref required_subgroup_size, _ne: _, } = self; @@ -469,10 +486,145 @@ impl PipelineShaderStageCreateInfo { } } + let workgroup_size = if let Some(local_size) = + entry_point_info.local_size(specialization_info)? + { + let [x, y, z] = local_size; + if x == 0 || y == 0 || z == 0 { + return Err(Box::new(ValidationError { + problem: format!("`workgroup size` {local_size:?} cannot be 0").into(), + ..Default::default() + })); + } + let properties = device.physical_device().properties(); + if stage_enum == ShaderStage::Compute { + let max_compute_work_group_size = properties.max_compute_work_group_size; + let [max_x, max_y, max_z] = max_compute_work_group_size; + if x > max_x || y > max_y || z > max_z { + return Err(Box::new(ValidationError { + problem: format!("`workgroup size` {local_size:?} is greater than `max_compute_work_group_size` {max_compute_work_group_size:?}").into(), + ..Default::default() + })); + } + let max_invocations = properties.max_compute_work_group_invocations; + if let Some(workgroup_size) = (|| x.checked_mul(y)?.checked_mul(z))() { + if workgroup_size > max_invocations { + return Err(Box::new(ValidationError { + problem: format!("the product of `workgroup size` {local_size:?} = {workgroup_size} is greater than `max_compute_work_group_invocations` {max_invocations}").into(), + ..Default::default() + })); + } + Some(workgroup_size) + } else { + return Err(Box::new(ValidationError { + problem: format!("the product of `workgroup size` {local_size:?} = (overflow) is greater than `max_compute_work_group_invocations` {max_invocations}").into(), + ..Default::default() + })); + } + } else { + // TODO: Additional stages when `.local_size()` supports them. + unreachable!() + } + } else { + None + }; + + if let Some(required_subgroup_size) = required_subgroup_size { + validate_required_subgroup_size( + device, + stage_enum, + workgroup_size, + *required_subgroup_size, + )?; + } + Ok(()) } } +pub(crate) fn validate_required_subgroup_size( + device: &Device, + stage: ShaderStage, + workgroup_size: Option, + subgroup_size: u32, +) -> Result<(), Box> { + if !device.enabled_features().subgroup_size_control { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + requires_one_of: RequiresOneOf(&[RequiresAllOf(&[Requires::Feature( + "subgroup_size_control", + )])]), + vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], + ..Default::default() + })); + } + let properties = device.physical_device().properties(); + if !properties + .required_subgroup_size_stages + .unwrap_or_default() + .contains_enum(stage) + { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: format!("`shader stage` {stage:?} is not in `required_subgroup_size_stages`") + .into(), + vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02755"], + ..Default::default() + })); + } + if !subgroup_size.is_power_of_two() { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: format!("`subgroup_size` {subgroup_size} is not a power of 2").into(), + vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02760"], + ..Default::default() + })); + } + let min_subgroup_size = properties.min_subgroup_size.unwrap_or(1); + if subgroup_size < min_subgroup_size { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: format!( + "`subgroup_size` {subgroup_size} is less than `min_subgroup_size` {min_subgroup_size}" + ) + .into(), + vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02761"], + ..Default::default() + })); + } + let max_subgroup_size = properties.max_subgroup_size.unwrap_or(128); + if subgroup_size > max_subgroup_size { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: + format!("`subgroup_size` {subgroup_size} is greater than `max_subgroup_size` {max_subgroup_size}") + .into(), + vuids: &["VUID-VkPipelineShaderStageRequiredSubgroupSizeCreateInfo-requiredSubgroupSize-02762"], + ..Default::default() + })); + } + if let Some(workgroup_size) = workgroup_size { + if stage == ShaderStage::Compute { + let max_compute_workgroup_subgroups = properties + .max_compute_workgroup_subgroups + .unwrap_or_default(); + if max_compute_workgroup_subgroups + .checked_mul(subgroup_size) + .unwrap_or(u32::MAX) + < workgroup_size + { + return Err(Box::new(ValidationError { + context: "required_subgroup_size".into(), + problem: format!("`subgroup_size` {subgroup_size} creates more than {max_compute_workgroup_subgroups} subgroups").into(), + vuids: &["VUID-VkPipelineShaderStageCreateInfo-pNext-02756"], + ..Default::default() + })); + } + } + } + Ok(()) +} + vulkan_bitflags! { #[non_exhaustive] diff --git a/vulkano/src/shader/mod.rs b/vulkano/src/shader/mod.rs index 55a455ae..1072f14a 100644 --- a/vulkano/src/shader/mod.rs +++ b/vulkano/src/shader/mod.rs @@ -549,6 +549,78 @@ pub struct EntryPointInfo { pub output_interface: ShaderInterface, } +impl EntryPointInfo { + /// The local size in Compute shaders, None otherwise. + /// + /// `specialization_info` is used for LocalSizeId / WorkgroupSizeId, using specialization_constants if not found. + /// Errors if specialization constants are not found or are not u32's. + pub(crate) fn local_size( + &self, + specialization_info: &HashMap, + ) -> Result, Box> { + if let ShaderExecution::Compute(execution) = self.execution { + match execution { + ComputeShaderExecution::LocalSize(local_size) + | ComputeShaderExecution::LocalSizeId(local_size) => { + let mut output = [0; 3]; + for (output, local_size) in output.iter_mut().zip(local_size) { + let id = match local_size { + LocalSize::Literal(literal) => { + *output = literal; + continue; + } + LocalSize::SpecId(id) => id, + }; + if let Some(default) = self.specialization_constants.get(&id) { + let default_value = if let SpecializationConstant::U32(default_value) = + default + { + default_value + } else { + return Err(Box::new(ValidationError { + problem: format!( + "`entry_point.info().specialization_constants[{id}]` is not a 32 bit integer" + ) + .into(), + ..Default::default() + })); + }; + if let Some(provided) = specialization_info.get(&id) { + if let SpecializationConstant::U32(provided_value) = provided { + *output = *provided_value; + } else { + return Err(Box::new(ValidationError { + problem: format!( + "`specialization_info[{0}]` does not have the same type as \ + `entry_point.info().specialization_constants[{0}]`", + id, + ) + .into(), + vuids: &["VUID-VkSpecializationMapEntry-constantID-00776"], + ..Default::default() + })); + } + } + *output = *default_value; + } else { + return Err(Box::new(ValidationError { + problem: format!( + "specialization constant {id} not found in `entry_point.info().specialization_constants`" + ) + .into(), + ..Default::default() + })); + } + } + Ok(Some(output)) + } + } + } else { + Ok(None) + } + } +} + /// Represents a shader entry point in a shader module. /// /// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module. @@ -581,15 +653,15 @@ pub enum ShaderExecution { TessellationEvaluation, Geometry(GeometryShaderExecution), Fragment(FragmentShaderExecution), - Compute, + Compute(ComputeShaderExecution), RayGeneration, AnyHit, ClosestHit, Miss, Intersection, Callable, - Task, - Mesh, + Task, // TODO: like compute? + Mesh, // TODO: like compute? SubpassShading, } @@ -601,7 +673,7 @@ impl From<&ShaderExecution> for ExecutionModel { ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation, ShaderExecution::Geometry(_) => Self::Geometry, ShaderExecution::Fragment(_) => Self::Fragment, - ShaderExecution::Compute => Self::GLCompute, + ShaderExecution::Compute(_) => Self::GLCompute, ShaderExecution::RayGeneration => Self::RayGenerationKHR, ShaderExecution::AnyHit => Self::AnyHitKHR, ShaderExecution::ClosestHit => Self::ClosestHitKHR, @@ -699,6 +771,32 @@ pub enum FragmentTestsStages { EarlyAndLate, } +/// LocalSize. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum LocalSize { + Literal(u32), + SpecId(u32), +} + +/// The mode in which the compute shader executes. +/// +/// The workgroup size is specified for x, y, and z dimensions. +/// +/// Constants are resolved to literals, while specialization constants +/// map to spec ids. +/// +/// The `WorkgroupSize` builtin overrides the values specified in the +/// execution mode. It can decorate a 3 component ConstantComposite or +/// SpecConstantComposite vector. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ComputeShaderExecution { + /// Workgroup size in x, y, and z. + LocalSize([LocalSize; 3]), + /// Requires spirv 1.2. + /// Like `LocalSize`, but uses ids instead of literals. + LocalSizeId([LocalSize; 3]), +} + /// The requirements imposed by a shader on a binding within a descriptor set layout, and on any /// resource that is bound to that binding. #[derive(Clone, Debug, Default)] @@ -1263,7 +1361,7 @@ impl From<&ShaderExecution> for ShaderStage { ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation, ShaderExecution::Geometry(_) => Self::Geometry, ShaderExecution::Fragment(_) => Self::Fragment, - ShaderExecution::Compute => Self::Compute, + ShaderExecution::Compute(_) => Self::Compute, ShaderExecution::RayGeneration => Self::Raygen, ShaderExecution::AnyHit => Self::AnyHit, ShaderExecution::ClosestHit => Self::ClosestHit, diff --git a/vulkano/src/shader/reflect.rs b/vulkano/src/shader/reflect.rs index 91776942..786a14aa 100644 --- a/vulkano/src/shader/reflect.rs +++ b/vulkano/src/shader/reflect.rs @@ -16,12 +16,13 @@ use crate::{ pipeline::layout::PushConstantRange, shader::{ spirv::{ - Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv, - StorageClass, + BuiltIn, Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, + Spirv, StorageClass, }, - DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, GeometryShaderExecution, - GeometryShaderInput, NumericType, ShaderExecution, ShaderInterface, ShaderInterfaceEntry, - ShaderInterfaceEntryType, ShaderStage, SpecializationConstant, + ComputeShaderExecution, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, + GeometryShaderExecution, GeometryShaderInput, LocalSize, NumericType, ShaderExecution, + ShaderInterface, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, + SpecializationConstant, }, DeviceSize, }; @@ -55,6 +56,9 @@ pub fn spirv_extensions(spirv: &Spirv) -> impl Iterator { #[inline] pub fn entry_points(spirv: &Spirv) -> impl Iterator + '_ { let interface_variables = interface_variables(spirv); + let u32_constants = u32_constants(spirv); + let specialization_constant_ids = specialization_constant_ids(spirv); + let workgroup_size_decorations = workgroup_size_decorations(spirv); spirv.iter_entry_point().filter_map(move |instruction| { let (execution_model, function_id, entry_point_name, interface) = match instruction { @@ -68,7 +72,14 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator + '_ _ => return None, }; - let execution = shader_execution(spirv, execution_model, function_id); + let execution = shader_execution( + spirv, + execution_model, + function_id, + &u32_constants, + &specialization_constant_ids, + &workgroup_size_decorations, + ); let stage = ShaderStage::from(&execution); let descriptor_binding_requirements = inspect_entry_point( @@ -109,11 +120,87 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator + '_ }) } +/// Extracts the u32 constants from `spirv`. +fn u32_constants(spirv: &Spirv) -> HashMap { + let type_u32s: HashSet = spirv + .iter_global() + .filter_map(|inst| { + if let Instruction::TypeInt { + result_id, + width, + signedness, + } = inst + { + if *width == 0 && *signedness == 0 { + return Some(*result_id); + } + } + None + }) + .collect(); + spirv + .iter_decoration() + .filter_map(|inst| { + if let Instruction::Constant { + result_type_id, + result_id, + value, + } = inst + { + if type_u32s.contains(result_type_id) { + if let [value] = value.as_slice() { + return Some((*result_id, *value)); + } + } + } + None + }) + .collect() +} + +/// Extracts the specialization constant ids from `spirv`. +fn specialization_constant_ids(spirv: &Spirv) -> HashMap { + spirv + .iter_decoration() + .filter_map(|inst| { + if let Instruction::Decorate { target, decoration } = inst { + if let Decoration::SpecId { + specialization_constant_id, + } = decoration + { + return Some((*target, *specialization_constant_id)); + } + } + None + }) + .collect() +} + +/// Extracts the `WorkgroupSize` builtin Id's from `spirv`. +fn workgroup_size_decorations(spirv: &Spirv) -> HashSet { + spirv + .iter_decoration() + .filter_map(|inst| { + if let Instruction::Decorate { target, decoration } = inst { + if let Decoration::BuiltIn { built_in } = decoration { + if *built_in == BuiltIn::WorkgroupSize { + return Some(*target); + } + } + } + None + }) + .collect() +} + /// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`. fn shader_execution( spirv: &Spirv, execution_model: ExecutionModel, function_id: Id, + u32_constants: &HashMap, + specialization_constant_ids: &HashMap, + workgroup_size_decorations: &HashSet, ) -> ShaderExecution { match execution_model { ExecutionModel::Vertex => ShaderExecution::Vertex, @@ -187,7 +274,124 @@ fn shader_execution( }) } - ExecutionModel::GLCompute => ShaderExecution::Compute, + ExecutionModel::GLCompute => { + let mut execution = ComputeShaderExecution::LocalSize([LocalSize::Literal(0); 3]); + for instruction in spirv.iter_execution_mode() { + match instruction { + Instruction::ExecutionMode { entry_point, mode } + if *entry_point == function_id => + { + if let ExecutionMode::LocalSize { + x_size, + y_size, + z_size, + } = mode + { + execution = ComputeShaderExecution::LocalSize([ + LocalSize::Literal(*x_size), + LocalSize::Literal(*y_size), + LocalSize::Literal(*z_size), + ]); + break; + } + } + Instruction::ExecutionModeId { entry_point, mode } + if *entry_point == function_id => + { + if let ExecutionMode::LocalSizeId { + x_size, + y_size, + z_size, + } = mode + { + let mut local_size = [LocalSize::Literal(0); 3]; + for (local_size, id) in + local_size.iter_mut().zip([*x_size, *y_size, *z_size]) + { + if let Some(constant) = u32_constants.get(&id) { + *local_size = LocalSize::Literal(*constant); + } else if let Some(spec_id) = specialization_constant_ids.get(&id) { + *local_size = LocalSize::SpecId(*spec_id); + } else { + panic!("LocalSizeId {id:?} not defined!"); + } + } + execution = ComputeShaderExecution::LocalSizeId(local_size); + break; + } + } + _ => continue, + }; + } + if !workgroup_size_decorations.is_empty() { + let mut in_function = false; + for instruction in spirv.instructions() { + if !in_function { + match *instruction { + Instruction::Function { result_id, .. } if result_id == function_id => { + in_function = true; + } + _ => {} + } + } else { + let mut local_size = [LocalSize::Literal(0); 3]; + match instruction { + Instruction::ConstantComposite { + result_type_id: _, + result_id, + constituents, + } => { + if workgroup_size_decorations.contains(result_id) { + if constituents.len() != 3 { + panic!("WorkgroupSize must be 3 component vector!"); + } + for (local_size, id) in + local_size.iter_mut().zip(constituents.iter()) + { + if let Some(constant) = u32_constants.get(id) { + *local_size = LocalSize::Literal(*constant); + } else { + panic!("WorkgroupSize Constant {id:?} not defined!"); + }; + } + } + } + Instruction::SpecConstantComposite { + result_type_id: _, + result_id, + constituents, + } => { + if workgroup_size_decorations.contains(result_id) { + if constituents.len() != 3 { + panic!("WorkgroupSize must be 3 component vector!"); + } + for (local_size, id) in + local_size.iter_mut().zip(constituents.iter()) + { + if let Some(spec_id) = specialization_constant_ids.get(id) { + *local_size = LocalSize::SpecId(*spec_id); + } else { + panic!("WorkgroupSize SpecializationConstant {id:?} not defined!"); + }; + } + } + } + Instruction::FunctionEnd => break, + _ => continue, + } + match &mut execution { + ComputeShaderExecution::LocalSize(output) => { + *output = local_size; + } + ComputeShaderExecution::LocalSizeId(output) => { + *output = local_size; + } + } + } + } + } + ShaderExecution::Compute(execution) + } ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration, ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,