diff --git a/CHANGELOG.md b/CHANGELOG.md index 705e8130..6b77d6dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -167,7 +167,7 @@ Changes to the `khr_display` extension: - Added `DeviceMemory::{map, unmap, mapping_state, invalidate_range, flush_range}`, `MappedDeviceMemory` has been deprecated. - Added `MemoryMapInfo`, `MemoryUnmapInfo`, `MappingState` and `MappedMemoryRange`. - Added `ShaderModule::single_entry_point()` which may replace `entry_point("main")` calls in common setups. -- Added `ShaderModule::single_entry_point_of_execution`. +- Added `ShaderModule::single_entry_point_with_execution`. - Added `GenericMemoryAllocatorCreateInfo::memory_type_bits` and `AllocationCreateInfo::memory_type_bits`. ### Bugs fixed diff --git a/examples/src/bin/dynamic-local-size.rs b/examples/src/bin/dynamic-local-size.rs index cc6df296..619d8265 100644 --- a/examples/src/bin/dynamic-local-size.rs +++ b/examples/src/bin/dynamic-local-size.rs @@ -179,21 +179,22 @@ fn main() { let pipeline = { let cs = cs::load(device.clone()) + .unwrap() + .specialize( + [ + (0, 0.2f32.into()), + (1, local_size_x.into()), + (2, local_size_y.into()), + (3, 0.5f32.into()), + (4, 1.0f32.into()), + ] + .into_iter() + .collect(), + ) .unwrap() .entry_point("main") .unwrap(); - let stage = PipelineShaderStageCreateInfo { - specialization_info: [ - (0, 0.2f32.into()), - (1, local_size_x.into()), - (2, local_size_y.into()), - (3, 0.5f32.into()), - (4, 1.0f32.into()), - ] - .into_iter() - .collect(), - ..PipelineShaderStageCreateInfo::new(cs) - }; + let stage = PipelineShaderStageCreateInfo::new(cs); let layout = PipelineLayout::new( device.clone(), PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) diff --git a/examples/src/bin/shader-types-sharing.rs b/examples/src/bin/shader-types-sharing.rs index c7db2c82..bd6d05b7 100644 --- a/examples/src/bin/shader-types-sharing.rs +++ b/examples/src/bin/shader-types-sharing.rs @@ -256,13 +256,12 @@ fn main() { // Load the first shader, and create a pipeline for the shader. let mult_pipeline = { let cs = shaders::load_mult(device.clone()) + .unwrap() + .specialize([(0, true.into())].into_iter().collect()) .unwrap() .entry_point("main") .unwrap(); - let stage = PipelineShaderStageCreateInfo { - specialization_info: [(0, true.into())].into_iter().collect(), - ..PipelineShaderStageCreateInfo::new(cs) - }; + let stage = PipelineShaderStageCreateInfo::new(cs); let layout = PipelineLayout::new( device.clone(), PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) @@ -281,13 +280,12 @@ fn main() { // Load the second shader, and create a pipeline for the shader. let add_pipeline = { let cs = shaders::load_add(device.clone()) + .unwrap() + .specialize([(0, true.into())].into_iter().collect()) .unwrap() .entry_point("main") .unwrap(); - let stage = PipelineShaderStageCreateInfo { - specialization_info: [(0, true.into())].into_iter().collect(), - ..PipelineShaderStageCreateInfo::new(cs) - }; + let stage = PipelineShaderStageCreateInfo::new(cs); let layout = PipelineLayout::new( device.clone(), PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) diff --git a/examples/src/bin/specialization-constants.rs b/examples/src/bin/specialization-constants.rs index fbbd3d2f..aae8ee98 100644 --- a/examples/src/bin/specialization-constants.rs +++ b/examples/src/bin/specialization-constants.rs @@ -117,15 +117,16 @@ fn main() { let pipeline = { let cs = cs::load(device.clone()) + .unwrap() + .specialize( + [(0, 1i32.into()), (1, 1.0f32.into()), (2, true.into())] + .into_iter() + .collect(), + ) .unwrap() .entry_point("main") .unwrap(); - let stage = PipelineShaderStageCreateInfo { - specialization_info: [(0, 1i32.into()), (1, 1.0f32.into()), (2, true.into())] - .into_iter() - .collect(), - ..PipelineShaderStageCreateInfo::new(cs) - }; + let stage = PipelineShaderStageCreateInfo::new(cs); let layout = PipelineLayout::new( device.clone(), PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) diff --git a/vulkano-shaders/src/codegen.rs b/vulkano-shaders/src/codegen.rs index 9cf5a4f0..f9223543 100644 --- a/vulkano-shaders/src/codegen.rs +++ b/vulkano-shaders/src/codegen.rs @@ -8,7 +8,6 @@ // according to those terms. use crate::{ - entry_point, structs::{self, TypeRegistry}, MacroInput, }; @@ -23,10 +22,7 @@ use std::{ path::{Path, PathBuf}, }; use syn::{Error, LitStr}; -use vulkano::{ - shader::{reflect, spirv::Spirv}, - Version, -}; +use vulkano::shader::spirv::Spirv; pub struct Shader { pub source: LitStr, @@ -238,29 +234,6 @@ pub(super) fn reflect( } }); - let spirv_version = { - let Version { - major, - minor, - patch, - } = shader.spirv.version(); - - quote! { - ::vulkano::Version { - major: #major, - minor: #minor, - patch: #patch, - } - } - }; - let spirv_capabilities = reflect::spirv_capabilities(&shader.spirv).map(|capability| { - let name = format_ident!("{}", format!("{:?}", capability)); - quote! { &::vulkano::shader::spirv::Capability::#name } - }); - let spirv_extensions = reflect::spirv_extensions(&shader.spirv); - let entry_points = - reflect::entry_points(&shader.spirv).map(|info| entry_point::write_entry_point(&info)); - let load_name = if shader.name.is_empty() { format_ident!("load") } else { @@ -282,13 +255,9 @@ pub(super) fn reflect( static WORDS: &[u32] = &[ #( #words ),* ]; unsafe { - ::vulkano::shader::ShaderModule::new_with_data( + ::vulkano::shader::ShaderModule::new( device, ::vulkano::shader::ShaderModuleCreateInfo::new(&WORDS), - [ #( #entry_points ),* ], - #spirv_version, - [ #( #spirv_capabilities ),* ], - [ #( #spirv_extensions ),* ], ) } } @@ -302,6 +271,7 @@ pub(super) fn reflect( #[cfg(test)] mod tests { use super::*; + use vulkano::shader::reflect; fn convert_paths(root_path: &Path, paths: &[PathBuf]) -> Vec { paths diff --git a/vulkano-shaders/src/entry_point.rs b/vulkano-shaders/src/entry_point.rs deleted file mode 100644 index 4ca4e15f..00000000 --- a/vulkano-shaders/src/entry_point.rs +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright (c) 2016 The vulkano developers -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , -// at your option. All files in the project carrying such -// notice may not be copied, modified, or distributed except -// according to those terms. - -use ahash::HashMap; -use proc_macro2::TokenStream; -use vulkano::{ - pipeline::layout::PushConstantRange, - shader::{ - DescriptorBindingRequirements, DescriptorIdentifier, DescriptorRequirements, - EntryPointInfo, ShaderExecution, ShaderInterface, ShaderInterfaceEntry, - ShaderInterfaceEntryType, ShaderStages, SpecializationConstant, - }, -}; - -pub(super) fn write_entry_point(info: &EntryPointInfo) -> TokenStream { - let name = &info.name; - let execution = write_shader_execution(&info.execution); - let descriptor_binding_requirements = - write_descriptor_binding_requirements(&info.descriptor_binding_requirements); - let push_constant_requirements = - write_push_constant_requirements(&info.push_constant_requirements); - let specialization_constants = write_specialization_constants(&info.specialization_constants); - let input_interface = write_interface(&info.input_interface); - let output_interface = write_interface(&info.output_interface); - - quote! { - ::vulkano::shader::EntryPointInfo { - name: #name.to_owned(), - execution: #execution, - descriptor_binding_requirements: #descriptor_binding_requirements.into_iter().collect(), - push_constant_requirements: #push_constant_requirements, - specialization_constants: #specialization_constants.into_iter().collect(), - input_interface: #input_interface, - output_interface: #output_interface, - }, - } -} - -fn write_shader_execution(execution: &ShaderExecution) -> TokenStream { - match execution { - ShaderExecution::Vertex => quote! { ::vulkano::shader::ShaderExecution::Vertex }, - ShaderExecution::TessellationControl => { - quote! { ::vulkano::shader::ShaderExecution::TessellationControl } - } - ShaderExecution::TessellationEvaluation => { - quote! { ::vulkano::shader::ShaderExecution::TessellationEvaluation } - } - ShaderExecution::Geometry(::vulkano::shader::GeometryShaderExecution { input }) => { - let input = format_ident!("{}", format!("{:?}", input)); - quote! { - ::vulkano::shader::ShaderExecution::Geometry( - ::vulkano::shader::GeometryShaderExecution { - input: ::vulkano::shader::GeometryShaderInput::#input - } - ) - } - } - ShaderExecution::Fragment(::vulkano::shader::FragmentShaderExecution { - fragment_tests_stages, - }) => { - let fragment_tests_stages = format_ident!("{}", format!("{:?}", fragment_tests_stages)); - quote! { - ::vulkano::shader::ShaderExecution::Fragment( - ::vulkano::shader::FragmentShaderExecution { - fragment_tests_stages: ::vulkano::shader::FragmentTestsStages::#fragment_tests_stages, - } - ) - } - } - ShaderExecution::Compute(execution) => { - use ::quote::ToTokens; - use ::vulkano::shader::{ComputeShaderExecution, LocalSize}; - - struct LocalSizeToTokens(LocalSize); - - impl ToTokens for LocalSizeToTokens { - fn to_tokens(&self, tokens: &mut TokenStream) { - match self.0 { - LocalSize::Literal(literal) => quote! { - ::vulkano::shader::LocalSize::Literal(#literal) - }, - LocalSize::SpecId(id) => quote! { - ::vulkano::shader::LocalSize::SpecId(#id) - }, - } - .to_tokens(tokens); - } - } - - match execution { - ComputeShaderExecution::LocalSize([x, y, z]) => { - let [x, y, z] = [ - LocalSizeToTokens(*x), - LocalSizeToTokens(*y), - LocalSizeToTokens(*z), - ]; - quote! { ::vulkano::shader::ShaderExecution::Compute( - ::vulkano::shader::ComputeShaderExecution::LocalSize([#x, #y, #z]) - ) } - } - ComputeShaderExecution::LocalSizeId([x, y, z]) => { - let [x, y, z] = [ - LocalSizeToTokens(*x), - LocalSizeToTokens(*y), - LocalSizeToTokens(*z), - ]; - quote! { ::vulkano::shader::ShaderExecution::Compute( - ::vulkano::shader::ComputeShaderExecution::LocalSizeId([#x, #y, #z]) - ) } - } - } - } - ShaderExecution::RayGeneration => { - quote! { ::vulkano::shader::ShaderExecution::RayGeneration } - } - ShaderExecution::AnyHit => quote! { ::vulkano::shader::ShaderExecution::AnyHit }, - ShaderExecution::ClosestHit => quote! { ::vulkano::shader::ShaderExecution::ClosestHit }, - ShaderExecution::Miss => quote! { ::vulkano::shader::ShaderExecution::Miss }, - ShaderExecution::Intersection => { - quote! { ::vulkano::shader::ShaderExecution::Intersection } - } - ShaderExecution::Callable => quote! { ::vulkano::shader::ShaderExecution::Callable }, - ShaderExecution::Task => quote! { ::vulkano::shader::ShaderExecution::Task }, - ShaderExecution::Mesh => quote! { ::vulkano::shader::ShaderExecution::Mesh }, - ShaderExecution::SubpassShading => { - quote! { ::vulkano::shader::ShaderExecution::SubpassShading } - } - } -} - -fn write_descriptor_binding_requirements( - descriptor_binding_requirements: &HashMap<(u32, u32), DescriptorBindingRequirements>, -) -> TokenStream { - let descriptor_binding_requirements = - descriptor_binding_requirements - .iter() - .map(|(loc, binding_reqs)| { - let (set_num, binding_num) = loc; - let DescriptorBindingRequirements { - descriptor_types, - descriptor_count, - image_format, - image_multisampled, - image_scalar_type, - image_view_type, - stages, - descriptors, - } = binding_reqs; - - let descriptor_types_items = descriptor_types.iter().map(|ty| { - let ident = format_ident!("{}", format!("{:?}", ty)); - quote! { ::vulkano::descriptor_set::layout::DescriptorType::#ident } - }); - let descriptor_count = match descriptor_count { - Some(descriptor_count) => quote! { Some(#descriptor_count) }, - None => quote! { None }, - }; - let image_format = match image_format { - Some(image_format) => { - let ident = format_ident!("{}", format!("{:?}", image_format)); - quote! { Some(::vulkano::format::Format::#ident) } - } - None => quote! { None }, - }; - let image_scalar_type = match image_scalar_type { - Some(image_scalar_type) => { - let ident = format_ident!("{}", format!("{:?}", image_scalar_type)); - quote! { Some(::vulkano::format::NumericType::#ident) } - } - None => quote! { None }, - }; - let image_view_type = match image_view_type { - Some(image_view_type) => { - let ident = format_ident!("{}", format!("{:?}", image_view_type)); - quote! { Some(::vulkano::image::view::ImageViewType::#ident) } - } - None => quote! { None }, - }; - let stages = stages_to_items(*stages); - let descriptor_items = descriptors.iter().map(|(index, desc_reqs)| { - let DescriptorRequirements { - memory_read, - memory_write, - sampler_compare, - sampler_no_unnormalized_coordinates, - sampler_no_ycbcr_conversion, - sampler_with_images, - storage_image_atomic, - } = desc_reqs; - - let index = match index { - Some(index) => quote! { Some(#index) }, - None => quote! { None }, - }; - let memory_read = stages_to_items(*memory_read); - let memory_write = stages_to_items(*memory_write); - let sampler_with_images_items = sampler_with_images.iter().map(|DescriptorIdentifier { - set, - binding, - index, - }| { - quote! { - ::vulkano::shader::DescriptorIdentifier { - set: #set, - binding: #binding, - index: #index, - } - } - }); - - quote! { - ( - #index, - ::vulkano::shader::DescriptorRequirements { - memory_read: #memory_read, - memory_write: #memory_write, - sampler_compare: #sampler_compare, - sampler_no_unnormalized_coordinates: #sampler_no_unnormalized_coordinates, - sampler_no_ycbcr_conversion: #sampler_no_ycbcr_conversion, - sampler_with_images: [ #( #sampler_with_images_items ),* ] - .into_iter() - .collect(), - storage_image_atomic: #storage_image_atomic, - } - ) - } - }); - - quote! { - ( - (#set_num, #binding_num), - ::vulkano::shader::DescriptorBindingRequirements { - descriptor_types: vec![ #( #descriptor_types_items ),* ], - descriptor_count: #descriptor_count, - image_format: #image_format, - image_multisampled: #image_multisampled, - image_scalar_type: #image_scalar_type, - image_view_type: #image_view_type, - stages: #stages, - descriptors: [ #( #descriptor_items ),* ].into_iter().collect(), - }, - ) - } - }); - - quote! { - [ - #( #descriptor_binding_requirements ),* - ] - } -} - -fn write_push_constant_requirements( - push_constant_requirements: &Option, -) -> TokenStream { - match push_constant_requirements { - Some(PushConstantRange { - offset, - size, - stages, - }) => { - let stages = stages_to_items(*stages); - - quote! { - Some(::vulkano::pipeline::layout::PushConstantRange { - stages: #stages, - offset: #offset, - size: #size, - }) - } - } - None => quote! { - None - }, - } -} - -fn write_specialization_constants( - specialization_constants: &HashMap, -) -> TokenStream { - let specialization_constants = specialization_constants - .iter() - .map(|(&constant_id, value)| { - let value = match value { - SpecializationConstant::Bool(value) => quote! { Bool(#value) }, - SpecializationConstant::I8(value) => quote! { I8(#value) }, - SpecializationConstant::I16(value) => quote! { I16(#value) }, - SpecializationConstant::I32(value) => quote! { I32(#value) }, - SpecializationConstant::I64(value) => quote! { I64(#value) }, - SpecializationConstant::U8(value) => quote! { U8(#value) }, - SpecializationConstant::U16(value) => quote! { U16(#value) }, - SpecializationConstant::U32(value) => quote! { U32(#value) }, - SpecializationConstant::U64(value) => quote! { U64(#value) }, - SpecializationConstant::F16(value) => { - let bits = value.to_bits(); - quote! { F16(f16::from_bits(#bits)) } - } - SpecializationConstant::F32(value) => { - let bits = value.to_bits(); - quote! { F32(f32::from_bits(#bits)) } - } - SpecializationConstant::F64(value) => { - let bits = value.to_bits(); - quote! { F64(f64::from_bits(#bits)) } - } - }; - - quote! { - ( - #constant_id, - ::vulkano::shader::SpecializationConstant::#value, - ) - } - }); - - quote! { - [ - #( #specialization_constants ),* - ] - } -} - -fn write_interface(interface: &ShaderInterface) -> TokenStream { - let items = interface.elements().iter().map( - |ShaderInterfaceEntry { - location, - component, - ty: - ShaderInterfaceEntryType { - base_type, - num_components, - num_elements, - is_64bit, - }, - name, - }| { - let base_type = format_ident!("{}", format!("{:?}", base_type)); - let name = if let Some(name) = name { - quote! { ::std::option::Option::Some(::std::borrow::Cow::Borrowed(#name)) } - } else { - quote! { ::std::option::Option::None } - }; - - quote! { - ::vulkano::shader::ShaderInterfaceEntry { - location: #location, - component: #component, - ty: ::vulkano::shader::ShaderInterfaceEntryType { - base_type: ::vulkano::format::NumericType::#base_type, - num_components: #num_components, - num_elements: #num_elements, - is_64bit: #is_64bit, - }, - name: #name, - } - } - }, - ); - - quote! { - ::vulkano::shader::ShaderInterface::new_unchecked(vec![ - #( #items ),* - ]) - } -} - -fn stages_to_items(stages: ShaderStages) -> TokenStream { - if stages.is_empty() { - quote! { ::vulkano::shader::ShaderStages::empty() } - } else { - let stages_items = [ - stages.intersects(ShaderStages::VERTEX).then(|| { - quote! { ::vulkano::shader::ShaderStages::VERTEX } - }), - stages - .intersects(ShaderStages::TESSELLATION_CONTROL) - .then(|| { - quote! { ::vulkano::shader::ShaderStages::TESSELLATION_CONTROL } - }), - stages - .intersects(ShaderStages::TESSELLATION_EVALUATION) - .then(|| { - quote! { ::vulkano::shader::ShaderStages::TESSELLATION_EVALUATION } - }), - stages.intersects(ShaderStages::GEOMETRY).then(|| { - quote! { ::vulkano::shader::ShaderStages::GEOMETRY } - }), - stages.intersects(ShaderStages::FRAGMENT).then(|| { - quote! { ::vulkano::shader::ShaderStages::FRAGMENT } - }), - stages.intersects(ShaderStages::COMPUTE).then(|| { - quote! { ::vulkano::shader::ShaderStages::COMPUTE } - }), - stages.intersects(ShaderStages::RAYGEN).then(|| { - quote! { ::vulkano::shader::ShaderStages::RAYGEN } - }), - stages.intersects(ShaderStages::ANY_HIT).then(|| { - quote! { ::vulkano::shader::ShaderStages::ANY_HIT } - }), - stages.intersects(ShaderStages::CLOSEST_HIT).then(|| { - quote! { ::vulkano::shader::ShaderStages::CLOSEST_HIT } - }), - stages.intersects(ShaderStages::MISS).then(|| { - quote! { ::vulkano::shader::ShaderStages::MISS } - }), - stages.intersects(ShaderStages::INTERSECTION).then(|| { - quote! { ::vulkano::shader::ShaderStages::INTERSECTION } - }), - stages.intersects(ShaderStages::CALLABLE).then(|| { - quote! { ::vulkano::shader::ShaderStages::CALLABLE } - }), - ] - .into_iter() - .flatten(); - - quote! { #( #stages_items )|* } - } -} diff --git a/vulkano-shaders/src/lib.rs b/vulkano-shaders/src/lib.rs index 300ffffe..752407b6 100644 --- a/vulkano-shaders/src/lib.rs +++ b/vulkano-shaders/src/lib.rs @@ -238,7 +238,6 @@ use syn::{ }; mod codegen; -mod entry_point; mod structs; #[proc_macro] diff --git a/vulkano/autogen/spirv_parse.rs b/vulkano/autogen/spirv_parse.rs index e17bfc18..1c6457d0 100644 --- a/vulkano/autogen/spirv_parse.rs +++ b/vulkano/autogen/spirv_parse.rs @@ -14,9 +14,13 @@ use once_cell::sync::Lazy; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; +// From the documentation of the OpSpecConstantOp instruction. +// The instructions requiring the Kernel capability are not listed, +// as this capability is not supported by Vulkan. static SPEC_CONSTANT_OP: Lazy> = Lazy::new(|| { HashSet::from_iter([ "SConvert", + "UConvert", "FConvert", "SNegate", "Not", @@ -54,27 +58,6 @@ static SPEC_CONSTANT_OP: Lazy> = Lazy::new(|| { "UGreaterThanEqual", "SGreaterThanEqual", "QuantizeToF16", - "ConvertFToS", - "ConvertSToF", - "ConvertFToU", - "ConvertUToF", - "UConvert", - "ConvertPtrToU", - "ConvertUToPtr", - "GenericCastToPtr", - "PtrCastToGeneric", - "Bitcast", - "FNegate", - "FAdd", - "FSub", - "FMul", - "FDiv", - "FRem", - "FMod", - "AccessChain", - "InBoundsAccessChain", - "PtrAccessChain", - "InBoundsPtrAccessChain", ]) }); @@ -228,7 +211,7 @@ fn instruction_output(members: &[InstructionMember], spec_constant: bool) -> Tok impl #enum_name { #[allow(dead_code)] fn parse(reader: &mut InstructionReader<'_>) -> Result { - let opcode = (reader.next_u32()? & 0xffff) as u16; + let opcode = (reader.next_word()? & 0xffff) as u16; Ok(match opcode { #(#parse_items)* @@ -391,7 +374,7 @@ fn bit_enum_output(enums: &[(Ident, Vec)]) -> TokenStream { impl #name { #[allow(dead_code)] fn parse(reader: &mut InstructionReader<'_>) -> Result<#name, ParseError> { - let value = reader.next_u32()?; + let value = reader.next_word()?; Ok(Self { #(#parse_items)* @@ -536,7 +519,7 @@ fn value_enum_output(enums: &[(Ident, Vec)]) -> TokenStream { impl #name { #[allow(dead_code)] fn parse(reader: &mut InstructionReader<'_>) -> Result<#name, ParseError> { - Ok(match reader.next_u32()? { + Ok(match reader.next_word()? { #(#parse_items)* value => return Err(reader.map_err(ParseErrors::UnknownEnumerant(#name_string, value))), }) @@ -632,7 +615,7 @@ fn kinds_to_types(grammar: &SpirvGrammar) -> HashMap<&str, (TokenStream, TokenSt (quote! { Vec }, quote! { reader.remainder() }) } "LiteralInteger" | "LiteralExtInstInteger" => { - (quote! { u32 }, quote! { reader.next_u32()? }) + (quote! { u32 }, quote! { reader.next_word()? }) } "LiteralSpecConstantOpInteger" => ( quote! { SpecConstantInstruction }, @@ -643,8 +626,8 @@ fn kinds_to_types(grammar: &SpirvGrammar) -> HashMap<&str, (TokenStream, TokenSt quote! { (Id, Id) }, quote! { ( - Id(reader.next_u32()?), - Id(reader.next_u32()?), + Id(reader.next_word()?), + Id(reader.next_word()?), ) }, ), @@ -652,8 +635,8 @@ fn kinds_to_types(grammar: &SpirvGrammar) -> HashMap<&str, (TokenStream, TokenSt quote! { (Id, u32) }, quote! { ( - Id(reader.next_u32()?), - reader.next_u32()? + Id(reader.next_word()?), + reader.next_word()? ) }, ), @@ -661,11 +644,13 @@ fn kinds_to_types(grammar: &SpirvGrammar) -> HashMap<&str, (TokenStream, TokenSt quote! { (u32, Id) }, quote! { ( - reader.next_u32()?, - Id(reader.next_u32()?)), + reader.next_word()?, + Id(reader.next_word()?)), }, ), - _ if k.kind.starts_with("Id") => (quote! { Id }, quote! { Id(reader.next_u32()?) }), + _ if k.kind.starts_with("Id") => { + (quote! { Id }, quote! { Id(reader.next_word()?) }) + } ident => { let ident = format_ident!("{}", ident); (quote! { #ident }, quote! { #ident::parse(reader)? }) @@ -678,7 +663,7 @@ fn kinds_to_types(grammar: &SpirvGrammar) -> HashMap<&str, (TokenStream, TokenSt "LiteralFloat", ( quote! { f32 }, - quote! { f32::from_bits(reader.next_u32()?) }, + quote! { f32::from_bits(reader.next_word()?) }, ), )]) .collect() diff --git a/vulkano/src/command_buffer/traits.rs b/vulkano/src/command_buffer/traits.rs index adea9c75..50fc98c3 100644 --- a/vulkano/src/command_buffer/traits.rs +++ b/vulkano/src/command_buffer/traits.rs @@ -482,28 +482,28 @@ impl Error for CommandBufferExecError { impl Display for CommandBufferExecError { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { - let value = match self { + match self { CommandBufferExecError::AccessError { - error, + error: _, command_name, command_offset, command_param, - } => return write!( + } => write!( f, - "access to a resource has been denied on command {} (offset: {}, param: {}): {}", - command_name, command_offset, command_param, error + "access to a resource has been denied on command {} (offset: {}, param: {})", + command_name, command_offset, command_param, ), - CommandBufferExecError::OneTimeSubmitAlreadySubmitted => { + CommandBufferExecError::OneTimeSubmitAlreadySubmitted => write!( + f, "the command buffer or one of the secondary command buffers it executes was \ created with the \"one time submit\" flag, but has already been submitted in \ - the past" - } - CommandBufferExecError::ExclusiveAlreadyInUse => { + the past", + ), + CommandBufferExecError::ExclusiveAlreadyInUse => write!( + f, "the command buffer or one of the secondary command buffers it executes is \ already in use was not created with the \"concurrent\" flag" - } - }; - - write!(f, "{}", value) + ), + } } } diff --git a/vulkano/src/pipeline/compute.rs b/vulkano/src/pipeline/compute.rs index 5b58e0d4..19c82972 100644 --- a/vulkano/src/pipeline/compute.rs +++ b/vulkano/src/pipeline/compute.rs @@ -108,7 +108,6 @@ impl ComputePipeline { let &PipelineShaderStageCreateInfo { flags, ref entry_point, - ref specialization_info, ref required_subgroup_size, _ne: _, } = stage; @@ -117,7 +116,9 @@ impl ComputePipeline { name_vk = CString::new(entry_point_info.name.as_str()).unwrap(); specialization_data_vk = Vec::new(); - specialization_map_entries_vk = specialization_info + specialization_map_entries_vk = entry_point + .module() + .specialization_info() .iter() .map(|(&constant_id, value)| { let data = value.as_bytes(); @@ -403,7 +404,6 @@ impl ComputePipelineCreateInfo { let &PipelineShaderStageCreateInfo { flags: _, ref entry_point, - specialization_info: _, required_subgroup_size: _vk, _ne: _, } = &stage; @@ -509,14 +509,15 @@ mod tests { ]; let module = ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&MODULE)).unwrap(); - module.entry_point("main").unwrap() + module + .specialize([(83, 0x12345678i32.into())].into_iter().collect()) + .unwrap() + .entry_point("main") + .unwrap() }; let pipeline = { - let stage = PipelineShaderStageCreateInfo { - specialization_info: [(83, 0x12345678i32.into())].into_iter().collect(), - ..PipelineShaderStageCreateInfo::new(cs) - }; + let stage = PipelineShaderStageCreateInfo::new(cs); let layout = PipelineLayout::new( device.clone(), PipelineDescriptorSetLayoutCreateInfo::from_stages([&stage]) diff --git a/vulkano/src/pipeline/graphics/mod.rs b/vulkano/src/pipeline/graphics/mod.rs index 5eb96fd8..cb7b4a11 100644 --- a/vulkano/src/pipeline/graphics/mod.rs +++ b/vulkano/src/pipeline/graphics/mod.rs @@ -215,7 +215,6 @@ impl GraphicsPipeline { let &PipelineShaderStageCreateInfo { flags, ref entry_point, - ref specialization_info, ref required_subgroup_size, _ne: _, } = stage; @@ -224,7 +223,9 @@ impl GraphicsPipeline { let stage = ShaderStage::from(&entry_point_info.execution); let mut specialization_data_vk: Vec = Vec::new(); - let specialization_map_entries_vk: Vec<_> = specialization_info + let specialization_map_entries_vk: Vec<_> = entry_point + .module() + .specialization_info() .iter() .map(|(&constant_id, value)| { let data = value.as_bytes(); @@ -2489,7 +2490,6 @@ impl GraphicsPipelineCreateInfo { let &PipelineShaderStageCreateInfo { flags: _, ref entry_point, - specialization_info: _, required_subgroup_size: _vk, _ne: _, } = stage; diff --git a/vulkano/src/pipeline/mod.rs b/vulkano/src/pipeline/mod.rs index fa7999e8..a92d1455 100644 --- a/vulkano/src/pipeline/mod.rs +++ b/vulkano/src/pipeline/mod.rs @@ -21,7 +21,7 @@ pub use self::{compute::ComputePipeline, graphics::GraphicsPipeline, layout::Pip use crate::{ device::{Device, DeviceOwned}, macros::{vulkan_bitflags, vulkan_enum}, - shader::{DescriptorBindingRequirements, EntryPoint, ShaderStage, SpecializationConstant}, + shader::{DescriptorBindingRequirements, EntryPoint, ShaderExecution, ShaderStage}, Requires, RequiresAllOf, RequiresOneOf, ValidationError, }; use ahash::HashMap; @@ -304,21 +304,11 @@ pub struct PipelineShaderStageCreateInfo { /// The default value is empty. pub flags: PipelineShaderStageCreateFlags, - /// The shader entry point for the stage. + /// The shader entry point for the stage, which includes any specialization constants. /// /// There is no default value. pub entry_point: EntryPoint, - /// Values for the specialization constants in the shader, indexed by their `constant_id`. - /// - /// Specialization constants are constants whose value can be overridden when you create - /// a pipeline. When provided, they must have the same type as defined in the shader. - /// Constants that are not given a value here will have the default value that was specified - /// for them in the shader code. - /// - /// The default value is empty. - pub specialization_info: HashMap, - /// The required subgroup size. /// /// Requires [`subgroup_size_control`](crate::device::Features::subgroup_size_control). The @@ -344,7 +334,6 @@ impl PipelineShaderStageCreateInfo { Self { flags: PipelineShaderStageCreateFlags::empty(), entry_point, - specialization_info: HashMap::default(), required_subgroup_size: None, _ne: crate::NonExhaustive(()), } @@ -354,7 +343,6 @@ impl PipelineShaderStageCreateInfo { let &Self { flags, ref entry_point, - ref specialization_info, required_subgroup_size, _ne: _, } = self; @@ -463,32 +451,11 @@ impl PipelineShaderStageCreateInfo { ShaderStage::SubpassShading => (), } - for (&constant_id, provided_value) in specialization_info { - // Per `VkSpecializationMapEntry` spec: - // "If a constantID value is not a specialization constant ID used in the shader, - // that map entry does not affect the behavior of the pipeline." - // We *may* want to be stricter than this for the sake of catching user errors? - if let Some(default_value) = entry_point_info.specialization_constants.get(&constant_id) - { - // Check for equal types rather than only equal size. - if !provided_value.eq_type(default_value) { - return Err(Box::new(ValidationError { - problem: format!( - "`specialization_info[{0}]` does not have the same type as \ - `entry_point.info().specialization_constants[{0}]`", - constant_id - ) - .into(), - vuids: &["VUID-VkSpecializationMapEntry-constantID-00776"], - ..Default::default() - })); - } - } - } - - let workgroup_size = if let Some(local_size) = - entry_point_info.local_size(specialization_info)? + let workgroup_size = if let ShaderExecution::Compute(execution) = + &entry_point_info.execution { + let local_size = execution.local_size; + match stage_enum { ShaderStage::Compute => { if local_size[0] > properties.max_compute_work_group_size[0] { diff --git a/vulkano/src/shader/mod.rs b/vulkano/src/shader/mod.rs index 315c36b0..71dfa88f 100644 --- a/vulkano/src/shader/mod.rs +++ b/vulkano/src/shader/mod.rs @@ -131,6 +131,7 @@ //! [`scalar_block_layout`]: crate::device::Features::scalar_block_layout //! [`uniform_buffer_standard_layout`]: crate::device::Features::uniform_buffer_standard_layout +use self::spirv::Instruction; use crate::{ descriptor_set::layout::DescriptorType, device::{Device, DeviceOwned}, @@ -170,7 +171,9 @@ pub struct ShaderModule { handle: ash::vk::ShaderModule, device: InstanceOwnedDebugWrapper>, id: NonZeroU64, - entry_point_infos: SmallVec<[EntryPointInfo; 1]>, + + spirv: Spirv, + specialization_constants: HashMap, } impl ShaderModule { @@ -192,50 +195,18 @@ impl ShaderModule { }) })?; - Self::new_with_data( - device, - create_info, - reflect::entry_points(&spirv), - spirv.version(), - reflect::spirv_capabilities(&spirv), - reflect::spirv_extensions(&spirv), - ) + Self::validate_new(&device, &create_info, &spirv)?; + + Ok(Self::new_with_spirv_unchecked(device, create_info, spirv)?) } - // This is public only for vulkano-shaders, do not use otherwise. - #[doc(hidden)] - pub unsafe fn new_with_data<'a>( - device: Arc, - create_info: ShaderModuleCreateInfo<'_>, - entry_points: impl IntoIterator, - spirv_version: Version, - spirv_capabilities: impl IntoIterator, - spirv_extensions: impl IntoIterator, - ) -> Result, Validated> { - Self::validate_new( - &device, - &create_info, - spirv_version, - spirv_capabilities, - spirv_extensions, - )?; - - Ok(Self::new_with_data_unchecked( - device, - create_info, - entry_points, - )?) - } - - fn validate_new<'a>( + fn validate_new( device: &Device, create_info: &ShaderModuleCreateInfo<'_>, - spirv_version: Version, - spirv_capabilities: impl IntoIterator, - spirv_extensions: impl IntoIterator, + spirv: &Spirv, ) -> Result<(), Box> { create_info - .validate(device, spirv_version, spirv_capabilities, spirv_extensions) + .validate(device, spirv) .map_err(|err| err.add_context("create_info"))?; Ok(()) @@ -247,14 +218,13 @@ impl ShaderModule { create_info: ShaderModuleCreateInfo<'_>, ) -> Result, VulkanError> { let spirv = Spirv::new(create_info.code).unwrap(); - - Self::new_with_data_unchecked(device, create_info, reflect::entry_points(&spirv)) + Self::new_with_spirv_unchecked(device, create_info, spirv) } - unsafe fn new_with_data_unchecked( + unsafe fn new_with_spirv_unchecked( device: Arc, create_info: ShaderModuleCreateInfo<'_>, - entry_points: impl IntoIterator, + spirv: Spirv, ) -> Result, VulkanError> { let &ShaderModuleCreateInfo { code, _ne: _ } = &create_info; @@ -279,11 +249,11 @@ impl ShaderModule { output.assume_init() }; - Ok(Self::from_handle_with_data( + Ok(Self::from_handle_with_spirv( device, handle, create_info, - entry_points, + spirv, )) } @@ -299,22 +269,25 @@ impl ShaderModule { create_info: ShaderModuleCreateInfo<'_>, ) -> Arc { let spirv = Spirv::new(create_info.code).unwrap(); - - Self::from_handle_with_data(device, handle, create_info, reflect::entry_points(&spirv)) + Self::from_handle_with_spirv(device, handle, create_info, spirv) } - unsafe fn from_handle_with_data( + unsafe fn from_handle_with_spirv( device: Arc, handle: ash::vk::ShaderModule, create_info: ShaderModuleCreateInfo<'_>, - entry_points: impl IntoIterator, + spirv: Spirv, ) -> Arc { let ShaderModuleCreateInfo { code: _, _ne: _ } = create_info; + let specialization_constants = reflect::specialization_constants(&spirv); + Arc::new(ShaderModule { handle, device: InstanceOwnedDebugWrapper(device), id: Self::next_id(), - entry_point_infos: entry_points.into_iter().collect(), + + spirv, + specialization_constants, }) } @@ -352,6 +325,458 @@ impl ShaderModule { Self::new(device, ShaderModuleCreateInfo::new(&words)) } + /// Returns the specialization constants that are defined in the module, + /// along with their default values. + /// + /// Specialization constants are constants whose value can be overridden when you create + /// a pipeline. They are indexed by their `constant_id`. + #[inline] + pub fn specialization_constants(&self) -> &HashMap { + &self.specialization_constants + } + + /// Applies the specialization constants to the shader module, + /// and returns a specialized version of the module. + /// + /// Constants that are not given a value here will have the default value that was specified + /// for them in the shader code. + /// When provided, they must have the same type as defined in the shader (as returned by + /// [`specialization_constants`]). + /// + /// [`specialization_constants`]: Self::specialization_constants + #[inline] + pub fn specialize( + self: &Arc, + specialization_info: HashMap, + ) -> Result, Box> { + SpecializedShaderModule::new(self.clone(), specialization_info) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + #[inline] + pub unsafe fn specialize_unchecked( + self: &Arc, + specialization_info: HashMap, + ) -> Arc { + SpecializedShaderModule::new_unchecked(self.clone(), specialization_info) + } + + /// Equivalent to calling [`specialize`] with empty specialization info, + /// and then calling [`SpecializedShaderModule::entry_point`]. + /// + /// [`specialize`]: Self::specialize + #[inline] + pub fn entry_point(self: &Arc, name: &str) -> Option { + unsafe { + self.specialize_unchecked(HashMap::default()) + .entry_point(name) + } + } + + /// Equivalent to calling [`specialize`] with empty specialization info, + /// and then calling [`SpecializedShaderModule::entry_point_with_execution`]. + /// + /// [`specialize`]: Self::specialize + #[inline] + pub fn entry_point_with_execution( + self: &Arc, + name: &str, + execution: ExecutionModel, + ) -> Option { + unsafe { + self.specialize_unchecked(HashMap::default()) + .entry_point_with_execution(name, execution) + } + } + + /// Equivalent to calling [`specialize`] with empty specialization info, + /// and then calling [`SpecializedShaderModule::single_entry_point`]. + /// + /// [`specialize`]: Self::specialize + #[inline] + pub fn single_entry_point(self: &Arc) -> Option { + unsafe { + self.specialize_unchecked(HashMap::default()) + .single_entry_point() + } + } + + /// Equivalent to calling [`specialize`] with empty specialization info, + /// and then calling [`SpecializedShaderModule::single_entry_point_with_execution`]. + /// + /// [`specialize`]: Self::specialize + #[inline] + pub fn single_entry_point_with_execution( + self: &Arc, + execution: ExecutionModel, + ) -> Option { + unsafe { + self.specialize_unchecked(HashMap::default()) + .single_entry_point_with_execution(execution) + } + } +} + +impl Drop for ShaderModule { + #[inline] + fn drop(&mut self) { + unsafe { + let fns = self.device.fns(); + (fns.v1_0.destroy_shader_module)(self.device.handle(), self.handle, ptr::null()); + } + } +} + +unsafe impl VulkanObject for ShaderModule { + type Handle = ash::vk::ShaderModule; + + #[inline] + fn handle(&self) -> Self::Handle { + self.handle + } +} + +unsafe impl DeviceOwned for ShaderModule { + #[inline] + fn device(&self) -> &Arc { + &self.device + } +} + +impl_id_counter!(ShaderModule); + +pub struct ShaderModuleCreateInfo<'a> { + /// The SPIR-V code, in the form of 32-bit words. + /// + /// There is no default value. + pub code: &'a [u32], + + pub _ne: crate::NonExhaustive, +} + +impl<'a> ShaderModuleCreateInfo<'a> { + /// Returns a `ShaderModuleCreateInfo` with the specified `code`. + #[inline] + pub fn new(code: &'a [u32]) -> Self { + Self { + code, + _ne: crate::NonExhaustive(()), + } + } + + pub(crate) fn validate( + &self, + device: &Device, + spirv: &Spirv, + ) -> Result<(), Box> { + let &Self { code, _ne: _ } = self; + + if code.is_empty() { + return Err(Box::new(ValidationError { + context: "code".into(), + problem: "is empty".into(), + vuids: &["VUID-VkShaderModuleCreateInfo-codeSize-01085"], + ..Default::default() + })); + } + + let spirv_version = Version { + patch: 0, // Ignore the patch version + ..spirv.version() + }; + + { + match spirv_version { + Version::V1_0 => None, + Version::V1_1 | Version::V1_2 | Version::V1_3 => { + (!(device.api_version() >= Version::V1_1)).then_some(RequiresOneOf(&[ + RequiresAllOf(&[Requires::APIVersion(Version::V1_1)]), + ])) + } + Version::V1_4 => (!(device.api_version() >= Version::V1_2 + || device.enabled_extensions().khr_spirv_1_4)) + .then_some(RequiresOneOf(&[ + RequiresAllOf(&[Requires::APIVersion(Version::V1_2)]), + RequiresAllOf(&[Requires::DeviceExtension("khr_spirv_1_4")]), + ])), + Version::V1_5 => { + (!(device.api_version() >= Version::V1_2)).then_some(RequiresOneOf(&[ + RequiresAllOf(&[Requires::APIVersion(Version::V1_2)]), + ])) + } + Version::V1_6 => { + (!(device.api_version() >= Version::V1_3)).then_some(RequiresOneOf(&[ + RequiresAllOf(&[Requires::APIVersion(Version::V1_3)]), + ])) + } + _ => { + return Err(Box::new(ValidationError { + context: "code".into(), + problem: format!( + "uses SPIR-V version {}.{}, which is not supported by Vulkan", + spirv_version.major, spirv_version.minor + ) + .into(), + // vuids? + ..Default::default() + })); + } + } + } + .map_or(Ok(()), |requires_one_of| { + Err(Box::new(ValidationError { + context: "code".into(), + problem: format!( + "uses SPIR-V version {}.{}", + spirv_version.major, spirv_version.minor + ) + .into(), + requires_one_of, + ..Default::default() + })) + })?; + + for &capability in spirv + .iter_capability() + .filter_map(|instruction| match instruction { + Instruction::Capability { capability } => Some(capability), + _ => None, + }) + { + validate_spirv_capability(device, capability).map_err(|err| err.add_context("code"))?; + } + + for extension in spirv + .iter_extension() + .filter_map(|instruction| match instruction { + Instruction::Extension { name } => Some(name.as_str()), + _ => None, + }) + { + validate_spirv_extension(device, extension).map_err(|err| err.add_context("code"))?; + } + + // VUID-VkShaderModuleCreateInfo-pCode-08736 + // VUID-VkShaderModuleCreateInfo-pCode-08737 + // VUID-VkShaderModuleCreateInfo-pCode-08738 + // Unsafe + + Ok(()) + } +} + +/// The value to provide for a specialization constant, when creating a pipeline. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum SpecializationConstant { + Bool(bool), + U8(u8), + U16(u16), + U32(u32), + U64(u64), + I8(i8), + I16(i16), + I32(i32), + I64(i64), + F16(f16), + F32(f32), + F64(f64), +} + +impl SpecializationConstant { + /// Returns the value as a byte slice. Booleans are expanded to a `VkBool32` value. + #[inline] + pub fn as_bytes(&self) -> &[u8] { + match self { + Self::Bool(false) => bytes_of(&ash::vk::FALSE), + Self::Bool(true) => bytes_of(&ash::vk::TRUE), + Self::U8(value) => bytes_of(value), + Self::U16(value) => bytes_of(value), + Self::U32(value) => bytes_of(value), + Self::U64(value) => bytes_of(value), + Self::I8(value) => bytes_of(value), + Self::I16(value) => bytes_of(value), + Self::I32(value) => bytes_of(value), + Self::I64(value) => bytes_of(value), + Self::F16(value) => bytes_of(value), + Self::F32(value) => bytes_of(value), + Self::F64(value) => bytes_of(value), + } + } + + /// Returns whether `self` and `other` have the same type, ignoring the value. + #[inline] + pub fn eq_type(&self, other: &Self) -> bool { + discriminant(self) == discriminant(other) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: bool) -> Self { + SpecializationConstant::Bool(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: i8) -> Self { + SpecializationConstant::I8(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: i16) -> Self { + SpecializationConstant::I16(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: i32) -> Self { + SpecializationConstant::I32(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: i64) -> Self { + SpecializationConstant::I64(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: u8) -> Self { + SpecializationConstant::U8(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: u16) -> Self { + SpecializationConstant::U16(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: u32) -> Self { + SpecializationConstant::U32(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: u64) -> Self { + SpecializationConstant::U64(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: f16) -> Self { + SpecializationConstant::F16(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: f32) -> Self { + SpecializationConstant::F32(value) + } +} + +impl From for SpecializationConstant { + #[inline] + fn from(value: f64) -> Self { + SpecializationConstant::F64(value) + } +} + +/// A shader module with specialization constants applied. +#[derive(Debug)] +pub struct SpecializedShaderModule { + base_module: Arc, + specialization_info: HashMap, + _spirv: Option, + entry_point_infos: SmallVec<[EntryPointInfo; 1]>, +} + +impl SpecializedShaderModule { + /// Returns `base_module` specialized with `specialization_info`. + #[inline] + pub fn new( + base_module: Arc, + specialization_info: HashMap, + ) -> Result, Box> { + Self::validate_new(&base_module, &specialization_info)?; + + unsafe { Ok(Self::new_unchecked(base_module, specialization_info)) } + } + + fn validate_new( + base_module: &ShaderModule, + specialization_info: &HashMap, + ) -> Result<(), Box> { + for (&constant_id, provided_value) in specialization_info { + // Per `VkSpecializationMapEntry` spec: + // "If a constantID value is not a specialization constant ID used in the shader, + // that map entry does not affect the behavior of the pipeline." + // We *may* want to be stricter than this for the sake of catching user errors? + if let Some(default_value) = base_module.specialization_constants.get(&constant_id) { + // Check for equal types rather than only equal size. + if !provided_value.eq_type(default_value) { + return Err(Box::new(ValidationError { + problem: format!( + "`specialization_info[{0}]` does not have the same type as \ + `base_module.specialization_constants()[{0}]`", + constant_id + ) + .into(), + vuids: &["VUID-VkSpecializationMapEntry-constantID-00776"], + ..Default::default() + })); + } + } + } + + Ok(()) + } + + #[cfg_attr(not(feature = "document_unchecked"), doc(hidden))] + pub unsafe fn new_unchecked( + base_module: Arc, + specialization_info: HashMap, + ) -> Arc { + let spirv = (!base_module.specialization_constants.is_empty()).then(|| { + let mut spirv = base_module.spirv.clone(); + spirv.apply_specialization(&specialization_info); + spirv + }); + let entry_point_infos = + reflect::entry_points(spirv.as_ref().unwrap_or(&base_module.spirv)).collect(); + + Arc::new(Self { + base_module, + specialization_info, + _spirv: spirv, + entry_point_infos, + }) + } + + /// Returns the base module, without specialization applied. + #[inline] + pub fn base_module(&self) -> &Arc { + &self.base_module + } + + /// Returns the specialization constants that have been applied to the module. + #[inline] + pub fn specialization_info(&self) -> &HashMap { + &self.specialization_info + } + /// Returns information about the entry point with the provided name. Returns `None` if no entry /// point with that name exists in the shader module or if multiple entry points with the same /// name exist. @@ -403,7 +828,7 @@ impl ShaderModule { /// with the provided `ExecutionModel`. Returns `None` if no entry point was found or multiple /// entry points have been found matching the provided `ExecutionModel`. #[inline] - pub fn single_entry_point_of_execution( + pub fn single_entry_point_with_execution( self: &Arc, execution: ExecutionModel, ) -> Option { @@ -411,138 +836,19 @@ impl ShaderModule { } } -impl Drop for ShaderModule { - #[inline] - fn drop(&mut self) { - unsafe { - let fns = self.device.fns(); - (fns.v1_0.destroy_shader_module)(self.device.handle(), self.handle, ptr::null()); - } - } -} - -unsafe impl VulkanObject for ShaderModule { +unsafe impl VulkanObject for SpecializedShaderModule { type Handle = ash::vk::ShaderModule; #[inline] fn handle(&self) -> Self::Handle { - self.handle + self.base_module.handle } } -unsafe impl DeviceOwned for ShaderModule { +unsafe impl DeviceOwned for SpecializedShaderModule { #[inline] fn device(&self) -> &Arc { - &self.device - } -} - -impl_id_counter!(ShaderModule); - -pub struct ShaderModuleCreateInfo<'a> { - /// The SPIR-V code, in the form of 32-bit words. - /// - /// There is no default value. - pub code: &'a [u32], - - pub _ne: crate::NonExhaustive, -} - -impl<'a> ShaderModuleCreateInfo<'a> { - /// Returns a `ShaderModuleCreateInfo` with the specified `code`. - #[inline] - pub fn new(code: &'a [u32]) -> Self { - Self { - code, - _ne: crate::NonExhaustive(()), - } - } - - pub(crate) fn validate<'b>( - &self, - device: &Device, - mut spirv_version: Version, - spirv_capabilities: impl IntoIterator, - spirv_extensions: impl IntoIterator, - ) -> Result<(), Box> { - let &Self { code, _ne: _ } = self; - - if code.is_empty() { - return Err(Box::new(ValidationError { - context: "code".into(), - problem: "is empty".into(), - vuids: &["VUID-VkShaderModuleCreateInfo-codeSize-01085"], - ..Default::default() - })); - } - - { - spirv_version.patch = 0; // Ignore the patch version - - match spirv_version { - Version::V1_0 => None, - Version::V1_1 | Version::V1_2 | Version::V1_3 => { - (!(device.api_version() >= Version::V1_1)).then_some(RequiresOneOf(&[ - RequiresAllOf(&[Requires::APIVersion(Version::V1_1)]), - ])) - } - Version::V1_4 => (!(device.api_version() >= Version::V1_2 - || device.enabled_extensions().khr_spirv_1_4)) - .then_some(RequiresOneOf(&[ - RequiresAllOf(&[Requires::APIVersion(Version::V1_2)]), - RequiresAllOf(&[Requires::DeviceExtension("khr_spirv_1_4")]), - ])), - Version::V1_5 => { - (!(device.api_version() >= Version::V1_2)).then_some(RequiresOneOf(&[ - RequiresAllOf(&[Requires::APIVersion(Version::V1_2)]), - ])) - } - Version::V1_6 => { - (!(device.api_version() >= Version::V1_3)).then_some(RequiresOneOf(&[ - RequiresAllOf(&[Requires::APIVersion(Version::V1_3)]), - ])) - } - _ => { - return Err(Box::new(ValidationError { - context: "code".into(), - problem: format!( - "uses SPIR-V version {}.{}, which is not supported by Vulkan", - spirv_version.major, spirv_version.minor - ) - .into(), - // vuids? - ..Default::default() - })); - } - } - } - .map_or(Ok(()), |requires_one_of| { - Err(Box::new(ValidationError { - context: "code".into(), - problem: format!( - "uses SPIR-V version {}.{}", - spirv_version.major, spirv_version.minor - ) - .into(), - requires_one_of, - ..Default::default() - })) - })?; - - for &capability in spirv_capabilities { - validate_spirv_capability(device, capability).map_err(|err| err.add_context("code"))?; - } - - for extension in spirv_extensions { - validate_spirv_extension(device, extension).map_err(|err| err.add_context("code"))?; - } - - // VUID-VkShaderModuleCreateInfo-pCode-08736 - // VUID-VkShaderModuleCreateInfo-pCode-08737 - // VUID-VkShaderModuleCreateInfo-pCode-08738 - // Unsafe - - Ok(()) + &self.base_module.device } } @@ -553,96 +859,23 @@ pub struct EntryPointInfo { pub execution: ShaderExecution, pub descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>, pub push_constant_requirements: Option, - pub specialization_constants: HashMap, pub input_interface: ShaderInterface, pub output_interface: ShaderInterface, } -impl EntryPointInfo { - /// The local size in Compute shaders, None otherwise. - /// - /// `specialization_info` is used for LocalSizeId / WorkgroupSizeId, using specialization_constants if not found. - /// Errors if specialization constants are not found or are not u32's. - pub(crate) fn local_size( - &self, - specialization_info: &HashMap, - ) -> Result, Box> { - if let ShaderExecution::Compute(execution) = self.execution { - match execution { - ComputeShaderExecution::LocalSize(local_size) - | ComputeShaderExecution::LocalSizeId(local_size) => { - let mut output = [0; 3]; - for (output, local_size) in output.iter_mut().zip(local_size) { - let id = match local_size { - LocalSize::Literal(literal) => { - *output = literal; - continue; - } - LocalSize::SpecId(id) => id, - }; - if let Some(default) = self.specialization_constants.get(&id) { - let default_value = if let SpecializationConstant::U32(default_value) = - default - { - default_value - } else { - return Err(Box::new(ValidationError { - problem: format!( - "`entry_point.info().specialization_constants[{id}]` is not a 32 bit integer" - ) - .into(), - ..Default::default() - })); - }; - if let Some(provided) = specialization_info.get(&id) { - if let SpecializationConstant::U32(provided_value) = provided { - *output = *provided_value; - } else { - return Err(Box::new(ValidationError { - problem: format!( - "`specialization_info[{0}]` does not have the same type as \ - `entry_point.info().specialization_constants[{0}]`", - id, - ) - .into(), - vuids: &["VUID-VkSpecializationMapEntry-constantID-00776"], - ..Default::default() - })); - } - } - *output = *default_value; - } else { - return Err(Box::new(ValidationError { - problem: format!( - "specialization constant {id} not found in `entry_point.info().specialization_constants`" - ) - .into(), - ..Default::default() - })); - } - } - Ok(Some(output)) - } - } - } else { - Ok(None) - } - } -} - /// Represents a shader entry point in a shader module. /// /// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module. #[derive(Clone, Debug)] pub struct EntryPoint { - module: Arc, + module: Arc, info_index: usize, } impl EntryPoint { /// Returns the module this entry point comes from. #[inline] - pub fn module(&self) -> &Arc { + pub fn module(&self) -> &Arc { &self.module } @@ -780,30 +1013,15 @@ pub enum FragmentTestsStages { EarlyAndLate, } -/// LocalSize. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum LocalSize { - Literal(u32), - SpecId(u32), -} - /// The mode in which the compute shader executes. /// -/// The workgroup size is specified for x, y, and z dimensions. -/// -/// Constants are resolved to literals, while specialization constants -/// map to spec ids. -/// /// The `WorkgroupSize` builtin overrides the values specified in the /// execution mode. It can decorate a 3 component ConstantComposite or /// SpecConstantComposite vector. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ComputeShaderExecution { +pub struct ComputeShaderExecution { /// Workgroup size in x, y, and z. - LocalSize([LocalSize; 3]), - /// Requires spirv 1.2. - /// Like `LocalSize`, but uses ids instead of literals. - LocalSizeId([LocalSize; 3]), + pub local_size: [u32; 3], } /// The requirements imposed by a shader on a binding within a descriptor set layout, and on any @@ -995,135 +1213,6 @@ impl DescriptorRequirements { } } -/// The value to provide for a specialization constant, when creating a pipeline. -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum SpecializationConstant { - Bool(bool), - I8(i8), - I16(i16), - I32(i32), - I64(i64), - U8(u8), - U16(u16), - U32(u32), - U64(u64), - F16(f16), - F32(f32), - F64(f64), -} - -impl SpecializationConstant { - /// Returns the value as a byte slice. Booleans are expanded to a `VkBool32` value. - #[inline] - pub fn as_bytes(&self) -> &[u8] { - match self { - Self::Bool(false) => bytes_of(&ash::vk::FALSE), - Self::Bool(true) => bytes_of(&ash::vk::TRUE), - Self::I8(value) => bytes_of(value), - Self::I16(value) => bytes_of(value), - Self::I32(value) => bytes_of(value), - Self::I64(value) => bytes_of(value), - Self::U8(value) => bytes_of(value), - Self::U16(value) => bytes_of(value), - Self::U32(value) => bytes_of(value), - Self::U64(value) => bytes_of(value), - Self::F16(value) => bytes_of(value), - Self::F32(value) => bytes_of(value), - Self::F64(value) => bytes_of(value), - } - } - - /// Returns whether `self` and `other` have the same type, ignoring the value. - #[inline] - pub fn eq_type(&self, other: &Self) -> bool { - discriminant(self) == discriminant(other) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: bool) -> Self { - SpecializationConstant::Bool(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: i8) -> Self { - SpecializationConstant::I8(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: i16) -> Self { - SpecializationConstant::I16(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: i32) -> Self { - SpecializationConstant::I32(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: i64) -> Self { - SpecializationConstant::I64(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: u8) -> Self { - SpecializationConstant::U8(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: u16) -> Self { - SpecializationConstant::U16(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: u32) -> Self { - SpecializationConstant::U32(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: u64) -> Self { - SpecializationConstant::U64(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: f16) -> Self { - SpecializationConstant::F16(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: f32) -> Self { - SpecializationConstant::F32(value) - } -} - -impl From for SpecializationConstant { - #[inline] - fn from(value: f64) -> Self { - SpecializationConstant::F64(value) - } -} - /// Type that contains the definition of an interface between two shader stages, or between /// the outside and a shader stage. #[derive(Clone, Debug)] diff --git a/vulkano/src/shader/reflect.rs b/vulkano/src/shader/reflect.rs index 7318bf59..d44e55ec 100644 --- a/vulkano/src/shader/reflect.rs +++ b/vulkano/src/shader/reflect.rs @@ -16,11 +16,11 @@ use crate::{ pipeline::layout::PushConstantRange, shader::{ spirv::{ - BuiltIn, Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, - Spirv, StorageClass, + BuiltIn, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv, + StorageClass, }, ComputeShaderExecution, DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, - GeometryShaderExecution, GeometryShaderInput, LocalSize, NumericType, ShaderExecution, + GeometryShaderExecution, GeometryShaderInput, NumericType, ShaderExecution, ShaderInterface, ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, SpecializationConstant, }, @@ -30,56 +30,24 @@ use ahash::{HashMap, HashSet}; use half::f16; use std::borrow::Cow; -/// Returns an iterator of the capabilities used by `spirv`. -#[inline] -pub fn spirv_capabilities(spirv: &Spirv) -> impl Iterator { - spirv - .iter_capability() - .filter_map(|instruction| match instruction { - Instruction::Capability { capability } => Some(capability), - _ => None, - }) -} - -/// Returns an iterator of the extensions used by `spirv`. -#[inline] -pub fn spirv_extensions(spirv: &Spirv) -> impl Iterator { - spirv - .iter_extension() - .filter_map(|instruction| match instruction { - Instruction::Extension { name } => Some(name.as_str()), - _ => None, - }) -} - /// Returns an iterator over all entry points in `spirv`, with information about the entry point. #[inline] pub fn entry_points(spirv: &Spirv) -> impl Iterator + '_ { let interface_variables = interface_variables(spirv); - let u32_constants = u32_constants(spirv); - let specialization_constant_ids = specialization_constant_ids(spirv); - let workgroup_size_decorations = workgroup_size_decorations(spirv); spirv.iter_entry_point().filter_map(move |instruction| { - let (execution_model, function_id, entry_point_name, interface) = match instruction { + let (execution_model, function_id, entry_point_name, interface) = match *instruction { Instruction::EntryPoint { execution_model, entry_point, - name, - interface, + ref name, + ref interface, .. - } => (*execution_model, *entry_point, name, interface), + } => (execution_model, entry_point, name, interface), _ => return None, }; - let execution = shader_execution( - spirv, - execution_model, - function_id, - &u32_constants, - &specialization_constant_ids, - &workgroup_size_decorations, - ); + let execution = shader_execution(spirv, execution_model, function_id); let stage = ShaderStage::from(&execution); let descriptor_binding_requirements = inspect_entry_point( @@ -89,7 +57,6 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator + '_ function_id, ); let push_constant_requirements = push_constant_requirements(spirv, stage); - let specialization_constants = specialization_constants(spirv); let input_interface = shader_interface( spirv, interface, @@ -113,101 +80,17 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator + '_ execution, descriptor_binding_requirements, push_constant_requirements, - specialization_constants, input_interface, output_interface, }) }) } -/// Extracts the u32 constants from `spirv`. -fn u32_constants(spirv: &Spirv) -> HashMap { - let type_u32s: HashSet = spirv - .iter_global() - .filter_map(|inst| { - if let Instruction::TypeInt { - result_id, - width, - signedness, - } = inst - { - if *width == 0 && *signedness == 0 { - return Some(*result_id); - } - } - None - }) - .collect(); - spirv - .iter_decoration() - .filter_map(|inst| { - if let Instruction::Constant { - result_type_id, - result_id, - value, - } = inst - { - if type_u32s.contains(result_type_id) { - if let [value] = value.as_slice() { - return Some((*result_id, *value)); - } - } - } - None - }) - .collect() -} - -/// Extracts the specialization constant ids from `spirv`. -fn specialization_constant_ids(spirv: &Spirv) -> HashMap { - spirv - .iter_decoration() - .filter_map(|inst| { - if let Instruction::Decorate { - target, - decoration: - Decoration::SpecId { - specialization_constant_id, - }, - } = inst - { - Some((*target, *specialization_constant_id)) - } else { - None - } - }) - .collect() -} - -/// Extracts the `WorkgroupSize` builtin Id's from `spirv`. -fn workgroup_size_decorations(spirv: &Spirv) -> HashSet { - spirv - .iter_decoration() - .filter_map(|inst| { - if let Instruction::Decorate { - target, - decoration: - Decoration::BuiltIn { - built_in: BuiltIn::WorkgroupSize, - }, - } = inst - { - Some(*target) - } else { - None - } - }) - .collect() -} - /// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`. fn shader_execution( spirv: &Spirv, execution_model: ExecutionModel, function_id: Id, - u32_constants: &HashMap, - specialization_constant_ids: &HashMap, - workgroup_size_decorations: &HashSet, ) -> ShaderExecution { match execution_model { ExecutionModel::Vertex => ShaderExecution::Vertex, @@ -264,14 +147,13 @@ fn shader_execution( _ => continue, }; - #[allow(clippy::single_match)] match mode { ExecutionMode::EarlyFragmentTests => { fragment_tests_stages = FragmentTestsStages::Early; } - /*ExecutionMode::EarlyAndLateFragmentTestsAMD => { + ExecutionMode::EarlyAndLateFragmentTestsAMD => { fragment_tests_stages = FragmentTestsStages::EarlyAndLate; - }*/ + } _ => (), } } @@ -282,122 +164,82 @@ fn shader_execution( } ExecutionModel::GLCompute => { - let mut execution = ComputeShaderExecution::LocalSize([LocalSize::Literal(0); 3]); - for instruction in spirv.iter_execution_mode() { - match instruction { - Instruction::ExecutionMode { entry_point, mode } - if *entry_point == function_id => - { - if let ExecutionMode::LocalSize { - x_size, - y_size, - z_size, - } = mode - { - execution = ComputeShaderExecution::LocalSize([ - LocalSize::Literal(*x_size), - LocalSize::Literal(*y_size), - LocalSize::Literal(*z_size), - ]); - break; - } - } - Instruction::ExecutionModeId { entry_point, mode } - if *entry_point == function_id => - { - if let ExecutionMode::LocalSizeId { - x_size, - y_size, - z_size, - } = mode - { - let mut local_size = [LocalSize::Literal(0); 3]; - for (local_size, id) in - local_size.iter_mut().zip([*x_size, *y_size, *z_size]) - { - if let Some(constant) = u32_constants.get(&id) { - *local_size = LocalSize::Literal(*constant); - } else if let Some(spec_id) = specialization_constant_ids.get(&id) { - *local_size = LocalSize::SpecId(*spec_id); - } else { - panic!("LocalSizeId {id:?} not defined!"); + let local_size = (spirv + .iter_decoration() + .find_map(|instruction| match *instruction { + Instruction::Decorate { + target, + decoration: + Decoration::BuiltIn { + built_in: BuiltIn::WorkgroupSize, + }, + } => match *spirv.id(target).instruction() { + Instruction::ConstantComposite { + ref constituents, .. + } => { + match *constituents.as_slice() { + [x_size, y_size, z_size] => { + Some([x_size, y_size, z_size].map(|id| { + match *spirv.id(id).instruction() { + Instruction::Constant { ref value, .. } => { + assert!(value.len() == 1); + value[0] + } + // VUID-WorkgroupSize-WorkgroupSize-04426 + // VUID-WorkgroupSize-WorkgroupSize-04427 + _ => panic!("WorkgroupSize is not a constant"), + } + })) } + // VUID-WorkgroupSize-WorkgroupSize-04427 + _ => panic!("WorkgroupSize must be 3 component vector!"), } - execution = ComputeShaderExecution::LocalSizeId(local_size); - break; } - } - _ => continue, - }; - } - if !workgroup_size_decorations.is_empty() { - let mut in_function = false; - for instruction in spirv.instructions() { - if !in_function { - match *instruction { - Instruction::Function { result_id, .. } if result_id == function_id => { - in_function = true; - } - _ => {} - } - } else { - let mut local_size = [LocalSize::Literal(0); 3]; - match instruction { - Instruction::ConstantComposite { - result_type_id: _, - result_id, - constituents, - } => { - if workgroup_size_decorations.contains(result_id) { - if constituents.len() != 3 { - panic!("WorkgroupSize must be 3 component vector!"); - } - for (local_size, id) in - local_size.iter_mut().zip(constituents.iter()) - { - if let Some(constant) = u32_constants.get(id) { - *local_size = LocalSize::Literal(*constant); - } else { - panic!("WorkgroupSize Constant {id:?} not defined!"); - }; - } + // VUID-WorkgroupSize-WorkgroupSize-04426 + _ => panic!("WorkgroupSize is not a constant"), + }, + _ => None, + })) + .or_else(|| { + spirv + .iter_execution_mode() + .find_map(|instruction| match *instruction { + Instruction::ExecutionMode { + entry_point, + mode: + ExecutionMode::LocalSize { + x_size, + y_size, + z_size, + }, + } if entry_point == function_id => Some([x_size, y_size, z_size]), + Instruction::ExecutionModeId { + entry_point, + mode: + ExecutionMode::LocalSizeId { + x_size, + y_size, + z_size, + }, + } if entry_point == function_id => Some([x_size, y_size, z_size].map( + |id| match *spirv.id(id).instruction() { + Instruction::Constant { ref value, .. } => { + assert!(value.len() == 1); + value[0] } - } - Instruction::SpecConstantComposite { - result_type_id: _, - result_id, - constituents, - } => { - if workgroup_size_decorations.contains(result_id) { - if constituents.len() != 3 { - panic!("WorkgroupSize must be 3 component vector!"); - } - for (local_size, id) in - local_size.iter_mut().zip(constituents.iter()) - { - if let Some(spec_id) = specialization_constant_ids.get(id) { - *local_size = LocalSize::SpecId(*spec_id); - } else { - panic!("WorkgroupSize SpecializationConstant {id:?} not defined!"); - }; - } - } - } - Instruction::FunctionEnd => break, - _ => continue, - } - match &mut execution { - ComputeShaderExecution::LocalSize(output) => { - *output = local_size; - } - ComputeShaderExecution::LocalSizeId(output) => { - *output = local_size; - } - } - } - } - } - ShaderExecution::Compute(execution) + _ => panic!("LocalSizeId is not a constant"), + }, + )), + _ => None, + }) + }); + + ShaderExecution::Compute(ComputeShaderExecution { + local_size: local_size.expect( + "Geometry shader does not have a WorkgroupSize builtin, \ + or LocalSize or LocalSizeId ExecutionMode", + ), + }) } ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration, @@ -541,414 +383,375 @@ fn inspect_entry_point( } self.inspected_functions.insert(function); - let mut in_function = false; - for instruction in self.spirv.instructions() { - if !in_function { - match *instruction { - Instruction::Function { result_id, .. } if result_id == function => { - in_function = true; + for instruction in self.spirv.function(function).iter_instructions() { + let stage = self.stage; + + match *instruction { + Instruction::AtomicLoad { pointer, .. } => { + // Storage buffer + if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) { + desc_reqs.memory_read = stage.into(); + } + + // Storage image + if let Some(desc_reqs) = + desc_reqs(self.instruction_chain([inst_image_texel_pointer], pointer)) + { + desc_reqs.memory_read = stage.into(); + desc_reqs.storage_image_atomic = true; } - _ => {} } - } else { - let stage = self.stage; - match *instruction { - Instruction::AtomicLoad { pointer, .. } => { - // Storage buffer - if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) - { - desc_reqs.memory_read = stage.into(); - } - - // Storage image - if let Some(desc_reqs) = desc_reqs( - self.instruction_chain([inst_image_texel_pointer], pointer), - ) { - desc_reqs.memory_read = stage.into(); - desc_reqs.storage_image_atomic = true; - } + Instruction::AtomicStore { pointer, .. } => { + // Storage buffer + if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) { + desc_reqs.memory_write = stage.into(); } - Instruction::AtomicStore { pointer, .. } => { - // Storage buffer - if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) - { - desc_reqs.memory_write = stage.into(); - } + // Storage image + if let Some(desc_reqs) = + desc_reqs(self.instruction_chain([inst_image_texel_pointer], pointer)) + { + desc_reqs.memory_write = stage.into(); + desc_reqs.storage_image_atomic = true; + } + } - // Storage image - if let Some(desc_reqs) = desc_reqs( - self.instruction_chain([inst_image_texel_pointer], pointer), - ) { - desc_reqs.memory_write = stage.into(); - desc_reqs.storage_image_atomic = true; - } + Instruction::AtomicExchange { pointer, .. } + | Instruction::AtomicCompareExchange { pointer, .. } + | Instruction::AtomicCompareExchangeWeak { pointer, .. } + | Instruction::AtomicIIncrement { pointer, .. } + | Instruction::AtomicIDecrement { pointer, .. } + | Instruction::AtomicIAdd { pointer, .. } + | Instruction::AtomicISub { pointer, .. } + | Instruction::AtomicSMin { pointer, .. } + | Instruction::AtomicUMin { pointer, .. } + | Instruction::AtomicSMax { pointer, .. } + | Instruction::AtomicUMax { pointer, .. } + | Instruction::AtomicAnd { pointer, .. } + | Instruction::AtomicOr { pointer, .. } + | Instruction::AtomicXor { pointer, .. } + | Instruction::AtomicFlagTestAndSet { pointer, .. } + | Instruction::AtomicFlagClear { pointer, .. } + | Instruction::AtomicFMinEXT { pointer, .. } + | Instruction::AtomicFMaxEXT { pointer, .. } + | Instruction::AtomicFAddEXT { pointer, .. } => { + // Storage buffer + if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) { + desc_reqs.memory_read = stage.into(); + desc_reqs.memory_write = stage.into(); } - Instruction::AtomicExchange { pointer, .. } - | Instruction::AtomicCompareExchange { pointer, .. } - | Instruction::AtomicCompareExchangeWeak { pointer, .. } - | Instruction::AtomicIIncrement { pointer, .. } - | Instruction::AtomicIDecrement { pointer, .. } - | Instruction::AtomicIAdd { pointer, .. } - | Instruction::AtomicISub { pointer, .. } - | Instruction::AtomicSMin { pointer, .. } - | Instruction::AtomicUMin { pointer, .. } - | Instruction::AtomicSMax { pointer, .. } - | Instruction::AtomicUMax { pointer, .. } - | Instruction::AtomicAnd { pointer, .. } - | Instruction::AtomicOr { pointer, .. } - | Instruction::AtomicXor { pointer, .. } - | Instruction::AtomicFlagTestAndSet { pointer, .. } - | Instruction::AtomicFlagClear { pointer, .. } - | Instruction::AtomicFMinEXT { pointer, .. } - | Instruction::AtomicFMaxEXT { pointer, .. } - | Instruction::AtomicFAddEXT { pointer, .. } => { - // Storage buffer - if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) - { - desc_reqs.memory_read = stage.into(); - desc_reqs.memory_write = stage.into(); - } - - // Storage image - if let Some(desc_reqs) = desc_reqs( - self.instruction_chain([inst_image_texel_pointer], pointer), - ) { - desc_reqs.memory_read = stage.into(); - desc_reqs.memory_write = stage.into(); - desc_reqs.storage_image_atomic = true; - } + // Storage image + if let Some(desc_reqs) = + desc_reqs(self.instruction_chain([inst_image_texel_pointer], pointer)) + { + desc_reqs.memory_read = stage.into(); + desc_reqs.memory_write = stage.into(); + desc_reqs.storage_image_atomic = true; } + } - Instruction::CopyMemory { target, source, .. } => { - self.instruction_chain([], target); - self.instruction_chain([], source); - } + Instruction::CopyMemory { target, source, .. } => { + self.instruction_chain([], target); + self.instruction_chain([], source); + } - Instruction::CopyObject { operand, .. } => { + Instruction::CopyObject { operand, .. } => { + self.instruction_chain([], operand); + } + + Instruction::ExtInst { ref operands, .. } => { + // We don't know which extended instructions take pointers, + // so we must interpret every operand as a pointer. + for &operand in operands { self.instruction_chain([], operand); } + } - Instruction::ExtInst { ref operands, .. } => { - // We don't know which extended instructions take pointers, - // so we must interpret every operand as a pointer. - for &operand in operands { - self.instruction_chain([], operand); - } + Instruction::FunctionCall { + function, + ref arguments, + .. + } => { + // Rather than trying to figure out the type of each argument, we just + // try all of them as pointers. + for &argument in arguments { + self.instruction_chain([], argument); } - Instruction::FunctionCall { - function, - ref arguments, - .. - } => { - // Rather than trying to figure out the type of each argument, we just - // try all of them as pointers. - for &argument in arguments { - self.instruction_chain([], argument); - } + if !self.inspected_functions.contains(&function) { + self.inspect_entry_point_r(function); + } + } - if !self.inspected_functions.contains(&function) { - self.inspect_entry_point_r(function); - } - } + Instruction::ImageGather { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseGather { + sampled_image, + image_operands, + .. + } => { + if let Some(desc_reqs) = desc_reqs( + self.instruction_chain([inst_sampled_image, inst_load], sampled_image), + ) { + desc_reqs.memory_read = stage.into(); + desc_reqs.sampler_no_ycbcr_conversion = true; - Instruction::FunctionEnd => return, - - Instruction::ImageGather { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseGather { - sampled_image, - image_operands, - .. - } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain( - [inst_sampled_image, inst_load], - sampled_image, - )) - { - desc_reqs.memory_read = stage.into(); - desc_reqs.sampler_no_ycbcr_conversion = true; - - if image_operands.as_ref().map_or(false, |image_operands| { - image_operands.bias.is_some() - || image_operands.const_offset.is_some() - || image_operands.offset.is_some() - }) { - desc_reqs.sampler_no_unnormalized_coordinates = true; - } - } - } - - Instruction::ImageDrefGather { sampled_image, .. } - | Instruction::ImageSparseDrefGather { sampled_image, .. } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain( - [inst_sampled_image, inst_load], - sampled_image, - )) - { - desc_reqs.memory_read = stage.into(); - desc_reqs.sampler_no_unnormalized_coordinates = true; - desc_reqs.sampler_no_ycbcr_conversion = true; - } - } - - Instruction::ImageSampleImplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSampleProjImplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleProjImplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleImplicitLod { - sampled_image, - image_operands, - .. - } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain( - [inst_sampled_image, inst_load], - sampled_image, - )) - { - desc_reqs.memory_read = stage.into(); - desc_reqs.sampler_no_unnormalized_coordinates = true; - - if image_operands.as_ref().map_or(false, |image_operands| { - image_operands.const_offset.is_some() - || image_operands.offset.is_some() - }) { - desc_reqs.sampler_no_ycbcr_conversion = true; - } - } - } - - Instruction::ImageSampleProjExplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleProjExplicitLod { - sampled_image, - image_operands, - .. - } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain( - [inst_sampled_image, inst_load], - sampled_image, - )) - { - desc_reqs.memory_read = stage.into(); - desc_reqs.sampler_no_unnormalized_coordinates = true; - - if image_operands.const_offset.is_some() - || image_operands.offset.is_some() - { - desc_reqs.sampler_no_ycbcr_conversion = true; - } - } - } - - Instruction::ImageSampleDrefImplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSampleProjDrefImplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleDrefImplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleProjDrefImplicitLod { - sampled_image, - image_operands, - .. - } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain( - [inst_sampled_image, inst_load], - sampled_image, - )) - { - desc_reqs.memory_read = stage.into(); - desc_reqs.sampler_no_unnormalized_coordinates = true; - desc_reqs.sampler_compare = true; - - if image_operands.as_ref().map_or(false, |image_operands| { - image_operands.const_offset.is_some() - || image_operands.offset.is_some() - }) { - desc_reqs.sampler_no_ycbcr_conversion = true; - } - } - } - - Instruction::ImageSampleDrefExplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSampleProjDrefExplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleDrefExplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleProjDrefExplicitLod { - sampled_image, - image_operands, - .. - } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain( - [inst_sampled_image, inst_load], - sampled_image, - )) - { - desc_reqs.memory_read = stage.into(); - desc_reqs.sampler_no_unnormalized_coordinates = true; - desc_reqs.sampler_compare = true; - - if image_operands.const_offset.is_some() - || image_operands.offset.is_some() - { - desc_reqs.sampler_no_ycbcr_conversion = true; - } - } - } - - Instruction::ImageSampleExplicitLod { - sampled_image, - image_operands, - .. - } - | Instruction::ImageSparseSampleExplicitLod { - sampled_image, - image_operands, - .. - } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain( - [inst_sampled_image, inst_load], - sampled_image, - )) - { - desc_reqs.memory_read = stage.into(); - - if image_operands.bias.is_some() + if image_operands.as_ref().map_or(false, |image_operands| { + image_operands.bias.is_some() || image_operands.const_offset.is_some() || image_operands.offset.is_some() - { - desc_reqs.sampler_no_unnormalized_coordinates = true; - } - - if image_operands.const_offset.is_some() - || image_operands.offset.is_some() - { - desc_reqs.sampler_no_ycbcr_conversion = true; - } + }) { + desc_reqs.sampler_no_unnormalized_coordinates = true; } } - - Instruction::ImageTexelPointer { image, .. } => { - self.instruction_chain([], image); - } - - Instruction::ImageRead { image, .. } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain([inst_load], image)) - { - desc_reqs.memory_read = stage.into(); - } - } - - Instruction::ImageWrite { image, .. } => { - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain([inst_load], image)) - { - desc_reqs.memory_write = stage.into(); - } - } - - Instruction::Load { pointer, .. } => { - if let Some((binding_variable, index)) = - self.instruction_chain([], pointer) - { - // Only loads on buffers access memory directly. - // Loads on images load the image object itself, but don't touch - // the texels in memory yet. - if binding_variable.reqs.descriptor_types.iter().any(|ty| { - matches!( - ty, - DescriptorType::UniformBuffer - | DescriptorType::UniformBufferDynamic - | DescriptorType::StorageBuffer - | DescriptorType::StorageBufferDynamic - | DescriptorType::InlineUniformBlock - ) - }) { - if let Some(desc_reqs) = - desc_reqs(Some((binding_variable, index))) - { - desc_reqs.memory_read = stage.into(); - } - } - } - } - - Instruction::SampledImage { image, sampler, .. } => { - let identifier = match self.instruction_chain([inst_load], image) { - Some((variable, Some(index))) => DescriptorIdentifier { - set: variable.set, - binding: variable.binding, - index, - }, - _ => continue, - }; - - if let Some(desc_reqs) = - desc_reqs(self.instruction_chain([inst_load], sampler)) - { - desc_reqs.sampler_with_images.insert(identifier); - } - } - - Instruction::Store { pointer, .. } => { - // This can only apply to buffers, right? - if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) - { - desc_reqs.memory_write = stage.into(); - } - } - - _ => (), } + + Instruction::ImageDrefGather { sampled_image, .. } + | Instruction::ImageSparseDrefGather { sampled_image, .. } => { + if let Some(desc_reqs) = desc_reqs( + self.instruction_chain([inst_sampled_image, inst_load], sampled_image), + ) { + desc_reqs.memory_read = stage.into(); + desc_reqs.sampler_no_unnormalized_coordinates = true; + desc_reqs.sampler_no_ycbcr_conversion = true; + } + } + + Instruction::ImageSampleImplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSampleProjImplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleProjImplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleImplicitLod { + sampled_image, + image_operands, + .. + } => { + if let Some(desc_reqs) = desc_reqs( + self.instruction_chain([inst_sampled_image, inst_load], sampled_image), + ) { + desc_reqs.memory_read = stage.into(); + desc_reqs.sampler_no_unnormalized_coordinates = true; + + if image_operands.as_ref().map_or(false, |image_operands| { + image_operands.const_offset.is_some() + || image_operands.offset.is_some() + }) { + desc_reqs.sampler_no_ycbcr_conversion = true; + } + } + } + + Instruction::ImageSampleProjExplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleProjExplicitLod { + sampled_image, + image_operands, + .. + } => { + if let Some(desc_reqs) = desc_reqs( + self.instruction_chain([inst_sampled_image, inst_load], sampled_image), + ) { + desc_reqs.memory_read = stage.into(); + desc_reqs.sampler_no_unnormalized_coordinates = true; + + if image_operands.const_offset.is_some() + || image_operands.offset.is_some() + { + desc_reqs.sampler_no_ycbcr_conversion = true; + } + } + } + + Instruction::ImageSampleDrefImplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSampleProjDrefImplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleDrefImplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleProjDrefImplicitLod { + sampled_image, + image_operands, + .. + } => { + if let Some(desc_reqs) = desc_reqs( + self.instruction_chain([inst_sampled_image, inst_load], sampled_image), + ) { + desc_reqs.memory_read = stage.into(); + desc_reqs.sampler_no_unnormalized_coordinates = true; + desc_reqs.sampler_compare = true; + + if image_operands.as_ref().map_or(false, |image_operands| { + image_operands.const_offset.is_some() + || image_operands.offset.is_some() + }) { + desc_reqs.sampler_no_ycbcr_conversion = true; + } + } + } + + Instruction::ImageSampleDrefExplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSampleProjDrefExplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleDrefExplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleProjDrefExplicitLod { + sampled_image, + image_operands, + .. + } => { + if let Some(desc_reqs) = desc_reqs( + self.instruction_chain([inst_sampled_image, inst_load], sampled_image), + ) { + desc_reqs.memory_read = stage.into(); + desc_reqs.sampler_no_unnormalized_coordinates = true; + desc_reqs.sampler_compare = true; + + if image_operands.const_offset.is_some() + || image_operands.offset.is_some() + { + desc_reqs.sampler_no_ycbcr_conversion = true; + } + } + } + + Instruction::ImageSampleExplicitLod { + sampled_image, + image_operands, + .. + } + | Instruction::ImageSparseSampleExplicitLod { + sampled_image, + image_operands, + .. + } => { + if let Some(desc_reqs) = desc_reqs( + self.instruction_chain([inst_sampled_image, inst_load], sampled_image), + ) { + desc_reqs.memory_read = stage.into(); + + if image_operands.bias.is_some() + || image_operands.const_offset.is_some() + || image_operands.offset.is_some() + { + desc_reqs.sampler_no_unnormalized_coordinates = true; + } + + if image_operands.const_offset.is_some() + || image_operands.offset.is_some() + { + desc_reqs.sampler_no_ycbcr_conversion = true; + } + } + } + + Instruction::ImageTexelPointer { image, .. } => { + self.instruction_chain([], image); + } + + Instruction::ImageRead { image, .. } => { + if let Some(desc_reqs) = + desc_reqs(self.instruction_chain([inst_load], image)) + { + desc_reqs.memory_read = stage.into(); + } + } + + Instruction::ImageWrite { image, .. } => { + if let Some(desc_reqs) = + desc_reqs(self.instruction_chain([inst_load], image)) + { + desc_reqs.memory_write = stage.into(); + } + } + + Instruction::Load { pointer, .. } => { + if let Some((binding_variable, index)) = self.instruction_chain([], pointer) + { + // Only loads on buffers access memory directly. + // Loads on images load the image object itself, but don't touch + // the texels in memory yet. + if binding_variable.reqs.descriptor_types.iter().any(|ty| { + matches!( + ty, + DescriptorType::UniformBuffer + | DescriptorType::UniformBufferDynamic + | DescriptorType::StorageBuffer + | DescriptorType::StorageBufferDynamic + | DescriptorType::InlineUniformBlock + ) + }) { + if let Some(desc_reqs) = desc_reqs(Some((binding_variable, index))) + { + desc_reqs.memory_read = stage.into(); + } + } + } + } + + Instruction::SampledImage { image, sampler, .. } => { + let identifier = match self.instruction_chain([inst_load], image) { + Some((variable, Some(index))) => DescriptorIdentifier { + set: variable.set, + binding: variable.binding, + index, + }, + _ => continue, + }; + + if let Some(desc_reqs) = + desc_reqs(self.instruction_chain([inst_load], sampler)) + { + desc_reqs.sampler_with_images.insert(identifier); + } + } + + Instruction::Store { pointer, .. } => { + // This can only apply to buffers, right? + if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer)) { + desc_reqs.memory_write = stage.into(); + } + } + + _ => (), } } } @@ -1252,7 +1055,7 @@ fn push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option HashMap { +pub(super) fn specialization_constants(spirv: &Spirv) -> HashMap { let get_constant_id = |result_id| { spirv .id(result_id) @@ -1618,8 +1421,7 @@ fn shader_interface_type_of( } else { let mut ty = shader_interface_type_of(spirv, element_type, false); let num_elements = spirv - .instructions() - .iter() + .iter_global() .find_map(|instruction| match *instruction { Instruction::Constant { result_id, diff --git a/vulkano/src/shader/spirv.rs b/vulkano/src/shader/spirv.rs deleted file mode 100644 index b0c43cb3..00000000 --- a/vulkano/src/shader/spirv.rs +++ /dev/null @@ -1,837 +0,0 @@ -// Copyright (c) 2021 The Vulkano developers -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , -// at your option. All files in the project carrying such -// notice may not be copied, modified, or distributed except -// according to those terms. - -//! Parsing and analysis utilities for SPIR-V shader binaries. -//! -//! This can be used to inspect and validate a SPIR-V module at runtime. The `Spirv` type does some -//! validation, but you should not assume that code that is read successfully is valid. -//! -//! For more information about SPIR-V modules, instructions and types, see the -//! [SPIR-V specification](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html). - -use crate::Version; -use ahash::{HashMap, HashMapExt}; -use std::{ - borrow::Cow, - error::Error, - fmt::{Display, Error as FmtError, Formatter}, - ops::Range, - string::FromUtf8Error, -}; - -// Generated by build.rs -include!(concat!(env!("OUT_DIR"), "/spirv_parse.rs")); - -/// A parsed and analyzed SPIR-V module. -#[derive(Clone, Debug)] -pub struct Spirv { - version: Version, - bound: u32, - instructions: Vec, - ids: HashMap, - - // Items described in the spec section "Logical Layout of a Module" - range_capability: Range, - range_extension: Range, - range_ext_inst_import: Range, - memory_model: usize, - range_entry_point: Range, - range_execution_mode: Range, - range_name: Range, - range_decoration: Range, - range_global: Range, -} - -impl Spirv { - /// Parses a SPIR-V document from a list of words. - pub fn new(words: &[u32]) -> Result { - if words.len() < 5 { - return Err(SpirvError::InvalidHeader); - } - - if words[0] != 0x07230203 { - return Err(SpirvError::InvalidHeader); - } - - let version = Version { - major: (words[1] & 0x00ff0000) >> 16, - minor: (words[1] & 0x0000ff00) >> 8, - patch: words[1] & 0x000000ff, - }; - - let bound = words[3]; - - let instructions = { - let mut ret = Vec::new(); - let mut rest = &words[5..]; - while !rest.is_empty() { - let word_count = (rest[0] >> 16) as usize; - assert!(word_count >= 1); - - if rest.len() < word_count { - return Err(ParseError { - instruction: ret.len(), - word: rest.len(), - error: ParseErrors::UnexpectedEOF, - words: rest.to_owned(), - } - .into()); - } - - let mut reader = InstructionReader::new(&rest[0..word_count], ret.len()); - let instruction = Instruction::parse(&mut reader)?; - - if !reader.is_empty() { - return Err(reader.map_err(ParseErrors::LeftoverOperands).into()); - } - - ret.push(instruction); - rest = &rest[word_count..]; - } - ret - }; - - // It is impossible for a valid SPIR-V file to contain more Ids than instructions, so put - // a sane upper limit on the allocation. This prevents a malicious file from causing huge - // memory allocations. - let mut ids = HashMap::with_capacity(instructions.len().min(bound as usize)); - let mut range_capability: Option> = None; - let mut range_extension: Option> = None; - let mut range_ext_inst_import: Option> = None; - let mut range_memory_model: Option> = None; - let mut range_entry_point: Option> = None; - let mut range_execution_mode: Option> = None; - let mut range_name: Option> = None; - let mut range_decoration: Option> = None; - let mut range_global: Option> = None; - let mut in_function = false; - - fn set_range(range: &mut Option>, index: usize) -> Result<(), SpirvError> { - if let Some(range) = range { - if range.end != index { - return Err(SpirvError::BadLayout { index }); - } - - range.end = index + 1; - } else { - *range = Some(Range { - start: index, - end: index + 1, - }); - } - - Ok(()) - } - - for (index, instruction) in instructions.iter().enumerate() { - if let Some(id) = instruction.result_id() { - if u32::from(id) >= bound { - return Err(SpirvError::IdOutOfBounds { id, index, bound }); - } - - let members = if let Instruction::TypeStruct { member_types, .. } = instruction { - member_types - .iter() - .map(|_| StructMemberDataIndices::default()) - .collect() - } else { - Vec::new() - }; - let data = IdDataIndices { - index, - names: Vec::new(), - decorations: Vec::new(), - members, - }; - if let Some(first) = ids.insert(id, data) { - return Err(SpirvError::DuplicateId { - id, - first_index: first.index, - second_index: index, - }); - } - } - - match instruction { - Instruction::Capability { .. } => set_range(&mut range_capability, index)?, - Instruction::Extension { .. } => set_range(&mut range_extension, index)?, - Instruction::ExtInstImport { .. } => set_range(&mut range_ext_inst_import, index)?, - Instruction::MemoryModel { .. } => set_range(&mut range_memory_model, index)?, - Instruction::EntryPoint { .. } => set_range(&mut range_entry_point, index)?, - Instruction::ExecutionMode { .. } | Instruction::ExecutionModeId { .. } => { - set_range(&mut range_execution_mode, index)? - } - Instruction::Name { .. } | Instruction::MemberName { .. } => { - set_range(&mut range_name, index)? - } - Instruction::Decorate { .. } - | Instruction::MemberDecorate { .. } - | Instruction::DecorationGroup { .. } - | Instruction::GroupDecorate { .. } - | Instruction::GroupMemberDecorate { .. } - | Instruction::DecorateId { .. } - | Instruction::DecorateString { .. } - | Instruction::MemberDecorateString { .. } => { - set_range(&mut range_decoration, index)? - } - Instruction::TypeVoid { .. } - | Instruction::TypeBool { .. } - | Instruction::TypeInt { .. } - | Instruction::TypeFloat { .. } - | Instruction::TypeVector { .. } - | Instruction::TypeMatrix { .. } - | Instruction::TypeImage { .. } - | Instruction::TypeSampler { .. } - | Instruction::TypeSampledImage { .. } - | Instruction::TypeArray { .. } - | Instruction::TypeRuntimeArray { .. } - | Instruction::TypeStruct { .. } - | Instruction::TypeOpaque { .. } - | Instruction::TypePointer { .. } - | Instruction::TypeFunction { .. } - | Instruction::TypeEvent { .. } - | Instruction::TypeDeviceEvent { .. } - | Instruction::TypeReserveId { .. } - | Instruction::TypeQueue { .. } - | Instruction::TypePipe { .. } - | Instruction::TypeForwardPointer { .. } - | Instruction::TypePipeStorage { .. } - | Instruction::TypeNamedBarrier { .. } - | Instruction::TypeRayQueryKHR { .. } - | Instruction::TypeAccelerationStructureKHR { .. } - | Instruction::TypeCooperativeMatrixNV { .. } - | Instruction::TypeVmeImageINTEL { .. } - | Instruction::TypeAvcImePayloadINTEL { .. } - | Instruction::TypeAvcRefPayloadINTEL { .. } - | Instruction::TypeAvcSicPayloadINTEL { .. } - | Instruction::TypeAvcMcePayloadINTEL { .. } - | Instruction::TypeAvcMceResultINTEL { .. } - | Instruction::TypeAvcImeResultINTEL { .. } - | Instruction::TypeAvcImeResultSingleReferenceStreamoutINTEL { .. } - | Instruction::TypeAvcImeResultDualReferenceStreamoutINTEL { .. } - | Instruction::TypeAvcImeSingleReferenceStreaminINTEL { .. } - | Instruction::TypeAvcImeDualReferenceStreaminINTEL { .. } - | Instruction::TypeAvcRefResultINTEL { .. } - | Instruction::TypeAvcSicResultINTEL { .. } - | Instruction::ConstantTrue { .. } - | Instruction::ConstantFalse { .. } - | Instruction::Constant { .. } - | Instruction::ConstantComposite { .. } - | Instruction::ConstantSampler { .. } - | Instruction::ConstantNull { .. } - | Instruction::ConstantPipeStorage { .. } - | Instruction::SpecConstantTrue { .. } - | Instruction::SpecConstantFalse { .. } - | Instruction::SpecConstant { .. } - | Instruction::SpecConstantComposite { .. } - | Instruction::SpecConstantOp { .. } => set_range(&mut range_global, index)?, - Instruction::Undef { .. } if !in_function => set_range(&mut range_global, index)?, - Instruction::Variable { storage_class, .. } - if *storage_class != StorageClass::Function => - { - set_range(&mut range_global, index)? - } - Instruction::Function { .. } => { - in_function = true; - } - Instruction::Line { .. } | Instruction::NoLine { .. } => { - if !in_function { - set_range(&mut range_global, index)? - } - } - _ => (), - } - } - - let mut spirv = Spirv { - version, - bound, - instructions, - ids, - - range_capability: range_capability.unwrap_or_default(), - range_extension: range_extension.unwrap_or_default(), - range_ext_inst_import: range_ext_inst_import.unwrap_or_default(), - memory_model: if let Some(range) = range_memory_model { - if range.end - range.start != 1 { - return Err(SpirvError::MemoryModelInvalid); - } - - range.start - } else { - return Err(SpirvError::MemoryModelInvalid); - }, - range_entry_point: range_entry_point.unwrap_or_default(), - range_execution_mode: range_execution_mode.unwrap_or_default(), - range_name: range_name.unwrap_or_default(), - range_decoration: range_decoration.unwrap_or_default(), - range_global: range_global.unwrap_or_default(), - }; - - for index in spirv.range_name.clone() { - match &spirv.instructions[index] { - Instruction::Name { target, .. } => { - spirv.ids.get_mut(target).unwrap().names.push(index); - } - Instruction::MemberName { ty, member, .. } => { - spirv.ids.get_mut(ty).unwrap().members[*member as usize] - .names - .push(index); - } - _ => unreachable!(), - } - } - - // First handle all regular decorations, including those targeting decoration groups. - for index in spirv.range_decoration.clone() { - match &spirv.instructions[index] { - Instruction::Decorate { target, .. } - | Instruction::DecorateId { target, .. } - | Instruction::DecorateString { target, .. } => { - spirv.ids.get_mut(target).unwrap().decorations.push(index); - } - Instruction::MemberDecorate { - structure_type: target, - member, - .. - } - | Instruction::MemberDecorateString { - struct_type: target, - member, - .. - } => { - spirv.ids.get_mut(target).unwrap().members[*member as usize] - .decorations - .push(index); - } - _ => (), - } - } - - // Then, with decoration groups having their lists complete, handle group decorates. - for index in spirv.range_decoration.clone() { - match &spirv.instructions[index] { - Instruction::GroupDecorate { - decoration_group, - targets, - .. - } => { - let indices = { - let data = &spirv.ids[decoration_group]; - if !matches!( - spirv.instructions[data.index], - Instruction::DecorationGroup { .. } - ) { - return Err(SpirvError::GroupDecorateNotGroup { index }); - }; - data.decorations.clone() - }; - - for target in targets { - spirv - .ids - .get_mut(target) - .unwrap() - .decorations - .extend(&indices); - } - } - Instruction::GroupMemberDecorate { - decoration_group, - targets, - .. - } => { - let indices = { - let data = &spirv.ids[decoration_group]; - if !matches!( - spirv.instructions[data.index], - Instruction::DecorationGroup { .. } - ) { - return Err(SpirvError::GroupDecorateNotGroup { index }); - }; - data.decorations.clone() - }; - - for (target, member) in targets { - spirv.ids.get_mut(target).unwrap().members[*member as usize] - .decorations - .extend(&indices); - } - } - _ => (), - } - } - - Ok(spirv) - } - - /// Returns a reference to the instructions in the module. - #[inline] - pub fn instructions(&self) -> &[Instruction] { - &self.instructions - } - - /// Returns the SPIR-V version that the module is compiled for. - #[inline] - pub fn version(&self) -> Version { - self.version - } - - /// Returns the upper bound of `Id`s. All `Id`s should have a numeric value strictly less than - /// this value. - #[inline] - pub fn bound(&self) -> u32 { - self.bound - } - - /// Returns information about an `Id`. - /// - /// # Panics - /// - /// - Panics if `id` is not defined in this module. This can in theory only happpen if you are - /// mixing `Id`s from different modules. - #[inline] - pub fn id(&self, id: Id) -> IdInfo<'_> { - IdInfo { - data_indices: &self.ids[&id], - instructions: &self.instructions, - } - } - - /// Returns an iterator over all `Capability` instructions. - #[inline] - pub fn iter_capability(&self) -> impl ExactSizeIterator { - self.instructions[self.range_capability.clone()].iter() - } - - /// Returns an iterator over all `Extension` instructions. - #[inline] - pub fn iter_extension(&self) -> impl ExactSizeIterator { - self.instructions[self.range_extension.clone()].iter() - } - - /// Returns an iterator over all `ExtInstImport` instructions. - #[inline] - pub fn iter_ext_inst_import(&self) -> impl ExactSizeIterator { - self.instructions[self.range_ext_inst_import.clone()].iter() - } - - /// Returns the `MemoryModel` instruction. - #[inline] - pub fn memory_model(&self) -> &Instruction { - &self.instructions[self.memory_model] - } - - /// Returns an iterator over all `EntryPoint` instructions. - #[inline] - pub fn iter_entry_point(&self) -> impl ExactSizeIterator { - self.instructions[self.range_entry_point.clone()].iter() - } - - /// Returns an iterator over all execution mode instructions. - #[inline] - pub fn iter_execution_mode(&self) -> impl ExactSizeIterator { - self.instructions[self.range_execution_mode.clone()].iter() - } - - /// Returns an iterator over all name debug instructions. - #[inline] - pub fn iter_name(&self) -> impl ExactSizeIterator { - self.instructions[self.range_name.clone()].iter() - } - - /// Returns an iterator over all decoration instructions. - #[inline] - pub fn iter_decoration(&self) -> impl ExactSizeIterator { - self.instructions[self.range_decoration.clone()].iter() - } - - /// Returns an iterator over all global declaration instructions: types, - /// constants and global variables. - /// - /// Note: This can also include `Line` and `NoLine` instructions. - #[inline] - pub fn iter_global(&self) -> impl ExactSizeIterator { - self.instructions[self.range_global.clone()].iter() - } -} - -#[derive(Clone, Debug)] -struct IdDataIndices { - index: usize, - names: Vec, - decorations: Vec, - members: Vec, -} - -#[derive(Clone, Debug, Default)] -struct StructMemberDataIndices { - names: Vec, - decorations: Vec, -} - -/// Information associated with an `Id`. -#[derive(Clone, Debug)] -pub struct IdInfo<'a> { - data_indices: &'a IdDataIndices, - instructions: &'a [Instruction], -} - -impl<'a> IdInfo<'a> { - /// Returns the instruction that defines this `Id` with a `result_id` operand. - #[inline] - pub fn instruction(&self) -> &'a Instruction { - &self.instructions[self.data_indices.index] - } - - /// Returns an iterator over all name debug instructions that target this `Id`. - #[inline] - pub fn iter_name(&self) -> impl ExactSizeIterator { - let instructions = self.instructions; - self.data_indices - .names - .iter() - .map(move |&index| &instructions[index]) - } - - /// Returns an iterator over all decorate instructions, that target this `Id`. This includes any - /// decorate instructions that target this `Id` indirectly via a `DecorationGroup`. - #[inline] - pub fn iter_decoration(&self) -> impl ExactSizeIterator { - let instructions = self.instructions; - self.data_indices - .decorations - .iter() - .map(move |&index| &instructions[index]) - } - - /// If this `Id` refers to a `TypeStruct`, returns an iterator of information about each member - /// of the struct. Empty otherwise. - #[inline] - pub fn iter_members(&self) -> impl ExactSizeIterator> { - let instructions = self.instructions; - self.data_indices - .members - .iter() - .map(move |data_indices| StructMemberInfo { - data_indices, - instructions, - }) - } -} - -/// Information associated with a member of a `TypeStruct` instruction. -#[derive(Clone, Debug)] -pub struct StructMemberInfo<'a> { - data_indices: &'a StructMemberDataIndices, - instructions: &'a [Instruction], -} - -impl<'a> StructMemberInfo<'a> { - /// Returns an iterator over all name debug instructions that target this struct member. - #[inline] - pub fn iter_name(&self) -> impl ExactSizeIterator { - let instructions = self.instructions; - self.data_indices - .names - .iter() - .map(move |&index| &instructions[index]) - } - - /// Returns an iterator over all decorate instructions that target this struct member. This - /// includes any decorate instructions that target this member indirectly via a - /// `DecorationGroup`. - #[inline] - pub fn iter_decoration(&self) -> impl ExactSizeIterator { - let instructions = self.instructions; - self.data_indices - .decorations - .iter() - .map(move |&index| &instructions[index]) - } -} - -/// Used in SPIR-V to refer to the result of another instruction. -/// -/// Ids are global across a module, and are always assigned by exactly one instruction. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -#[repr(transparent)] -pub struct Id(u32); - -impl Id { - // Returns the raw numeric value of this Id. - #[inline] - pub const fn as_raw(self) -> u32 { - self.0 - } -} - -impl From for u32 { - #[inline] - fn from(id: Id) -> u32 { - id.as_raw() - } -} - -impl Display for Id { - #[inline] - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { - write!(f, "%{}", self.0) - } -} - -/// Helper type for parsing the words of an instruction. -#[derive(Debug)] -struct InstructionReader<'a> { - words: &'a [u32], - next_word: usize, - instruction: usize, -} - -impl<'a> InstructionReader<'a> { - /// Constructs a new reader from a slice of words for a single instruction, including the opcode - /// word. `instruction` is the number of the instruction currently being read, and is used for - /// error reporting. - fn new(words: &'a [u32], instruction: usize) -> Self { - debug_assert!(!words.is_empty()); - Self { - words, - next_word: 0, - instruction, - } - } - - /// Returns whether the reader has reached the end of the current instruction. - fn is_empty(&self) -> bool { - self.next_word >= self.words.len() - } - - /// Converts the `ParseErrors` enum to the `ParseError` struct, adding contextual information. - fn map_err(&self, error: ParseErrors) -> ParseError { - ParseError { - instruction: self.instruction, - word: self.next_word - 1, // -1 because the word has already been read - error, - words: self.words.to_owned(), - } - } - - /// Returns the next word in the sequence. - fn next_u32(&mut self) -> Result { - let word = *self.words.get(self.next_word).ok_or(ParseError { - instruction: self.instruction, - word: self.next_word, // No -1 because we didn't advance yet - error: ParseErrors::MissingOperands, - words: self.words.to_owned(), - })?; - self.next_word += 1; - - Ok(word) - } - - /* - /// Returns the next two words as a single `u64`. - #[inline] - fn next_u64(&mut self) -> Result { - Ok(self.next_u32()? as u64 | (self.next_u32()? as u64) << 32) - } - */ - - /// Reads a nul-terminated string. - fn next_string(&mut self) -> Result { - let mut bytes = Vec::new(); - loop { - let word = self.next_u32()?.to_le_bytes(); - - if let Some(nul) = word.iter().position(|&b| b == 0) { - bytes.extend(&word[0..nul]); - break; - } else { - bytes.extend(word); - } - } - String::from_utf8(bytes).map_err(|err| self.map_err(ParseErrors::FromUtf8Error(err))) - } - - /// Reads all remaining words. - fn remainder(&mut self) -> Vec { - let vec = self.words[self.next_word..].to_owned(); - self.next_word = self.words.len(); - vec - } -} - -/// Error that can happen when reading a SPIR-V module. -#[derive(Clone, Debug)] -pub enum SpirvError { - BadLayout { - index: usize, - }, - DuplicateId { - id: Id, - first_index: usize, - second_index: usize, - }, - GroupDecorateNotGroup { - index: usize, - }, - IdOutOfBounds { - id: Id, - index: usize, - bound: u32, - }, - InvalidHeader, - MemoryModelInvalid, - ParseError(ParseError), -} - -impl Display for SpirvError { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { - match self { - Self::BadLayout { index } => write!( - f, - "the instruction at index {} does not follow the logical layout of a module", - index, - ), - Self::DuplicateId { - id, - first_index, - second_index, - } => write!( - f, - "id {} is assigned more than once, by instructions {} and {}", - id, first_index, second_index, - ), - Self::GroupDecorateNotGroup { index } => write!( - f, - "a GroupDecorate or GroupMemberDecorate instruction at index {} referred to an Id \ - that was not a DecorationGroup", - index, - ), - Self::IdOutOfBounds { id, bound, index } => write!( - f, - "id {}, assigned at instruction {}, is not below the maximum bound {}", - id, index, bound, - ), - Self::InvalidHeader => write!(f, "the SPIR-V module header is invalid"), - Self::MemoryModelInvalid => { - write!(f, "the MemoryModel instruction is not present exactly once") - } - Self::ParseError(_) => write!(f, "parse error"), - } - } -} - -impl Error for SpirvError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - Self::ParseError(err) => Some(err), - _ => None, - } - } -} - -impl From for SpirvError { - fn from(err: ParseError) -> Self { - Self::ParseError(err) - } -} - -/// Error that can happen when parsing SPIR-V instructions into Rust data structures. -#[derive(Clone, Debug)] -pub struct ParseError { - /// The instruction number the error happened at, starting from 0. - pub instruction: usize, - /// The word from the start of the instruction that the error happened at, starting from 0. - pub word: usize, - /// The error. - pub error: ParseErrors, - /// The words of the instruction. - pub words: Vec, -} - -impl Display for ParseError { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { - write!( - f, - "at instruction {}, word {}: {}", - self.instruction, self.word, self.error, - ) - } -} - -impl Error for ParseError {} - -/// Individual types of parse error that can happen. -#[derive(Clone, Debug)] -pub enum ParseErrors { - FromUtf8Error(FromUtf8Error), - LeftoverOperands, - MissingOperands, - UnexpectedEOF, - UnknownEnumerant(&'static str, u32), - UnknownOpcode(u16), - UnknownSpecConstantOpcode(u16), -} - -impl Display for ParseErrors { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { - match self { - Self::FromUtf8Error(_) => write!(f, "invalid UTF-8 in string literal"), - Self::LeftoverOperands => write!(f, "unparsed operands remaining"), - Self::MissingOperands => write!( - f, - "the instruction and its operands require more words than are present in the \ - instruction", - ), - Self::UnexpectedEOF => write!(f, "encountered unexpected end of file"), - Self::UnknownEnumerant(ty, enumerant) => { - write!(f, "invalid enumerant {} for enum {}", enumerant, ty) - } - Self::UnknownOpcode(opcode) => write!(f, "invalid instruction opcode {}", opcode), - Self::UnknownSpecConstantOpcode(opcode) => { - write!(f, "invalid spec constant instruction opcode {}", opcode) - } - } - } -} - -/// Converts SPIR-V bytes to words. If necessary, the byte order is swapped from little-endian -/// to native-endian. -pub fn bytes_to_words(bytes: &[u8]) -> Result, SpirvBytesNotMultipleOf4> { - // If the current target is little endian, and the slice already has the right size and - // alignment, then we can just transmute the slice with bytemuck. - #[cfg(target_endian = "little")] - if let Ok(words) = bytemuck::try_cast_slice(bytes) { - return Ok(Cow::Borrowed(words)); - } - - if bytes.len() % 4 != 0 { - return Err(SpirvBytesNotMultipleOf4); - } - - // TODO: Use `slice::array_chunks` once it's stable. - let words: Vec = bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - .collect(); - - Ok(Cow::Owned(words)) -} - -#[derive(Clone, Copy, Debug, Default)] -pub struct SpirvBytesNotMultipleOf4; - -impl Error for SpirvBytesNotMultipleOf4 {} - -impl Display for SpirvBytesNotMultipleOf4 { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "the length of the provided slice is not a multiple of 4") - } -} diff --git a/vulkano/src/shader/spirv/mod.rs b/vulkano/src/shader/spirv/mod.rs new file mode 100644 index 00000000..06acfcbe --- /dev/null +++ b/vulkano/src/shader/spirv/mod.rs @@ -0,0 +1,871 @@ +// Copyright (c) 2021 The Vulkano developers +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , +// at your option. All files in the project carrying such +// notice may not be copied, modified, or distributed except +// according to those terms. + +//! Parsing and analysis utilities for SPIR-V shader binaries. +//! +//! This can be used to inspect and validate a SPIR-V module at runtime. The `Spirv` type does some +//! validation, but you should not assume that code that is read successfully is valid. +//! +//! For more information about SPIR-V modules, instructions and types, see the +//! [SPIR-V specification](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html). + +use crate::{shader::SpecializationConstant, Version}; +use ahash::HashMap; +use smallvec::{smallvec, SmallVec}; +use std::{ + borrow::Cow, + error::Error, + fmt::{Display, Error as FmtError, Formatter}, + string::FromUtf8Error, +}; + +mod specialization; + +// Generated by build.rs +include!(concat!(env!("OUT_DIR"), "/spirv_parse.rs")); + +/// A parsed and analyzed SPIR-V module. +#[derive(Clone, Debug)] +pub struct Spirv { + version: Version, + bound: u32, + ids: HashMap, + + // Items described in the spec section "Logical Layout of a Module" + instructions_capability: Vec, + instructions_extension: Vec, + instructions_ext_inst_import: Vec, + instruction_memory_model: Instruction, + instructions_entry_point: Vec, + instructions_execution_mode: Vec, + instructions_name: Vec, + instructions_decoration: Vec, + instructions_global: Vec, + functions: HashMap, +} + +impl Spirv { + /// Parses a SPIR-V document from a list of words. + pub fn new(words: &[u32]) -> Result { + if words.len() < 5 { + return Err(SpirvError::InvalidHeader); + } + + if words[0] != 0x07230203 { + return Err(SpirvError::InvalidHeader); + } + + let version = Version { + major: (words[1] & 0x00ff0000) >> 16, + minor: (words[1] & 0x0000ff00) >> 8, + patch: words[1] & 0x000000ff, + }; + + // For safety, we recalculate the bound ourselves. + let mut bound = 0; + let mut ids = HashMap::default(); + + let mut instructions_capability = Vec::new(); + let mut instructions_extension = Vec::new(); + let mut instructions_ext_inst_import = Vec::new(); + let mut instructions_memory_model = Vec::new(); + let mut instructions_entry_point = Vec::new(); + let mut instructions_execution_mode = Vec::new(); + let mut instructions_name = Vec::new(); + let mut instructions_decoration = Vec::new(); + let mut instructions_global = Vec::new(); + + let mut functions = HashMap::default(); + let mut current_function: Option<&mut Vec> = None; + + for instruction in iter_instructions(&words[5..]) { + let instruction = instruction?; + + if let Some(id) = instruction.result_id() { + bound = bound.max(u32::from(id) + 1); + + let members = if let Instruction::TypeStruct { + ref member_types, .. + } = instruction + { + member_types + .iter() + .map(|_| StructMemberInfo::default()) + .collect() + } else { + Vec::new() + }; + + let data = IdInfo { + instruction: instruction.clone(), + names: Vec::new(), + decorations: Vec::new(), + members, + }; + + if ids.insert(id, data).is_some() { + return Err(SpirvError::DuplicateId { id }); + } + } + + if matches!( + instruction, + Instruction::Line { .. } | Instruction::NoLine { .. } + ) { + continue; + } + + if current_function.is_some() { + match instruction { + Instruction::FunctionEnd { .. } => { + current_function.take().unwrap().push(instruction); + } + _ => current_function.as_mut().unwrap().push(instruction), + } + } else { + let destination = match instruction { + Instruction::Function { result_id, .. } => { + current_function = None; + let function = functions.entry(result_id).or_insert(FunctionInfo { + instructions: Vec::new(), + }); + current_function.insert(&mut function.instructions) + } + Instruction::Capability { .. } => &mut instructions_capability, + Instruction::Extension { .. } => &mut instructions_extension, + Instruction::ExtInstImport { .. } => &mut instructions_ext_inst_import, + Instruction::MemoryModel { .. } => &mut instructions_memory_model, + Instruction::EntryPoint { .. } => &mut instructions_entry_point, + Instruction::ExecutionMode { .. } | Instruction::ExecutionModeId { .. } => { + &mut instructions_execution_mode + } + Instruction::Name { .. } | Instruction::MemberName { .. } => { + &mut instructions_name + } + Instruction::Decorate { .. } + | Instruction::MemberDecorate { .. } + | Instruction::DecorationGroup { .. } + | Instruction::GroupDecorate { .. } + | Instruction::GroupMemberDecorate { .. } + | Instruction::DecorateId { .. } + | Instruction::DecorateString { .. } + | Instruction::MemberDecorateString { .. } => &mut instructions_decoration, + Instruction::TypeVoid { .. } + | Instruction::TypeBool { .. } + | Instruction::TypeInt { .. } + | Instruction::TypeFloat { .. } + | Instruction::TypeVector { .. } + | Instruction::TypeMatrix { .. } + | Instruction::TypeImage { .. } + | Instruction::TypeSampler { .. } + | Instruction::TypeSampledImage { .. } + | Instruction::TypeArray { .. } + | Instruction::TypeRuntimeArray { .. } + | Instruction::TypeStruct { .. } + | Instruction::TypeOpaque { .. } + | Instruction::TypePointer { .. } + | Instruction::TypeFunction { .. } + | Instruction::TypeEvent { .. } + | Instruction::TypeDeviceEvent { .. } + | Instruction::TypeReserveId { .. } + | Instruction::TypeQueue { .. } + | Instruction::TypePipe { .. } + | Instruction::TypeForwardPointer { .. } + | Instruction::TypePipeStorage { .. } + | Instruction::TypeNamedBarrier { .. } + | Instruction::TypeRayQueryKHR { .. } + | Instruction::TypeAccelerationStructureKHR { .. } + | Instruction::TypeCooperativeMatrixNV { .. } + | Instruction::TypeVmeImageINTEL { .. } + | Instruction::TypeAvcImePayloadINTEL { .. } + | Instruction::TypeAvcRefPayloadINTEL { .. } + | Instruction::TypeAvcSicPayloadINTEL { .. } + | Instruction::TypeAvcMcePayloadINTEL { .. } + | Instruction::TypeAvcMceResultINTEL { .. } + | Instruction::TypeAvcImeResultINTEL { .. } + | Instruction::TypeAvcImeResultSingleReferenceStreamoutINTEL { .. } + | Instruction::TypeAvcImeResultDualReferenceStreamoutINTEL { .. } + | Instruction::TypeAvcImeSingleReferenceStreaminINTEL { .. } + | Instruction::TypeAvcImeDualReferenceStreaminINTEL { .. } + | Instruction::TypeAvcRefResultINTEL { .. } + | Instruction::TypeAvcSicResultINTEL { .. } + | Instruction::ConstantTrue { .. } + | Instruction::ConstantFalse { .. } + | Instruction::Constant { .. } + | Instruction::ConstantComposite { .. } + | Instruction::ConstantSampler { .. } + | Instruction::ConstantNull { .. } + | Instruction::ConstantPipeStorage { .. } + | Instruction::SpecConstantTrue { .. } + | Instruction::SpecConstantFalse { .. } + | Instruction::SpecConstant { .. } + | Instruction::SpecConstantComposite { .. } + | Instruction::SpecConstantOp { .. } + | Instruction::Variable { .. } + | Instruction::Undef { .. } => &mut instructions_global, + _ => continue, + }; + + destination.push(instruction); + } + } + + let instruction_memory_model = instructions_memory_model.drain(..).next().unwrap(); + + // Add decorations to ids, + // while also expanding decoration groups into individual decorations. + let mut decoration_groups: HashMap> = HashMap::default(); + let instructions_decoration = instructions_decoration + .into_iter() + .flat_map(|instruction| -> SmallVec<[Instruction; 1]> { + match instruction { + Instruction::Decorate { target, .. } + | Instruction::DecorateId { target, .. } + | Instruction::DecorateString { target, .. } => { + let id_info = ids.get_mut(&target).unwrap(); + + if matches!(id_info.instruction(), Instruction::DecorationGroup { .. }) { + decoration_groups + .entry(target) + .or_default() + .push(instruction); + smallvec![] + } else { + id_info.decorations.push(instruction.clone()); + smallvec![instruction] + } + } + Instruction::MemberDecorate { + structure_type: target, + member, + .. + } + | Instruction::MemberDecorateString { + struct_type: target, + member, + .. + } => { + ids.get_mut(&target).unwrap().members[member as usize] + .decorations + .push(instruction.clone()); + smallvec![instruction] + } + Instruction::DecorationGroup { result_id } => { + // Drop the instruction altogether. + decoration_groups.entry(result_id).or_default(); + ids.remove(&result_id); + smallvec![] + } + Instruction::GroupDecorate { + decoration_group, + ref targets, + } => { + let decorations = &decoration_groups[&decoration_group]; + + (targets.iter().copied()) + .flat_map(|target| { + decorations + .iter() + .map(move |instruction| (target, instruction)) + }) + .map(|(target, instruction)| { + let id_info = ids.get_mut(&target).unwrap(); + + match instruction { + Instruction::Decorate { ref decoration, .. } => { + let instruction = Instruction::Decorate { + target, + decoration: decoration.clone(), + }; + id_info.decorations.push(instruction.clone()); + instruction + } + Instruction::DecorateId { ref decoration, .. } => { + let instruction = Instruction::DecorateId { + target, + decoration: decoration.clone(), + }; + id_info.decorations.push(instruction.clone()); + instruction + } + _ => unreachable!(), + } + }) + .collect() + } + Instruction::GroupMemberDecorate { + decoration_group, + ref targets, + } => { + let decorations = &decoration_groups[&decoration_group]; + + (targets.iter().copied()) + .flat_map(|target| { + decorations + .iter() + .map(move |instruction| (target, instruction)) + }) + .map(|((structure_type, member), instruction)| { + let member_info = + &mut ids.get_mut(&structure_type).unwrap().members + [member as usize]; + + match instruction { + Instruction::Decorate { ref decoration, .. } => { + let instruction = Instruction::MemberDecorate { + structure_type, + member, + decoration: decoration.clone(), + }; + member_info.decorations.push(instruction.clone()); + instruction + } + Instruction::DecorateId { .. } => { + panic!( + "a DecorateId instruction targets a decoration group, \ + and that decoration group is applied using a \ + GroupMemberDecorate instruction, but there is no \ + MemberDecorateId instruction" + ); + } + _ => unreachable!(), + } + }) + .collect() + } + _ => smallvec![instruction], + } + }) + .collect(); + + instructions_name.retain(|instruction| match *instruction { + Instruction::Name { target, .. } => { + if let Some(id_info) = ids.get_mut(&target) { + id_info.names.push(instruction.clone()); + true + } else { + false + } + } + Instruction::MemberName { ty, member, .. } => { + if let Some(id_info) = ids.get_mut(&ty) { + id_info.members[member as usize] + .names + .push(instruction.clone()); + true + } else { + false + } + } + _ => unreachable!(), + }); + + Ok(Spirv { + version, + bound, + ids, + + instructions_capability, + instructions_extension, + instructions_ext_inst_import, + instruction_memory_model, + instructions_entry_point, + instructions_execution_mode, + instructions_name, + instructions_decoration, + instructions_global, + functions, + }) + } + + /// Returns the SPIR-V version that the module is compiled for. + #[inline] + pub fn version(&self) -> Version { + self.version + } + + /// Returns information about an `Id`. + /// + /// # Panics + /// + /// - Panics if `id` is not defined in this module. This can in theory only happpen if you are + /// mixing `Id`s from different modules. + #[inline] + pub fn id(&self, id: Id) -> &IdInfo { + &self.ids[&id] + } + + /// Returns the function with the given `id`, if it exists. + /// + /// # Panics + /// + /// - Panics if `id` is not defined in this module. This can in theory only happpen if you are + /// mixing `Id`s from different modules. + #[inline] + pub fn function(&self, id: Id) -> &FunctionInfo { + &self.functions[&id] + } + + /// Returns an iterator over all `Capability` instructions. + #[inline] + pub fn iter_capability(&self) -> impl ExactSizeIterator { + self.instructions_capability.iter() + } + + /// Returns an iterator over all `Extension` instructions. + #[inline] + pub fn iter_extension(&self) -> impl ExactSizeIterator { + self.instructions_extension.iter() + } + + /// Returns an iterator over all `ExtInstImport` instructions. + #[inline] + pub fn iter_ext_inst_import(&self) -> impl ExactSizeIterator { + self.instructions_ext_inst_import.iter() + } + + /// Returns the `MemoryModel` instruction. + #[inline] + pub fn memory_model(&self) -> &Instruction { + &self.instruction_memory_model + } + + /// Returns an iterator over all `EntryPoint` instructions. + #[inline] + pub fn iter_entry_point(&self) -> impl ExactSizeIterator { + self.instructions_entry_point.iter() + } + + /// Returns an iterator over all execution mode instructions. + #[inline] + pub fn iter_execution_mode(&self) -> impl ExactSizeIterator { + self.instructions_execution_mode.iter() + } + + /// Returns an iterator over all name debug instructions. + #[inline] + pub fn iter_name(&self) -> impl ExactSizeIterator { + self.instructions_name.iter() + } + + /// Returns an iterator over all decoration instructions. + #[inline] + pub fn iter_decoration(&self) -> impl ExactSizeIterator { + self.instructions_decoration.iter() + } + + /// Returns an iterator over all global declaration instructions: types, + /// constants and global variables. + #[inline] + pub fn iter_global(&self) -> impl ExactSizeIterator { + self.instructions_global.iter() + } + + /// Returns an iterator over all functions. + #[inline] + pub fn iter_functions(&self) -> impl ExactSizeIterator { + self.functions.values() + } + + pub fn apply_specialization( + &mut self, + specialization_info: &HashMap, + ) { + self.instructions_global = specialization::replace_specialization_instructions( + specialization_info, + self.instructions_global.drain(..), + &self.ids, + self.bound, + ); + + for instruction in &self.instructions_global { + if let Some(id) = instruction.result_id() { + if let Some(id_info) = self.ids.get_mut(&id) { + id_info.instruction = instruction.clone(); + id_info.decorations.retain(|instruction| { + !matches!( + instruction, + Instruction::Decorate { + decoration: Decoration::SpecId { .. }, + .. + } + ) + }); + } else { + self.ids.insert( + id, + IdInfo { + instruction: instruction.clone(), + names: Vec::new(), + decorations: Vec::new(), + members: Vec::new(), + }, + ); + self.bound = self.bound.max(u32::from(id) + 1); + } + } + } + + self.instructions_decoration.retain(|instruction| { + !matches!( + instruction, + Instruction::Decorate { + decoration: Decoration::SpecId { .. }, + .. + } + ) + }); + } +} + +/// Used in SPIR-V to refer to the result of another instruction. +/// +/// Ids are global across a module, and are always assigned by exactly one instruction. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct Id(u32); + +impl Id { + // Returns the raw numeric value of this Id. + #[inline] + pub const fn as_raw(self) -> u32 { + self.0 + } +} + +impl From for u32 { + #[inline] + fn from(id: Id) -> u32 { + id.as_raw() + } +} + +impl Display for Id { + #[inline] + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { + write!(f, "%{}", self.0) + } +} + +/// Information associated with an `Id`. +#[derive(Clone, Debug)] +pub struct IdInfo { + instruction: Instruction, + names: Vec, + decorations: Vec, + members: Vec, +} + +impl IdInfo { + /// Returns the instruction that defines this `Id` with a `result_id` operand. + #[inline] + pub fn instruction(&self) -> &Instruction { + &self.instruction + } + + /// Returns an iterator over all name debug instructions that target this `Id`. + #[inline] + pub fn iter_name(&self) -> impl ExactSizeIterator { + self.names.iter() + } + + /// Returns an iterator over all decorate instructions, that target this `Id`. + #[inline] + pub fn iter_decoration(&self) -> impl ExactSizeIterator { + self.decorations.iter() + } + + /// If this `Id` refers to a `TypeStruct`, returns an iterator of information about each member + /// of the struct. Empty otherwise. + #[inline] + pub fn iter_members(&self) -> impl ExactSizeIterator { + self.members.iter() + } +} + +/// Information associated with a member of a `TypeStruct` instruction. +#[derive(Clone, Debug, Default)] +pub struct StructMemberInfo { + names: Vec, + decorations: Vec, +} + +impl StructMemberInfo { + /// Returns an iterator over all name debug instructions that target this struct member. + #[inline] + pub fn iter_name(&self) -> impl ExactSizeIterator { + self.names.iter() + } + + /// Returns an iterator over all decorate instructions that target this struct member. + #[inline] + pub fn iter_decoration(&self) -> impl ExactSizeIterator { + self.decorations.iter() + } +} + +/// Information associated with a function. +#[derive(Clone, Debug, Default)] +pub struct FunctionInfo { + instructions: Vec, +} + +impl FunctionInfo { + /// Returns an iterator over all instructions in the function. + #[inline] + pub fn iter_instructions(&self) -> impl ExactSizeIterator { + self.instructions.iter() + } +} + +fn iter_instructions( + mut words: &[u32], +) -> impl Iterator> + '_ { + let mut index = 0; + let next = move || -> Option> { + if words.is_empty() { + return None; + } + + let word_count = (words[0] >> 16) as usize; + assert!(word_count >= 1); + + if words.len() < word_count { + return Some(Err(ParseError { + instruction: index, + word: words.len(), + error: ParseErrors::UnexpectedEOF, + words: words.to_owned(), + })); + } + + let mut reader = InstructionReader::new(&words[0..word_count], index); + let instruction = match Instruction::parse(&mut reader) { + Ok(x) => x, + Err(err) => return Some(Err(err)), + }; + + if !reader.is_empty() { + return Some(Err(reader.map_err(ParseErrors::LeftoverOperands))); + } + + words = &words[word_count..]; + index += 1; + Some(Ok(instruction)) + }; + + std::iter::from_fn(next) +} + +/// Helper type for parsing the words of an instruction. +#[derive(Debug)] +struct InstructionReader<'a> { + words: &'a [u32], + next_word: usize, + instruction: usize, +} + +impl<'a> InstructionReader<'a> { + /// Constructs a new reader from a slice of words for a single instruction, including the opcode + /// word. `instruction` is the number of the instruction currently being read, and is used for + /// error reporting. + fn new(words: &'a [u32], instruction: usize) -> Self { + debug_assert!(!words.is_empty()); + Self { + words, + next_word: 0, + instruction, + } + } + + /// Returns whether the reader has reached the end of the current instruction. + fn is_empty(&self) -> bool { + self.next_word >= self.words.len() + } + + /// Converts the `ParseErrors` enum to the `ParseError` struct, adding contextual information. + fn map_err(&self, error: ParseErrors) -> ParseError { + ParseError { + instruction: self.instruction, + word: self.next_word - 1, // -1 because the word has already been read + error, + words: self.words.to_owned(), + } + } + + /// Returns the next word in the sequence. + fn next_word(&mut self) -> Result { + let word = *self.words.get(self.next_word).ok_or(ParseError { + instruction: self.instruction, + word: self.next_word, // No -1 because we didn't advance yet + error: ParseErrors::MissingOperands, + words: self.words.to_owned(), + })?; + self.next_word += 1; + + Ok(word) + } + + /* + /// Returns the next two words as a single `u64`. + #[inline] + fn next_u64(&mut self) -> Result { + Ok(self.next_word()? as u64 | (self.next_word()? as u64) << 32) + } + */ + + /// Reads a nul-terminated string. + fn next_string(&mut self) -> Result { + let mut bytes = Vec::new(); + loop { + let word = self.next_word()?.to_le_bytes(); + + if let Some(nul) = word.iter().position(|&b| b == 0) { + bytes.extend(&word[0..nul]); + break; + } else { + bytes.extend(word); + } + } + String::from_utf8(bytes).map_err(|err| self.map_err(ParseErrors::FromUtf8Error(err))) + } + + /// Reads all remaining words. + fn remainder(&mut self) -> Vec { + let vec = self.words[self.next_word..].to_owned(); + self.next_word = self.words.len(); + vec + } +} + +/// Error that can happen when reading a SPIR-V module. +#[derive(Clone, Debug)] +pub enum SpirvError { + DuplicateId { id: Id }, + InvalidHeader, + ParseError(ParseError), +} + +impl Display for SpirvError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { + match self { + Self::DuplicateId { id } => write!(f, "id {} is assigned more than once", id,), + Self::InvalidHeader => write!(f, "the SPIR-V module header is invalid"), + Self::ParseError(_) => write!(f, "parse error"), + } + } +} + +impl Error for SpirvError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::ParseError(err) => Some(err), + _ => None, + } + } +} + +impl From for SpirvError { + fn from(err: ParseError) -> Self { + Self::ParseError(err) + } +} + +/// Error that can happen when parsing SPIR-V instructions into Rust data structures. +#[derive(Clone, Debug)] +pub struct ParseError { + /// The instruction number the error happened at, starting from 0. + pub instruction: usize, + /// The word from the start of the instruction that the error happened at, starting from 0. + pub word: usize, + /// The error. + pub error: ParseErrors, + /// The words of the instruction. + pub words: Vec, +} + +impl Display for ParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { + write!( + f, + "at instruction {}, word {}: {}", + self.instruction, self.word, self.error, + ) + } +} + +impl Error for ParseError {} + +/// Individual types of parse error that can happen. +#[derive(Clone, Debug)] +pub enum ParseErrors { + FromUtf8Error(FromUtf8Error), + LeftoverOperands, + MissingOperands, + UnexpectedEOF, + UnknownEnumerant(&'static str, u32), + UnknownOpcode(u16), + UnknownSpecConstantOpcode(u16), +} + +impl Display for ParseErrors { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> { + match self { + Self::FromUtf8Error(_) => write!(f, "invalid UTF-8 in string literal"), + Self::LeftoverOperands => write!(f, "unparsed operands remaining"), + Self::MissingOperands => write!( + f, + "the instruction and its operands require more words than are present in the \ + instruction", + ), + Self::UnexpectedEOF => write!(f, "encountered unexpected end of file"), + Self::UnknownEnumerant(ty, enumerant) => { + write!(f, "invalid enumerant {} for enum {}", enumerant, ty) + } + Self::UnknownOpcode(opcode) => write!(f, "invalid instruction opcode {}", opcode), + Self::UnknownSpecConstantOpcode(opcode) => { + write!(f, "invalid spec constant instruction opcode {}", opcode) + } + } + } +} + +/// Converts SPIR-V bytes to words. If necessary, the byte order is swapped from little-endian +/// to native-endian. +pub fn bytes_to_words(bytes: &[u8]) -> Result, SpirvBytesNotMultipleOf4> { + // If the current target is little endian, and the slice already has the right size and + // alignment, then we can just transmute the slice with bytemuck. + #[cfg(target_endian = "little")] + if let Ok(words) = bytemuck::try_cast_slice(bytes) { + return Ok(Cow::Borrowed(words)); + } + + if bytes.len() % 4 != 0 { + return Err(SpirvBytesNotMultipleOf4); + } + + // TODO: Use `slice::array_chunks` once it's stable. + let words: Vec = bytes + .chunks_exact(4) + .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) + .collect(); + + Ok(Cow::Owned(words)) +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct SpirvBytesNotMultipleOf4; + +impl Error for SpirvBytesNotMultipleOf4 {} + +impl Display for SpirvBytesNotMultipleOf4 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "the length of the provided slice is not a multiple of 4") + } +} diff --git a/vulkano/src/shader/spirv/specialization.rs b/vulkano/src/shader/spirv/specialization.rs new file mode 100644 index 00000000..4f990e26 --- /dev/null +++ b/vulkano/src/shader/spirv/specialization.rs @@ -0,0 +1,730 @@ +// Copyright (c) 2023 The Vulkano developers +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , +// at your option. All files in the project carrying such +// notice may not be copied, modified, or distributed except +// according to those terms. + +use crate::shader::{ + spirv::{Decoration, Id, IdInfo, Instruction, SpecConstantInstruction}, + SpecializationConstant, +}; +use ahash::HashMap; +use half::f16; +use smallvec::{smallvec, SmallVec}; + +/// Go through all the specialization constant instructions, +/// and updates their values and replaces them with regular constants. +pub(super) fn replace_specialization_instructions( + specialization_info: &HashMap, + instructions_global: impl IntoIterator, + ids: &HashMap, + mut next_new_id: u32, +) -> Vec { + let get_specialization = |id: Id| -> Option { + ids[&id] + .decorations + .iter() + .find_map(|instruction| match instruction { + Instruction::Decorate { + decoration: + Decoration::SpecId { + specialization_constant_id, + }, + .. + } => specialization_info.get(specialization_constant_id).copied(), + _ => None, + }) + }; + + // This stores the constants we've seen so far. Since composite and op constants must + // use constants that were defined earlier, this works. + let mut constants: HashMap = HashMap::default(); + + instructions_global + .into_iter() + .flat_map(|instruction| -> SmallVec<[Instruction; 1]> { + let new_instructions: SmallVec<[Instruction; 1]> = match instruction { + Instruction::SpecConstantFalse { + result_type_id, + result_id, + } + | Instruction::SpecConstantTrue { + result_type_id, + result_id, + } => { + let value = get_specialization(result_id).map_or_else( + || matches!(instruction, Instruction::SpecConstantTrue { .. }), + |sc| matches!(sc, SpecializationConstant::Bool(true)), + ); + let new_instruction = if value { + Instruction::ConstantTrue { + result_type_id, + result_id, + } + } else { + Instruction::ConstantFalse { + result_type_id, + result_id, + } + }; + + smallvec![new_instruction] + } + Instruction::SpecConstant { + result_type_id, + result_id, + ref value, + } => { + if let Some(specialization) = get_specialization(result_id) { + smallvec![Instruction::Constant { + result_type_id, + result_id, + value: match specialization { + SpecializationConstant::Bool(_) => unreachable!(), + SpecializationConstant::U8(num) => vec![num as u32], + SpecializationConstant::U16(num) => vec![num as u32], + SpecializationConstant::U32(num) => vec![num], + SpecializationConstant::U64(num) => + vec![num as u32, (num >> 32) as u32], + SpecializationConstant::I8(num) => vec![num as u32], + SpecializationConstant::I16(num) => vec![num as u32], + SpecializationConstant::I32(num) => vec![num as u32], + SpecializationConstant::I64(num) => + vec![num as u32, (num >> 32) as u32], + SpecializationConstant::F16(num) => vec![num.to_bits() as u32], + SpecializationConstant::F32(num) => vec![num.to_bits()], + SpecializationConstant::F64(num) => { + let num = num.to_bits(); + vec![num as u32, (num >> 32) as u32] + } + }, + }] + } else { + smallvec![Instruction::Constant { + result_type_id, + result_id, + value: value.clone(), + }] + } + } + Instruction::SpecConstantComposite { + result_type_id, + result_id, + ref constituents, + } => { + smallvec![Instruction::ConstantComposite { + result_type_id, + result_id, + constituents: constituents.clone(), + }] + } + Instruction::SpecConstantOp { + result_type_id, + result_id, + ref opcode, + } => evaluate_spec_constant_op( + &mut next_new_id, + ids, + &constants, + result_type_id, + result_id, + opcode, + ), + _ => smallvec![instruction], + }; + + for instruction in &new_instructions { + match *instruction { + Instruction::ConstantFalse { + result_type_id, + result_id, + .. + } => { + constants.insert( + result_id, + Constant { + type_id: result_type_id, + value: ConstantValue::Scalar(0), + }, + ); + } + Instruction::ConstantTrue { + result_type_id, + result_id, + .. + } => { + constants.insert( + result_id, + Constant { + type_id: result_type_id, + value: ConstantValue::Scalar(1), + }, + ); + } + Instruction::Constant { + result_type_id, + result_id, + ref value, + } => { + let constant_value = match *ids[&result_type_id].instruction() { + Instruction::TypeInt { + width, signedness, .. + } => { + if width == 64 { + assert!(value.len() == 2); + } else { + assert!(value.len() == 1); + } + + match (signedness, width) { + (0, 8) => value[0] as u64, + (0, 16) => value[0] as u64, + (0, 32) => value[0] as u64, + (0, 64) => (value[0] as u64) | ((value[1] as u64) << 32), + (1, 8) => value[0] as i32 as u64, + (1, 16) => value[0] as i32 as u64, + (1, 32) => value[0] as i32 as u64, + (1, 64) => { + ((value[0] as i64) | ((value[1] as i64) << 32)) as u64 + } + _ => unimplemented!(), + } + } + Instruction::TypeFloat { width, .. } => { + if width == 64 { + assert!(value.len() == 2); + } else { + assert!(value.len() == 1); + } + + match width { + 16 => f16::from_bits(value[0] as u16).to_f64() as u64, + 32 => f32::from_bits(value[0]) as f64 as u64, + 64 => f64::from_bits( + (value[0] as u64) | ((value[1] as u64) << 32), + ) as u64, + _ => unimplemented!(), + } + } + _ => unreachable!(), + }; + + constants.insert( + result_id, + Constant { + type_id: result_type_id, + value: ConstantValue::Scalar(constant_value), + }, + ); + } + Instruction::ConstantComposite { + result_type_id, + result_id, + ref constituents, + } => { + constants.insert( + result_id, + Constant { + type_id: result_type_id, + value: ConstantValue::Composite(constituents.as_slice().into()), + }, + ); + } + _ => (), + } + } + + new_instructions + }) + .collect() +} + +struct Constant { + type_id: Id, + value: ConstantValue, +} + +#[derive(Clone)] +enum ConstantValue { + // All scalar constants are stored as u64, regardless of their original type. + // They are converted from and to their actual representation + // when they are first read or when they are written back. + // Signed integers are sign extended to i64 first, floats are cast to f64 and then + // bit-converted. + Scalar(u64), + + Composite(SmallVec<[Id; 4]>), +} + +impl ConstantValue { + fn as_scalar(&self) -> u64 { + match self { + Self::Scalar(val) => *val, + Self::Composite(_) => panic!("called `as_scalar` on a composite value"), + } + } + + fn as_composite(&self) -> &[Id] { + match self { + Self::Scalar(_) => panic!("called `as_composite` on a scalar value"), + Self::Composite(val) => val, + } + } +} + +fn numeric_constant_to_words( + constant_type: &Instruction, + constant_value: u64, +) -> SmallVec<[u32; 2]> { + match *constant_type { + Instruction::TypeInt { + width, signedness, .. + } => match (signedness, width) { + (0, 8) => smallvec![constant_value as u8 as u32], + (0, 16) => smallvec![constant_value as u16 as u32], + (0, 32) => smallvec![constant_value as u32], + (0, 64) => smallvec![constant_value as u32, (constant_value >> 32) as u32], + (1, 8) => smallvec![constant_value as i8 as u32], + (1, 16) => smallvec![constant_value as i16 as u32], + (1, 32) => smallvec![constant_value as u32], + (1, 64) => smallvec![constant_value as u32, (constant_value >> 32) as u32], + _ => unimplemented!(), + }, + Instruction::TypeFloat { width, .. } => match width { + 16 => smallvec![f16::from_f64(f64::from_bits(constant_value)).to_bits() as u32], + 32 => smallvec![(f64::from_bits(constant_value) as f32).to_bits()], + 64 => smallvec![constant_value as u32, (constant_value >> 32) as u32], + _ => unimplemented!(), + }, + _ => unreachable!(), + } +} + +// Evaluate a SpecConstantInstruction. +fn evaluate_spec_constant_op( + next_new_id: &mut u32, + ids: &HashMap, + constants: &HashMap, + result_type_id: Id, + result_id: Id, + opcode: &SpecConstantInstruction, +) -> SmallVec<[Instruction; 1]> { + let scalar_constant_to_instruction = + |constant_type_id: Id, constant_id: Id, constant_value: u64| -> Instruction { + match *ids[&constant_type_id].instruction() { + Instruction::TypeBool { .. } => { + if constant_value != 0 { + Instruction::ConstantTrue { + result_type_id: constant_type_id, + result_id: constant_id, + } + } else { + Instruction::ConstantFalse { + result_type_id: constant_type_id, + result_id: constant_id, + } + } + } + ref result_type @ (Instruction::TypeInt { .. } | Instruction::TypeFloat { .. }) => { + Instruction::Constant { + result_type_id: constant_type_id, + result_id: constant_id, + value: numeric_constant_to_words(result_type, constant_value).to_vec(), + } + } + _ => unreachable!(), + } + }; + + let constant_to_instruction = |constant_id: Id| -> SmallVec<[Instruction; 1]> { + let constant = &constants[&constant_id]; + debug_assert_eq!(constant.type_id, result_type_id); + + match constant.value { + ConstantValue::Scalar(value) => smallvec![scalar_constant_to_instruction( + result_type_id, + result_id, + value + )], + ConstantValue::Composite(ref constituents) => { + smallvec![Instruction::ConstantComposite { + result_type_id, + result_id, + constituents: constituents.to_vec(), + }] + } + } + }; + + match *opcode { + SpecConstantInstruction::VectorShuffle { + vector_1, + vector_2, + ref components, + } => { + let vector_1 = constants[&vector_1].value.as_composite(); + let vector_2 = constants[&vector_2].value.as_composite(); + let concatenated: SmallVec<[Id; 8]> = + vector_1.iter().chain(vector_2.iter()).copied().collect(); + let constituents: SmallVec<[Id; 4]> = components + .iter() + .map(|&component| { + concatenated[if component == 0xFFFFFFFF { + 0 // Spec says the value is undefined, so we can pick anything. + } else { + component as usize + }] + }) + .collect(); + + smallvec![Instruction::ConstantComposite { + result_type_id, + result_id, + constituents: constituents.to_vec(), + }] + } + SpecConstantInstruction::CompositeExtract { + composite, + ref indexes, + } => { + // Go through the index chain to find the Id to extract. + let id = indexes.iter().fold(composite, |current_id, &index| { + constants[¤t_id].value.as_composite()[index as usize] + }); + + constant_to_instruction(id) + } + SpecConstantInstruction::CompositeInsert { + object, + composite, + ref indexes, + } => { + let new_id_count = indexes.len() as u32 - 1; + let new_ids = (0..new_id_count).map(|i| Id(*next_new_id + i)); + + // Go down the type tree, starting from the top-level type `composite`. + let mut old_constituent_id = composite; + + let new_result_ids = std::iter::once(result_id).chain(new_ids.clone()); + let new_constituent_ids = new_ids.chain(std::iter::once(object)); + + let mut new_instructions: SmallVec<_> = (indexes.iter().copied()) + .zip(new_result_ids.zip(new_constituent_ids)) + .map(|(index, (new_result_id, new_constituent_id))| { + let constant = &constants[&old_constituent_id]; + + // Get the Id of the original constituent value to iterate further, + // then replace it with the new Id. + let mut constituents = constant.value.as_composite().to_vec(); + old_constituent_id = constituents[index as usize]; + constituents[index as usize] = new_constituent_id; + + Instruction::ConstantComposite { + result_type_id: constant.type_id, + result_id: new_result_id, + constituents, + } + }) + .collect(); + + *next_new_id += new_id_count; + new_instructions.reverse(); // so that new constants are defined before use + new_instructions + } + SpecConstantInstruction::Select { + condition, + object_1, + object_2, + } => match constants[&condition].value { + ConstantValue::Scalar(condition) => { + let result = if condition != 0 { object_1 } else { object_2 }; + + constant_to_instruction(result) + } + ConstantValue::Composite(ref conditions) => { + let object_1 = constants[&object_1].value.as_composite(); + let object_2 = constants[&object_2].value.as_composite(); + + assert_eq!(conditions.len(), object_1.len()); + assert_eq!(conditions.len(), object_2.len()); + + let constituents: SmallVec<[Id; 4]> = + (conditions.iter().map(|c| constants[c].value.as_scalar())) + .zip(object_1.iter().zip(object_2.iter())) + .map( + |(condition, (&object_1, &object_2))| { + if condition != 0 { + object_1 + } else { + object_2 + } + }, + ) + .collect(); + + smallvec![Instruction::ConstantComposite { + result_type_id, + result_id, + constituents: constituents.to_vec(), + }] + } + }, + SpecConstantInstruction::UConvert { + unsigned_value: value, + } + | SpecConstantInstruction::SConvert { + signed_value: value, + } + | SpecConstantInstruction::FConvert { float_value: value } => { + constant_to_instruction(value) + } + _ => { + let result = evaluate_spec_constant_calculation_op(opcode, constants); + + if let &[result] = result.as_slice() { + smallvec![scalar_constant_to_instruction( + result_type_id, + result_id, + result, + )] + } else { + let component_type_id = match *ids[&result_type_id].instruction() { + Instruction::TypeVector { component_type, .. } => component_type, + _ => unreachable!(), + }; + + // We have to create new constants with new ids, + // to hold each component of the result. + // In theory, we could go digging among the existing constants to see if any + // of them already fit... + let new_id_count = result.len() as u32; + let new_instructions = result + .into_iter() + .enumerate() + .map(|(i, result)| { + scalar_constant_to_instruction( + component_type_id, + Id(*next_new_id + i as u32), + result, + ) + }) + .chain(std::iter::once(Instruction::ConstantComposite { + result_type_id, + result_id, + constituents: (0..new_id_count).map(|i| Id(*next_new_id + i)).collect(), + })) + .collect(); + *next_new_id += new_id_count; + new_instructions + } + } + } +} + +// Evaluate a SpecConstantInstruction that does calculations on scalars or paired vector components. +fn evaluate_spec_constant_calculation_op( + instruction: &SpecConstantInstruction, + constants: &HashMap, +) -> SmallVec<[u64; 4]> { + let unary_op = |operand: Id, op: fn(u64) -> u64| -> SmallVec<[u64; 4]> { + match constants[&operand].value { + ConstantValue::Scalar(operand) => smallvec![op(operand)], + ConstantValue::Composite(ref constituents) => constituents + .iter() + .map(|constituent| { + let operand = constants[constituent].value.as_scalar(); + op(operand) + }) + .collect(), + } + }; + + let binary_op = |operand1: Id, operand2: Id, op: fn(u64, u64) -> u64| -> SmallVec<[u64; 4]> { + match (&constants[&operand1].value, &constants[&operand2].value) { + (&ConstantValue::Scalar(operand1), &ConstantValue::Scalar(operand2)) => { + smallvec![op(operand1, operand2)] + } + (ConstantValue::Composite(constituents1), ConstantValue::Composite(constituents2)) => { + assert_eq!(constituents1.len(), constituents2.len()); + (constituents1.iter()) + .zip(constituents2.iter()) + .map(|(constituent1, constituent2)| { + let operand1 = constants[constituent1].value.as_scalar(); + let operand2 = constants[constituent2].value.as_scalar(); + op(operand1, operand2) + }) + .collect() + } + _ => unreachable!(), + } + }; + + match *instruction { + SpecConstantInstruction::VectorShuffle { .. } + | SpecConstantInstruction::CompositeExtract { .. } + | SpecConstantInstruction::CompositeInsert { .. } + | SpecConstantInstruction::Select { .. } + | SpecConstantInstruction::UConvert { .. } + | SpecConstantInstruction::SConvert { .. } + | SpecConstantInstruction::FConvert { .. } => unreachable!(), + SpecConstantInstruction::SNegate { operand } => { + unary_op(operand, |operand| operand.wrapping_neg()) + } + SpecConstantInstruction::IAdd { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + operand1.wrapping_add(operand2) + }) + } + SpecConstantInstruction::ISub { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + operand1.wrapping_sub(operand2) + }) + } + SpecConstantInstruction::IMul { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + operand1.wrapping_mul(operand2) + }) + } + SpecConstantInstruction::UDiv { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + operand1.wrapping_div(operand2) + }) + } + SpecConstantInstruction::UMod { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + operand1.wrapping_rem(operand2) + }) + } + SpecConstantInstruction::SDiv { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + let operand1 = operand1 as i64; + let operand2 = operand2 as i64; + operand1.wrapping_div(operand2) as u64 + }) + } + SpecConstantInstruction::SRem { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + let operand1 = operand1 as i64; + let operand2 = operand2 as i64; + operand1.wrapping_rem(operand2) as u64 + }) + } + SpecConstantInstruction::SMod { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + let operand1 = operand1 as i64; + let operand2 = operand2 as i64; + ((operand1.wrapping_rem(operand2) + operand2) % operand2) as u64 + }) + } + SpecConstantInstruction::LogicalEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + ((operand1 != 0) == (operand2 != 0)) as u64 + }) + } + SpecConstantInstruction::LogicalNotEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + ((operand1 != 0) != (operand2 != 0)) as u64 + }) + } + SpecConstantInstruction::LogicalOr { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 != 0 || operand2 != 0) as u64 + }) + } + SpecConstantInstruction::LogicalAnd { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 != 0 && operand2 != 0) as u64 + }) + } + SpecConstantInstruction::LogicalNot { operand } => { + unary_op(operand, |operand| (operand == 0) as u64) + } + SpecConstantInstruction::IEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 == operand2) as u64 + }) + } + SpecConstantInstruction::INotEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 != operand2) as u64 + }) + } + SpecConstantInstruction::UGreaterThan { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 > operand2) as u64 + }) + } + SpecConstantInstruction::SGreaterThan { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + let operand1 = operand1 as i64; + let operand2 = operand2 as i64; + (operand1 > operand2) as u64 + }) + } + SpecConstantInstruction::UGreaterThanEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 >= operand2) as u64 + }) + } + SpecConstantInstruction::SGreaterThanEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + let operand1 = operand1 as i64; + let operand2 = operand2 as i64; + (operand1 >= operand2) as u64 + }) + } + SpecConstantInstruction::ULessThan { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 < operand2) as u64 + }) + } + SpecConstantInstruction::SLessThan { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + let operand1 = operand1 as i64; + let operand2 = operand2 as i64; + (operand1 < operand2) as u64 + }) + } + SpecConstantInstruction::ULessThanEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + (operand1 <= operand2) as u64 + }) + } + SpecConstantInstruction::SLessThanEqual { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| { + let operand1 = operand1 as i64; + let operand2 = operand2 as i64; + (operand1 <= operand2) as u64 + }) + } + SpecConstantInstruction::ShiftRightLogical { base, shift } => { + binary_op(base, shift, |base, shift| base >> shift) + } + SpecConstantInstruction::ShiftRightArithmetic { base, shift } => { + binary_op(base, shift, |base, shift| { + let base = base as i64; + (base >> shift) as u64 + }) + } + SpecConstantInstruction::ShiftLeftLogical { base, shift } => { + binary_op(base, shift, |base, shift| base << shift) + } + SpecConstantInstruction::BitwiseOr { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| operand1 | operand2) + } + SpecConstantInstruction::BitwiseXor { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| operand1 ^ operand2) + } + SpecConstantInstruction::BitwiseAnd { operand1, operand2 } => { + binary_op(operand1, operand2, |operand1, operand2| operand1 & operand2) + } + SpecConstantInstruction::Not { operand } => unary_op(operand, |operand| !operand), + SpecConstantInstruction::QuantizeToF16 { value } => unary_op(value, |value| { + let value = f64::from_bits(value); + f16::from_f64(value).to_f64().to_bits() + }), + } +}