From 5a326db54651c1be6c53813607ff49600cb08afb Mon Sep 17 00:00:00 2001 From: Rua Date: Mon, 25 Oct 2021 19:16:54 +0200 Subject: [PATCH] Add `DescriptorRequirements` (#1729) * Add DescriptorRequirements * Doc fix --- examples/src/bin/runtime_array/main.rs | 17 +- vulkano-shaders/Cargo.toml | 1 + vulkano-shaders/src/descriptor_sets.rs | 399 ++++++++--------- vulkano-shaders/src/entry_point.rs | 15 +- vulkano-shaders/src/lib.rs | 14 +- vulkano/src/descriptor_set/layout/desc.rs | 417 ++++++++++-------- vulkano/src/descriptor_set/layout/mod.rs | 1 + vulkano/src/format.rs | 50 +++ vulkano/src/pipeline/compute_pipeline.rs | 79 ++-- .../src/pipeline/graphics_pipeline/builder.rs | 107 ++++- vulkano/src/pipeline/graphics_pipeline/mod.rs | 13 +- vulkano/src/pipeline/layout/sys.rs | 78 ++-- vulkano/src/pipeline/shader.rs | 251 +++++++---- 13 files changed, 852 insertions(+), 590 deletions(-) diff --git a/examples/src/bin/runtime_array/main.rs b/examples/src/bin/runtime_array/main.rs index 4e7330ca..ff814208 100644 --- a/examples/src/bin/runtime_array/main.rs +++ b/examples/src/bin/runtime_array/main.rs @@ -12,7 +12,9 @@ use std::io::Cursor; use std::sync::Arc; use vulkano::buffer::{BufferUsage, CpuAccessibleBuffer, TypedBufferAccess}; use vulkano::command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage, SubpassContents}; -use vulkano::descriptor_set::layout::{DescriptorSetLayout, DescriptorSetLayoutError}; +use vulkano::descriptor_set::layout::{ + DescriptorSetDesc, DescriptorSetLayout, DescriptorSetLayoutError, +}; use vulkano::descriptor_set::PersistentDescriptorSet; use vulkano::device::physical::{PhysicalDevice, PhysicalDeviceType}; use vulkano::device::{Device, DeviceExtensions, Features}; @@ -23,7 +25,6 @@ use vulkano::image::{ use vulkano::instance::Instance; use vulkano::pipeline::color_blend::ColorBlendState; use vulkano::pipeline::layout::PipelineLayout; -use vulkano::pipeline::shader::EntryPointAbstract; use vulkano::pipeline::viewport::{Viewport, ViewportState}; use vulkano::pipeline::{GraphicsPipeline, PipelineBindPoint}; use vulkano::render_pass::{Framebuffer, FramebufferAbstract, RenderPass, Subpass}; @@ -281,11 +282,8 @@ fn main() { .unwrap(); let pipeline_layout = { - let mut descriptor_set_descs: Vec<_> = (&fs.main_entry_point() as &dyn EntryPointAbstract) - .descriptor_set_layout_descs() - .iter() - .cloned() - .collect(); + let mut descriptor_set_descs: Vec<_> = + DescriptorSetDesc::from_requirements(fs.main_entry_point().descriptor_requirements()); // Set 0, Binding 0 descriptor_set_descs[0].set_variable_descriptor_count(0, 2); @@ -305,10 +303,7 @@ fn main() { PipelineLayout::new( device.clone(), descriptor_set_layouts, - (&fs.main_entry_point() as &dyn EntryPointAbstract) - .push_constant_range() - .iter() - .cloned(), + fs.main_entry_point().push_constant_range().iter().cloned(), ) .unwrap(), ) diff --git a/vulkano-shaders/Cargo.toml b/vulkano-shaders/Cargo.toml index 20604b24..f4cb5e07 100644 --- a/vulkano-shaders/Cargo.toml +++ b/vulkano-shaders/Cargo.toml @@ -15,6 +15,7 @@ categories = ["rendering::graphics-api"] proc-macro = true [dependencies] +fnv = "1.0" proc-macro2 = "1.0" quote = "1.0" shaderc = "0.7" diff --git a/vulkano-shaders/src/descriptor_sets.rs b/vulkano-shaders/src/descriptor_sets.rs index 95fcac8e..a099a1fc7 100644 --- a/vulkano-shaders/src/descriptor_sets.rs +++ b/vulkano-shaders/src/descriptor_sets.rs @@ -8,80 +8,99 @@ // according to those terms. use crate::TypesMeta; +use fnv::FnvHashMap; use proc_macro2::TokenStream; use std::cmp; use std::collections::HashSet; -use vulkano::spirv::{Decoration, Dim, Id, ImageFormat, Instruction, Spirv, StorageClass}; +use vulkano::{ + descriptor_set::layout::DescriptorType, + format::Format, + image::view::ImageViewType, + pipeline::shader::DescriptorRequirements, + spirv::{Decoration, Dim, Id, ImageFormat, Instruction, Spirv, StorageClass}, +}; -#[derive(Debug)] -struct Descriptor { - set_num: u32, - binding_num: u32, - desc_ty: TokenStream, - descriptor_count: u64, - variable_count: bool, - mutable: bool, -} - -pub(super) fn write_descriptor_set_layout_descs( +pub(super) fn write_descriptor_requirements( spirv: &Spirv, entrypoint_id: Id, interface: &[Id], exact_entrypoint_interface: bool, stages: &TokenStream, ) -> TokenStream { - // TODO: somewhat implemented correctly + let descriptor_requirements = + find_descriptors(spirv, entrypoint_id, interface, exact_entrypoint_interface); - // Finding all the descriptors. - let descriptors = find_descriptors(spirv, entrypoint_id, interface, exact_entrypoint_interface); - let num_sets = descriptors.iter().map(|d| d.set_num + 1).max().unwrap_or(0); - let sets: Vec<_> = (0..num_sets) - .map(|set_num| { - let num_bindings = descriptors - .iter() - .filter(|d| d.set_num == set_num) - .map(|d| d.binding_num + 1) - .max() - .unwrap_or(0); - let bindings: Vec<_> = (0..num_bindings) - .map(|binding_num| { - match descriptors - .iter() - .find(|d| d.set_num == set_num && d.binding_num == binding_num) - { - Some(d) => { - let desc_ty = &d.desc_ty; - let descriptor_count = d.descriptor_count as u32; - let mutable = d.mutable; - let variable_count = d.variable_count; - quote! { - Some(DescriptorDesc { - ty: #desc_ty, - descriptor_count: #descriptor_count, - stages: #stages, - variable_count: #variable_count, - mutable: #mutable, - }), - } - } - None => quote! { - None, - }, - } - }) - .collect(); + let descriptor_requirements = descriptor_requirements.into_iter().map(|(loc, reqs)| { + let (set_num, binding_num) = loc; + let DescriptorRequirements { + descriptor_types, + descriptor_count, + format, + image_view_type, + multisampled, + mutable, + stages: _, + } = reqs; + + let descriptor_types = descriptor_types.into_iter().map(|ty| { + let ident = format_ident!("{}", format!("{:?}", ty)); + quote! { DescriptorType::#ident } + }); + let format = match format { + Some(format) => { + let ident = format_ident!("{}", format!("{:?}", format)); + quote! { Some(Format::#ident) } + } + None => quote! { None }, + }; + let image_view_type = match image_view_type { + Some(image_view_type) => { + let ident = format_ident!("{}", format!("{:?}", image_view_type)); + quote! { Some(ImageViewType::#ident) } + } + None => quote! { None }, + }; + /*let stages = { + let ShaderStages { + vertex, + tessellation_control, + tessellation_evaluation, + geometry, + fragment, + compute, + } = stages; quote! { - DescriptorSetDesc::new( - [#( #bindings )*] - ), + ShaderStages { + vertex: #vertex, + tessellation_control: #tessellation_control, + tessellation_evaluation: #tessellation_evaluation, + geometry: #geometry, + fragment: #fragment, + compute: #compute, + } } - }) - .collect(); + };*/ + + quote! { + ( + (#set_num, #binding_num), + DescriptorRequirements { + descriptor_types: vec![#(#descriptor_types),*], + descriptor_count: #descriptor_count, + format: #format, + image_view_type: #image_view_type, + multisampled: #multisampled, + mutable: #mutable, + stages: #stages, + }, + ), + } + }); quote! { [ - #( #sets )* + #( #descriptor_requirements )* ] } } @@ -134,7 +153,7 @@ fn find_descriptors( entrypoint_id: Id, interface: &[Id], exact: bool, -) -> Vec { +) -> FnvHashMap<(u32, u32), DescriptorRequirements> { // For SPIR-V 1.4+, the entrypoint interface can specify variables of all storage classes, // and most tools will put all used variables in the entrypoint interface. However, // SPIR-V 1.0-1.3 do not specify variables other than Input/Output ones in the interface, @@ -228,20 +247,15 @@ fn find_descriptors( }); // Find information about the kind of binding for this descriptor. - let (desc_ty, mutable, descriptor_count, variable_count) = - descriptor_infos(spirv, variable_type_id, storage_class, false).expect(&format!( - "Couldn't find relevant type for uniform `{}` (type {}, maybe unimplemented)", + let mut reqs = + descriptor_requirements(spirv, variable_type_id, storage_class, false).expect(&format!( + "Couldn't find relevant type for global variable `{}` (type {}, maybe unimplemented)", name, variable_type_id, )); - Some(Descriptor { - desc_ty, - set_num, - binding_num, - descriptor_count, - mutable: !nonwritable && mutable, - variable_count: variable_count, - }) + reqs.mutable &= !nonwritable; + + Some(((set_num, binding_num), reqs)) }) .collect() } @@ -359,16 +373,15 @@ fn find_variables_in_function( } } -/// Returns a `DescriptorDescTy` constructor, a bool indicating whether the descriptor is -/// read-only, and the number of array elements. +/// Returns a `DescriptorRequirements` value for the pointed type. /// /// See also section 14.5.2 of the Vulkan specs: Descriptor Set Interface -fn descriptor_infos( +fn descriptor_requirements( spirv: &Spirv, pointed_ty: Id, pointer_storage: &StorageClass, force_combined_image_sampled: bool, -) -> Option<(TokenStream, bool, u64, bool)> { +) -> Option { let id_info = spirv.id(pointed_ty); match id_info.instruction() { @@ -398,10 +411,14 @@ fn descriptor_infos( "Structs in shader interface are expected to be decorated with one of Block or BufferBlock" ); - let (ty, mutable) = if decoration_buffer_block + let mut reqs = DescriptorRequirements { + descriptor_count: 1, + ..Default::default() + }; + + if decoration_buffer_block || decoration_block && *pointer_storage == StorageClass::StorageBuffer { - // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER // Determine whether all members have a NonWritable decoration. let nonwritable = id_info.iter_members().all(|member_info| { member_info.iter_decoration().any(|instruction| { @@ -415,13 +432,19 @@ fn descriptor_infos( }) }); - (quote! { DescriptorDescTy::StorageBuffer }, !nonwritable) + reqs.descriptor_types = vec![ + DescriptorType::StorageBuffer, + DescriptorType::StorageBufferDynamic, + ]; + reqs.mutable = !nonwritable; } else { - // VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER - (quote! { DescriptorDescTy::UniformBuffer }, false) // Uniforms are never mutable. + reqs.descriptor_types = vec![ + DescriptorType::UniformBuffer, + DescriptorType::UniformBufferDynamic, + ]; }; - Some((ty, mutable, 1, false)) + Some(reqs) } &Instruction::TypeImage { ref dim, @@ -431,13 +454,12 @@ fn descriptor_infos( ref image_format, .. } => { - let ms = ms != 0; + let multisampled = ms != 0; assert!(sampled != 0, "Vulkan requires that variables of type OpTypeImage have a Sampled operand of 1 or 2"); - let vulkan_format = to_vulkan_format(image_format); + let format: Option = image_format.clone().into(); match dim { Dim::SubpassData => { - // VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT assert!( !force_combined_image_sampled, "An OpTypeSampledImage can't point to \ @@ -451,138 +473,118 @@ fn descriptor_infos( assert!(sampled == 2, "If Dim is SubpassData, Sampled must be 2"); assert!(arrayed == 0, "If Dim is SubpassData, Arrayed must be 0"); - let desc = quote! { - DescriptorDescTy::InputAttachment { - multisampled: #ms, - } - }; - - Some((desc, true, 1, false)) // Never writable. + Some(DescriptorRequirements { + descriptor_types: vec![DescriptorType::InputAttachment], + descriptor_count: 1, + multisampled, + ..Default::default() + }) } Dim::Buffer => { - let (ty, mutable) = if sampled == 1 { - // VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER - (quote! { DescriptorDescTy::UniformTexelBuffer }, false) - // Uniforms are never mutable. + let mut reqs = DescriptorRequirements { + descriptor_count: 1, + format, + ..Default::default() + }; + + if sampled == 1 { + reqs.descriptor_types = vec![DescriptorType::UniformTexelBuffer]; } else { - // VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER - (quote! { DescriptorDescTy::StorageTexelBuffer }, true) - }; + reqs.descriptor_types = vec![DescriptorType::StorageTexelBuffer]; + reqs.mutable = true; + } - let desc = quote! { - #ty { - format: #vulkan_format, - } - }; - - Some((desc, mutable, 1, false)) + Some(reqs) } _ => { - let view_type = match (dim, arrayed) { - (Dim::Dim1D, 0) => quote! { ImageViewType::Dim1d }, - (Dim::Dim1D, 1) => quote! { ImageViewType::Dim1dArray }, - (Dim::Dim2D, 0) => quote! { ImageViewType::Dim2d }, - (Dim::Dim2D, 1) => quote! { ImageViewType::Dim2dArray }, - (Dim::Dim3D, 0) => quote! { ImageViewType::Dim3d }, + let image_view_type = Some(match (dim, arrayed) { + (Dim::Dim1D, 0) => ImageViewType::Dim1d, + (Dim::Dim1D, 1) => ImageViewType::Dim1dArray, + (Dim::Dim2D, 0) => ImageViewType::Dim2d, + (Dim::Dim2D, 1) => ImageViewType::Dim2dArray, + (Dim::Dim3D, 0) => ImageViewType::Dim3d, (Dim::Dim3D, 1) => panic!("Vulkan doesn't support arrayed 3D textures"), - (Dim::Cube, 0) => quote! { ImageViewType::Cube }, - (Dim::Cube, 1) => quote! { ImageViewType::CubeArray }, + (Dim::Cube, 0) => ImageViewType::Cube, + (Dim::Cube, 1) => ImageViewType::CubeArray, (Dim::Rect, _) => panic!("Vulkan doesn't support rectangle textures"), _ => unreachable!(), + }); + + let mut reqs = DescriptorRequirements { + descriptor_count: 1, + format, + multisampled, + image_view_type, + ..Default::default() }; - let image_desc = quote! { - DescriptorDescImage { - format: #vulkan_format, - multisampled: #ms, - view_type: #view_type, - } - }; - - let (desc, mutable) = if force_combined_image_sampled { - // VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER - // Never writable. + if force_combined_image_sampled { assert!( sampled == 1, "A combined image sampler must not reference a storage image" ); - ( - quote! { - DescriptorDescTy::CombinedImageSampler { - image_desc: #image_desc, - immutable_samplers: Vec::new(), - } - }, - false, // Sampled images are never mutable. - ) + reqs.descriptor_types = vec![DescriptorType::CombinedImageSampler]; } else { - let (ty, mutable) = if sampled == 1 { - // VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE - (quote! { DescriptorDescTy::SampledImage }, false) // Sampled images are never mutable. + if sampled == 1 { + reqs.descriptor_types = vec![DescriptorType::SampledImage]; } else { - // VK_DESCRIPTOR_TYPE_STORAGE_IMAGE - (quote! { DescriptorDescTy::StorageImage }, true) - }; - - ( - quote! { - #ty { - image_desc: #image_desc, - } - }, - mutable, - ) + reqs.descriptor_types = vec![DescriptorType::StorageImage]; + reqs.mutable = true; + } }; - Some((desc, mutable, 1, false)) + Some(reqs) } } } &Instruction::TypeSampledImage { image_type, .. } => { - descriptor_infos(spirv, image_type, pointer_storage, true) + descriptor_requirements(spirv, image_type, pointer_storage, true) } - &Instruction::TypeSampler { .. } => { - let desc = quote! { DescriptorDescTy::Sampler { immutable_samplers: Vec::new() } }; - Some((desc, false, 1, false)) - } + &Instruction::TypeSampler { .. } => Some(DescriptorRequirements { + descriptor_types: vec![DescriptorType::Sampler], + descriptor_count: 1, + ..Default::default() + }), &Instruction::TypeArray { element_type, length, .. } => { - let (desc, mutable, arr, variable_count) = - match descriptor_infos(spirv, element_type, pointer_storage, false) { - None => return None, - Some(v) => v, - }; - assert_eq!(arr, 1); // TODO: implement? - assert!(!variable_count); // TODO: Is this even a thing? + let reqs = match descriptor_requirements(spirv, element_type, pointer_storage, false) { + None => return None, + Some(v) => v, + }; + assert_eq!(reqs.descriptor_count, 1); // TODO: implement? let len = match spirv.id(length).instruction() { &Instruction::Constant { ref value, .. } => value, _ => panic!("failed to find array length"), }; let len = len.iter().rev().fold(0, |a, &b| (a << 32) | b as u64); - Some((desc, mutable, len, false)) + Some(DescriptorRequirements { + descriptor_count: len as u32, + ..reqs + }) } &Instruction::TypeRuntimeArray { element_type, .. } => { - let (desc, mutable, arr, variable_count) = - match descriptor_infos(spirv, element_type, pointer_storage, false) { - None => return None, - Some(v) => v, - }; - assert_eq!(arr, 1); // TODO: implement? - assert!(!variable_count); // TODO: Don't think this is possible? + let reqs = match descriptor_requirements(spirv, element_type, pointer_storage, false) { + None => return None, + Some(v) => v, + }; + assert_eq!(reqs.descriptor_count, 1); // TODO: implement? - Some((desc, mutable, 1, true)) + Some(DescriptorRequirements { + descriptor_count: 0, + ..reqs + }) } - _ => None, // TODO: other types + _ => None, } } @@ -682,8 +684,8 @@ mod tests { // Check first entrypoint let e1_descriptors = descriptors.get(0).expect("Could not find entrypoint1"); let mut e1_bindings = Vec::new(); - for d in e1_descriptors { - e1_bindings.push((d.set_num, d.binding_num)); + for (loc, _reqs) in e1_descriptors { + e1_bindings.push(*loc); } assert_eq!(e1_bindings.len(), 5); assert!(e1_bindings.contains(&(0, 0))); @@ -695,8 +697,8 @@ mod tests { // Check second entrypoint let e2_descriptors = descriptors.get(1).expect("Could not find entrypoint2"); let mut e2_bindings = Vec::new(); - for d in e2_descriptors { - e2_bindings.push((d.set_num, d.binding_num)); + for (loc, _reqs) in e2_descriptors { + e2_bindings.push(*loc); } assert_eq!(e2_bindings.len(), 3); assert!(e2_bindings.contains(&(0, 0))); @@ -755,8 +757,8 @@ mod tests { { let descriptors = find_descriptors(&spirv, entry_point, interface, true); let mut bindings = Vec::new(); - for d in descriptors { - bindings.push((d.set_num, d.binding_num)); + for (loc, _reqs) in descriptors { + bindings.push(loc); } assert_eq!(bindings.len(), 4); assert!(bindings.contains(&(1, 0))); @@ -770,50 +772,3 @@ mod tests { panic!("Could not find entrypoint"); } } - -fn to_vulkan_format(spirv_format: &ImageFormat) -> TokenStream { - match spirv_format { - ImageFormat::Unknown => quote! { None }, - ImageFormat::Rgba32f => quote! { Some(Format::R32G32B32A32_SFLOAT) }, - ImageFormat::Rgba16f => quote! { Some(Format::R16G16B16A16_SFLOAT) }, - ImageFormat::R32f => quote! { Some(Format::R32_SFLOAT) }, - ImageFormat::Rgba8 => quote! { Some(Format::R8G8B8A8_UNORM) }, - ImageFormat::Rgba8Snorm => quote! { Some(Format::R8G8B8A8_SNORM) }, - ImageFormat::Rg32f => quote! { Some(Format::R32G32_SFLOAT) }, - ImageFormat::Rg16f => quote! { Some(Format::R16G16_SFLOAT) }, - ImageFormat::R11fG11fB10f => quote! { Some(Format::B10G11R11_UFLOAT_PACK32) }, - ImageFormat::R16f => quote! { Some(Format::R16_SFLOAT) }, - ImageFormat::Rgba16 => quote! { Some(Format::R16G16B16A16_UNORM) }, - ImageFormat::Rgb10A2 => quote! { Some(Format::A2B10G10R10_UNORMPack32) }, - ImageFormat::Rg16 => quote! { Some(Format::R16G16_UNORM) }, - ImageFormat::Rg8 => quote! { Some(Format::R8G8_UNORM) }, - ImageFormat::R16 => quote! { Some(Format::R16_UNORM) }, - ImageFormat::R8 => quote! { Some(Format::R8_UNORM) }, - ImageFormat::Rgba16Snorm => quote! { Some(Format::R16G16B16A16_SNORM) }, - ImageFormat::Rg16Snorm => quote! { Some(Format::R16G16_SNORM) }, - ImageFormat::Rg8Snorm => quote! { Some(Format::R8G8_SNORM) }, - ImageFormat::R16Snorm => quote! { Some(Format::R16_SNORM) }, - ImageFormat::R8Snorm => quote! { Some(Format::R8_SNORM) }, - ImageFormat::Rgba32i => quote! { Some(Format::R32G32B32A32_SINT) }, - ImageFormat::Rgba16i => quote! { Some(Format::R16G16B16A16_SINT) }, - ImageFormat::Rgba8i => quote! { Some(Format::R8G8B8A8_SINT) }, - ImageFormat::R32i => quote! { Some(Format::R32_SINT) }, - ImageFormat::Rg32i => quote! { Some(Format::R32G32_SINT) }, - ImageFormat::Rg16i => quote! { Some(Format::R16G16_SINT) }, - ImageFormat::Rg8i => quote! { Some(Format::R8G8_SINT) }, - ImageFormat::R16i => quote! { Some(Format::R16_SINT) }, - ImageFormat::R8i => quote! { Some(Format::R8_SINT) }, - ImageFormat::Rgba32ui => quote! { Some(Format::R32G32B32A32_UINT) }, - ImageFormat::Rgba16ui => quote! { Some(Format::R16G16B16A16_UINT) }, - ImageFormat::Rgba8ui => quote! { Some(Format::R8G8B8A8_UINT) }, - ImageFormat::R32ui => quote! { Some(Format::R32_UINT) }, - ImageFormat::Rgb10a2ui => quote! { Some(Format::A2B10G10R10_UINT_PACK32) }, - ImageFormat::Rg32ui => quote! { Some(Format::R32G32_UINT) }, - ImageFormat::Rg16ui => quote! { Some(Format::R16G16_UINT) }, - ImageFormat::Rg8ui => quote! { Some(Format::R8G8_UINT) }, - ImageFormat::R16ui => quote! { Some(Format::R16_UINT) }, - ImageFormat::R8ui => quote! { Some(Format::R8_UINT) }, - ImageFormat::R64ui => quote! { Some(Format::R64_UINT) }, - ImageFormat::R64i => quote! { Some(Format::R64_SINT) }, - } -} diff --git a/vulkano-shaders/src/entry_point.rs b/vulkano-shaders/src/entry_point.rs index faf9ecf3..bee4ecfa 100644 --- a/vulkano-shaders/src/entry_point.rs +++ b/vulkano-shaders/src/entry_point.rs @@ -7,7 +7,7 @@ // notice may not be copied, modified, or distributed except // according to those terms. -use crate::descriptor_sets::{write_descriptor_set_layout_descs, write_push_constant_ranges}; +use crate::descriptor_sets::{write_descriptor_requirements, write_push_constant_ranges}; use crate::{spirv_search, TypesMeta}; use proc_macro2::{Span, TokenStream}; use syn::Ident; @@ -84,13 +84,8 @@ pub(super) fn write_entry_point( } }; - let descriptor_set_layout_descs = write_descriptor_set_layout_descs( - &spirv, - id, - interface, - exact_entrypoint_interface, - &stage, - ); + let descriptor_requirements = + write_descriptor_requirements(&spirv, id, interface, exact_entrypoint_interface, &stage); let push_constant_ranges = write_push_constant_ranges(shader, &spirv, &stage, &types_meta); let spec_consts_struct = if crate::spec_consts::has_specialization_constants(spirv) { @@ -112,7 +107,7 @@ pub(super) fn write_entry_point( quote! { ::vulkano::pipeline::shader::ComputeEntryPoint }, quote! { compute_entry_point( ::std::ffi::CStr::from_ptr(NAME.as_ptr() as *const _), - #descriptor_set_layout_descs, + #descriptor_requirements, #push_constant_ranges, <#spec_consts_struct>::descriptors(), )}, @@ -185,7 +180,7 @@ pub(super) fn write_entry_point( let f_call = quote! { graphics_entry_point( ::std::ffi::CStr::from_ptr(NAME.as_ptr() as *const _), - #descriptor_set_layout_descs, + #descriptor_requirements, #push_constant_ranges, <#spec_consts_struct>::descriptors(), #input_interface, diff --git a/vulkano-shaders/src/lib.rs b/vulkano-shaders/src/lib.rs index 5d0b7817..0842e801 100644 --- a/vulkano-shaders/src/lib.rs +++ b/vulkano-shaders/src/lib.rs @@ -920,17 +920,7 @@ pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { #[allow(unused_imports)] use vulkano::device::Device; #[allow(unused_imports)] - use vulkano::descriptor_set::layout::DescriptorDesc; - #[allow(unused_imports)] - use vulkano::descriptor_set::layout::DescriptorDescTy; - #[allow(unused_imports)] - use vulkano::descriptor_set::layout::DescriptorDescImage; - #[allow(unused_imports)] - use vulkano::descriptor_set::layout::DescriptorSetDesc; - #[allow(unused_imports)] - use vulkano::descriptor_set::layout::DescriptorSetLayout; - #[allow(unused_imports)] - use vulkano::descriptor_set::DescriptorSet; + use vulkano::descriptor_set::layout::DescriptorType; #[allow(unused_imports)] use vulkano::format::Format; #[allow(unused_imports)] @@ -940,6 +930,8 @@ pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { #[allow(unused_imports)] use vulkano::pipeline::layout::PipelineLayoutPcRange; #[allow(unused_imports)] + use vulkano::pipeline::shader::DescriptorRequirements; + #[allow(unused_imports)] use vulkano::pipeline::shader::ShaderStages; #[allow(unused_imports)] use vulkano::pipeline::shader::SpecializationConstants as SpecConstsTrait; diff --git a/vulkano/src/descriptor_set/layout/desc.rs b/vulkano/src/descriptor_set/layout/desc.rs index 35d814ce..f4b51b0f 100644 --- a/vulkano/src/descriptor_set/layout/desc.rs +++ b/vulkano/src/descriptor_set/layout/desc.rs @@ -43,6 +43,7 @@ use crate::format::Format; use crate::image::view::ImageViewType; +use crate::pipeline::shader::DescriptorRequirements; use crate::pipeline::shader::ShaderStages; use crate::sampler::Sampler; use crate::sync::AccessFlags; @@ -66,16 +67,44 @@ impl DescriptorSetDesc { /// at bind point 0 first, then descriptor at bind point 1, and so on. If a binding must remain /// empty, you can make the iterator yield `None` for an element. #[inline] - pub fn new(descriptors: I) -> DescriptorSetDesc + pub fn new(descriptors: I) -> Self where I: IntoIterator>, { - DescriptorSetDesc { + Self { descriptors: descriptors.into_iter().collect(), push_descriptor: false, } } + /// Builds a list of `DescriptorSetDesc` from an iterator of `DescriptorRequirement` originating + /// from a shader. + #[inline] + pub fn from_requirements<'a>( + descriptor_requirements: impl IntoIterator, + ) -> Vec { + let mut descriptor_sets: Vec = Vec::new(); + + for ((set_num, binding_num), reqs) in descriptor_requirements { + let set_num = set_num as usize; + let binding_num = binding_num as usize; + + if set_num >= descriptor_sets.len() { + descriptor_sets.resize(set_num + 1, Self::default()); + } + + let descriptors = &mut descriptor_sets[set_num].descriptors; + + if binding_num >= descriptors.len() { + descriptors.resize(binding_num + 1, None); + } + + descriptors[binding_num] = Some(reqs.into()); + } + + descriptor_sets + } + /// Builds a new empty `DescriptorSetDesc`. #[inline] pub fn empty() -> DescriptorSetDesc { @@ -103,54 +132,6 @@ impl DescriptorSetDesc { self.push_descriptor } - /// Builds the union of this layout description and another. - #[inline] - pub fn union( - first: &DescriptorSetDesc, - second: &DescriptorSetDesc, - ) -> Result { - let num_bindings = cmp::max(first.descriptors.len(), second.descriptors.len()); - let descriptors = (0..num_bindings) - .map(|binding_num| { - DescriptorDesc::union( - first - .descriptors - .get(binding_num) - .map(|desc| desc.as_ref()) - .flatten(), - second - .descriptors - .get(binding_num) - .map(|desc| desc.as_ref()) - .flatten(), - ) - }) - .collect::>()?; - Ok(DescriptorSetDesc { - descriptors, - push_descriptor: false, - }) - } - - /// Builds the union of multiple descriptor sets. - pub fn union_multiple( - first: &[DescriptorSetDesc], - second: &[DescriptorSetDesc], - ) -> Result, ()> { - // Ewwwwwww - let empty = DescriptorSetDesc::empty(); - let num_sets = cmp::max(first.len(), second.len()); - - (0..num_sets) - .map(|set_num| { - Ok(DescriptorSetDesc::union( - first.get(set_num).unwrap_or_else(|| &empty), - second.get(set_num).unwrap_or_else(|| &empty), - )?) - }) - .collect() - } - /// Changes a buffer descriptor's type to dynamic. /// /// # Panics @@ -253,11 +234,8 @@ impl DescriptorSetDesc { .and_then(|b| b.as_mut()) { Some(desc) => { - if desc.variable_count { - desc.descriptor_count = descriptor_count; - } else { - panic!("descriptor isn't variable count") - } + desc.variable_count = true; + desc.descriptor_count = descriptor_count; } None => panic!("descriptor is empty"), } @@ -284,50 +262,6 @@ impl DescriptorSetDesc { }) } - /// Checks whether the descriptor of a pipeline layout `self` is compatible with the descriptor - /// of a shader `other`. - pub fn ensure_compatible_with_shader( - &self, - other: &DescriptorSetDesc, - ) -> Result<(), DescriptorSetCompatibilityError> { - // Don't care about push descriptors. - - if self.descriptors.len() < other.descriptors.len() { - return Err(DescriptorSetCompatibilityError::DescriptorsCountMismatch { - self_num: self.descriptors.len() as u32, - other_num: other.descriptors.len() as u32, - }); - } - - for binding_num in 0..other.descriptors.len() as u32 { - let self_desc = self.descriptor(binding_num); - let other_desc = self.descriptor(binding_num); - - match (self_desc, other_desc) { - (Some(mine), Some(other)) => { - if let Err(err) = mine.ensure_compatible_with_shader(&other) { - return Err(DescriptorSetCompatibilityError::IncompatibleDescriptors { - error: err, - binding_num: binding_num as u32, - }); - } - } - (None, Some(_)) => { - return Err(DescriptorSetCompatibilityError::IncompatibleDescriptors { - error: DescriptorCompatibilityError::Empty { - first: true, - second: false, - }, - binding_num: binding_num as u32, - }) - } - _ => (), - } - } - - Ok(()) - } - /// Checks whether the descriptor set of a pipeline layout `self` is compatible with the /// descriptor set being bound `other`. /// @@ -439,39 +373,65 @@ impl DescriptorDesc { #[inline] pub fn ensure_compatible_with_shader( &self, - other: &DescriptorDesc, - ) -> Result<(), DescriptorCompatibilityError> { - match (self.ty.ty(), other.ty.ty()) { - (DescriptorType::UniformBufferDynamic, DescriptorType::UniformBuffer) => (), - (DescriptorType::StorageBufferDynamic, DescriptorType::StorageBuffer) => (), - _ => self.ty.ensure_superset_of(&other.ty)?, - } + descriptor_requirements: &DescriptorRequirements, + ) -> Result<(), DescriptorRequirementsNotMet> { + let DescriptorRequirements { + descriptor_types, + descriptor_count, + format, + image_view_type, + multisampled, + mutable, + stages, + } = descriptor_requirements; - if !self.stages.is_superset_of(&other.stages) { - return Err(DescriptorCompatibilityError::ShaderStages { - first: self.stages, - second: other.stages, + if !descriptor_types.contains(&self.ty.ty()) { + return Err(DescriptorRequirementsNotMet::DescriptorType { + required: descriptor_types.clone(), + obtained: self.ty.ty(), }); } - if self.descriptor_count < other.descriptor_count { - return Err(DescriptorCompatibilityError::DescriptorCount { - first: self.descriptor_count, - second: other.descriptor_count, + if self.descriptor_count < *descriptor_count { + return Err(DescriptorRequirementsNotMet::DescriptorCount { + required: *descriptor_count, + obtained: self.descriptor_count, }); } - if self.variable_count != other.variable_count { - return Err(DescriptorCompatibilityError::VariableCount { - first: self.variable_count, - second: other.variable_count, + if let Some(format) = *format { + if self.ty.format() != Some(format) { + return Err(DescriptorRequirementsNotMet::Format { + required: format, + obtained: self.ty.format(), + }); + } + } + + if let Some(image_view_type) = *image_view_type { + if self.ty.image_view_type() != Some(image_view_type) { + return Err(DescriptorRequirementsNotMet::ImageViewType { + required: image_view_type, + obtained: self.ty.image_view_type(), + }); + } + } + + if *multisampled != self.ty.multisampled() { + return Err(DescriptorRequirementsNotMet::Multisampling { + required: *multisampled, + obtained: self.ty.multisampled(), }); } - if !self.mutable && other.mutable { - return Err(DescriptorCompatibilityError::Mutability { - first: self.mutable, - second: other.mutable, + if *mutable && !self.mutable { + return Err(DescriptorRequirementsNotMet::Mutability); + } + + if !self.stages.is_superset_of(stages) { + return Err(DescriptorRequirementsNotMet::ShaderStages { + required: *stages, + obtained: self.stages, }); } @@ -518,70 +478,6 @@ impl DescriptorDesc { Ok(()) } - /// Builds a `DescriptorDesc` that is the union of `self` and `other`, if possible. - /// - /// The returned value will be a superset of both `self` and `other`, or `None` if both were - /// `None`. - /// - /// `Err` is returned if the descriptors are not compatible. - /// - ///# Example - ///``` - ///use vulkano::descriptor_set::layout::DescriptorDesc; - ///use vulkano::descriptor_set::layout::DescriptorDescTy::*; - ///use vulkano::pipeline::shader::ShaderStages; - /// - ///let desc_part1 = DescriptorDesc{ ty: Sampler { immutable_samplers: vec![] }, descriptor_count: 2, stages: ShaderStages{ - /// vertex: true, - /// tessellation_control: true, - /// tessellation_evaluation: false, - /// geometry: true, - /// fragment: false, - /// compute: true - ///}, mutable: true, variable_count: false }; - /// - ///let desc_part2 = DescriptorDesc{ ty: Sampler { immutable_samplers: vec![] }, descriptor_count: 1, stages: ShaderStages{ - /// vertex: true, - /// tessellation_control: false, - /// tessellation_evaluation: true, - /// geometry: false, - /// fragment: true, - /// compute: true - ///}, mutable: false, variable_count: false }; - /// - ///let desc_union = DescriptorDesc{ ty: Sampler { immutable_samplers: vec![] }, descriptor_count: 2, stages: ShaderStages{ - /// vertex: true, - /// tessellation_control: true, - /// tessellation_evaluation: true, - /// geometry: true, - /// fragment: true, - /// compute: true - ///}, mutable: true, variable_count: false }; - /// - ///assert_eq!(DescriptorDesc::union(Some(&desc_part1), Some(&desc_part2)), Ok(Some(desc_union))); - ///``` - #[inline] - pub fn union( - first: Option<&DescriptorDesc>, - second: Option<&DescriptorDesc>, - ) -> Result, ()> { - if let (Some(first), Some(second)) = (first, second) { - if first.ty != second.ty { - return Err(()); - } - - Ok(Some(DescriptorDesc { - ty: first.ty.clone(), - descriptor_count: cmp::max(first.descriptor_count, second.descriptor_count), - stages: first.stages | second.stages, - mutable: first.mutable || second.mutable, - variable_count: first.variable_count && second.variable_count, // TODO: What is the correct behavior here? - })) - } else { - Ok(first.or(second).cloned()) - } - } - /// Returns the pipeline stages and access flags corresponding to the usage of this descriptor. /// /// # Panic @@ -626,6 +522,59 @@ impl DescriptorDesc { } } +impl From<&DescriptorRequirements> for DescriptorDesc { + fn from(reqs: &DescriptorRequirements) -> Self { + let ty = match reqs.descriptor_types[0] { + DescriptorType::Sampler => DescriptorDescTy::Sampler { + immutable_samplers: Vec::new(), + }, + DescriptorType::CombinedImageSampler => DescriptorDescTy::CombinedImageSampler { + image_desc: DescriptorDescImage { + format: reqs.format, + multisampled: reqs.multisampled, + view_type: reqs.image_view_type.unwrap(), + }, + immutable_samplers: Vec::new(), + }, + DescriptorType::SampledImage => DescriptorDescTy::SampledImage { + image_desc: DescriptorDescImage { + format: reqs.format, + multisampled: reqs.multisampled, + view_type: reqs.image_view_type.unwrap(), + }, + }, + DescriptorType::StorageImage => DescriptorDescTy::StorageImage { + image_desc: DescriptorDescImage { + format: reqs.format, + multisampled: reqs.multisampled, + view_type: reqs.image_view_type.unwrap(), + }, + }, + DescriptorType::UniformTexelBuffer => DescriptorDescTy::UniformTexelBuffer { + format: reqs.format, + }, + DescriptorType::StorageTexelBuffer => DescriptorDescTy::StorageTexelBuffer { + format: reqs.format, + }, + DescriptorType::UniformBuffer => DescriptorDescTy::UniformBuffer, + DescriptorType::StorageBuffer => DescriptorDescTy::StorageBuffer, + DescriptorType::UniformBufferDynamic => DescriptorDescTy::UniformBufferDynamic, + DescriptorType::StorageBufferDynamic => DescriptorDescTy::StorageBufferDynamic, + DescriptorType::InputAttachment => DescriptorDescTy::InputAttachment { + multisampled: reqs.multisampled, + }, + }; + + Self { + ty, + descriptor_count: reqs.descriptor_count, + stages: reqs.stages, + variable_count: false, + mutable: reqs.mutable, + } + } +} + /// Describes what kind of resource may later be bound to a descriptor. /// /// This is mostly the same as a `DescriptorDescTy` but with less precise information. @@ -652,6 +601,86 @@ impl From for ash::vk::DescriptorType { } } +/// Error when checking whether the requirements for a descriptor have been met. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DescriptorRequirementsNotMet { + /// The descriptor's type is not one of those required. + DescriptorType { + required: Vec, + obtained: DescriptorType, + }, + + /// The descriptor count is less than what is required. + DescriptorCount { required: u32, obtained: u32 }, + + /// The descriptor's format does not match what is required. + Format { + required: Format, + obtained: Option, + }, + + /// The descriptor's image view type does not match what is required. + ImageViewType { + required: ImageViewType, + obtained: Option, + }, + + /// The descriptor's multisampling does not match what is required. + Multisampling { required: bool, obtained: bool }, + + /// The descriptor is marked as read-only, but mutability is required. + Mutability, + + /// The descriptor's shader stages do not contain the stages that are required. + ShaderStages { + required: ShaderStages, + obtained: ShaderStages, + }, +} + +impl error::Error for DescriptorRequirementsNotMet {} + +impl fmt::Display for DescriptorRequirementsNotMet { + #[inline] + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match self { + Self::DescriptorType { required, obtained } => write!( + fmt, + "the descriptor's type ({:?}) is not one of those required ({:?})", + obtained, required + ), + Self::DescriptorCount { required, obtained } => write!( + fmt, + "the descriptor count ({}) is less than what is required ({})", + obtained, required + ), + Self::Format { required, obtained } => write!( + fmt, + "the descriptor's format ({:?}) does not match what is required ({:?})", + obtained, required + ), + Self::ImageViewType { required, obtained } => write!( + fmt, + "the descriptor's image view type ({:?}) does not match what is required ({:?})", + obtained, required + ), + Self::Multisampling { required, obtained } => write!( + fmt, + "the descriptor's multisampling ({}) does not match what is required ({})", + obtained, required + ), + Self::Mutability => write!( + fmt, + "the descriptor is marked as read-only, but mutability is required", + ), + Self::ShaderStages { required, obtained } => write!( + fmt, + "the descriptor's shader stages do not contain the stages that are required", + ), + } + } +} + /// Describes the content and layout of each array element of a descriptor. #[derive(Debug, Clone, PartialEq, Eq)] pub enum DescriptorDescTy { @@ -723,6 +752,9 @@ impl DescriptorDescTy { #[inline] fn format(&self) -> Option { match self { + Self::CombinedImageSampler { image_desc, .. } + | Self::SampledImage { image_desc, .. } + | Self::StorageImage { image_desc, .. } => image_desc.format, Self::UniformTexelBuffer { format } | Self::StorageTexelBuffer { format } => *format, _ => None, } @@ -738,6 +770,16 @@ impl DescriptorDescTy { } } + #[inline] + fn image_view_type(&self) -> Option { + match self { + Self::CombinedImageSampler { image_desc, .. } + | Self::SampledImage { image_desc, .. } + | Self::StorageImage { image_desc, .. } => Some(image_desc.view_type), + _ => None, + } + } + #[inline] pub(super) fn immutable_samplers(&self) -> &[Arc] { match self { @@ -754,6 +796,9 @@ impl DescriptorDescTy { #[inline] fn multisampled(&self) -> bool { match self { + Self::CombinedImageSampler { image_desc, .. } + | Self::SampledImage { image_desc, .. } + | Self::StorageImage { image_desc, .. } => image_desc.multisampled, DescriptorDescTy::InputAttachment { multisampled } => *multisampled, _ => false, } diff --git a/vulkano/src/descriptor_set/layout/mod.rs b/vulkano/src/descriptor_set/layout/mod.rs index f9961659..f969ea61 100644 --- a/vulkano/src/descriptor_set/layout/mod.rs +++ b/vulkano/src/descriptor_set/layout/mod.rs @@ -17,6 +17,7 @@ pub use self::desc::DescriptorCompatibilityError; pub use self::desc::DescriptorDesc; pub use self::desc::DescriptorDescImage; pub use self::desc::DescriptorDescTy; +pub use self::desc::DescriptorRequirementsNotMet; pub use self::desc::DescriptorSetCompatibilityError; pub use self::desc::DescriptorSetDesc; pub use self::desc::DescriptorType; diff --git a/vulkano/src/format.rs b/vulkano/src/format.rs index 4aebb083..aa7f71a3 100644 --- a/vulkano/src/format.rs +++ b/vulkano/src/format.rs @@ -95,6 +95,7 @@ use crate::device::physical::PhysicalDevice; use crate::image::ImageAspects; +use crate::spirv::ImageFormat; use crate::DeviceSize; use crate::VulkanObject; use half::f16; @@ -183,6 +184,55 @@ impl From for ash::vk::Format { } } +impl From for Option { + fn from(val: ImageFormat) -> Self { + match val { + ImageFormat::Unknown => None, + ImageFormat::Rgba32f => Some(Format::R32G32B32A32_SFLOAT), + ImageFormat::Rgba16f => Some(Format::R16G16B16A16_SFLOAT), + ImageFormat::R32f => Some(Format::R32_SFLOAT), + ImageFormat::Rgba8 => Some(Format::R8G8B8A8_UNORM), + ImageFormat::Rgba8Snorm => Some(Format::R8G8B8A8_SNORM), + ImageFormat::Rg32f => Some(Format::R32G32_SFLOAT), + ImageFormat::Rg16f => Some(Format::R16G16_SFLOAT), + ImageFormat::R11fG11fB10f => Some(Format::B10G11R11_UFLOAT_PACK32), + ImageFormat::R16f => Some(Format::R16_SFLOAT), + ImageFormat::Rgba16 => Some(Format::R16G16B16A16_UNORM), + ImageFormat::Rgb10A2 => Some(Format::A2B10G10R10_UNORM_PACK32), + ImageFormat::Rg16 => Some(Format::R16G16_UNORM), + ImageFormat::Rg8 => Some(Format::R8G8_UNORM), + ImageFormat::R16 => Some(Format::R16_UNORM), + ImageFormat::R8 => Some(Format::R8_UNORM), + ImageFormat::Rgba16Snorm => Some(Format::R16G16B16A16_SNORM), + ImageFormat::Rg16Snorm => Some(Format::R16G16_SNORM), + ImageFormat::Rg8Snorm => Some(Format::R8G8_SNORM), + ImageFormat::R16Snorm => Some(Format::R16_SNORM), + ImageFormat::R8Snorm => Some(Format::R8_SNORM), + ImageFormat::Rgba32i => Some(Format::R32G32B32A32_SINT), + ImageFormat::Rgba16i => Some(Format::R16G16B16A16_SINT), + ImageFormat::Rgba8i => Some(Format::R8G8B8A8_SINT), + ImageFormat::R32i => Some(Format::R32_SINT), + ImageFormat::Rg32i => Some(Format::R32G32_SINT), + ImageFormat::Rg16i => Some(Format::R16G16_SINT), + ImageFormat::Rg8i => Some(Format::R8G8_SINT), + ImageFormat::R16i => Some(Format::R16_SINT), + ImageFormat::R8i => Some(Format::R8_SINT), + ImageFormat::Rgba32ui => Some(Format::R32G32B32A32_UINT), + ImageFormat::Rgba16ui => Some(Format::R16G16B16A16_UINT), + ImageFormat::Rgba8ui => Some(Format::R8G8B8A8_UINT), + ImageFormat::R32ui => Some(Format::R32_UINT), + ImageFormat::Rgb10a2ui => Some(Format::A2B10G10R10_UINT_PACK32), + ImageFormat::Rg32ui => Some(Format::R32G32_UINT), + ImageFormat::Rg16ui => Some(Format::R16G16_UINT), + ImageFormat::Rg8ui => Some(Format::R8G8_UINT), + ImageFormat::R16ui => Some(Format::R16_UINT), + ImageFormat::R8ui => Some(Format::R8_UINT), + ImageFormat::R64ui => Some(Format::R64_UINT), + ImageFormat::R64i => Some(Format::R64_SINT), + } + } +} + /// The block compression scheme used in a format. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum CompressionType { diff --git a/vulkano/src/pipeline/compute_pipeline.rs b/vulkano/src/pipeline/compute_pipeline.rs index 1726f2b5..c561bbb2 100644 --- a/vulkano/src/pipeline/compute_pipeline.rs +++ b/vulkano/src/pipeline/compute_pipeline.rs @@ -8,19 +8,17 @@ // according to those terms. use crate::check_errors; -use crate::descriptor_set::layout::DescriptorSetDesc; -use crate::descriptor_set::layout::DescriptorSetLayout; -use crate::device::Device; -use crate::device::DeviceOwned; +use crate::descriptor_set::layout::{DescriptorSetDesc, DescriptorSetLayout}; +use crate::device::{Device, DeviceOwned}; use crate::pipeline::cache::PipelineCache; -use crate::pipeline::layout::PipelineLayout; -use crate::pipeline::layout::PipelineLayoutCreationError; -use crate::pipeline::layout::PipelineLayoutSupersetError; -use crate::pipeline::shader::EntryPointAbstract; -use crate::pipeline::shader::SpecializationConstants; +use crate::pipeline::layout::{ + PipelineLayout, PipelineLayoutCreationError, PipelineLayoutSupersetError, +}; +use crate::pipeline::shader::{ComputeEntryPoint, DescriptorRequirements, SpecializationConstants}; use crate::Error; use crate::OomError; use crate::VulkanObject; +use fnv::FnvHashMap; use std::error; use std::fmt; use std::marker::PhantomData; @@ -43,6 +41,7 @@ use std::sync::Arc; pub struct ComputePipeline { inner: Inner, pipeline_layout: Arc, + descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements>, } struct Inner { @@ -56,19 +55,19 @@ impl ComputePipeline { /// `func` is a closure that is given a mutable reference to the inferred descriptor set /// definitions. This can be used to make changes to the layout before it's created, for example /// to add dynamic buffers or immutable samplers. - pub fn new( + pub fn new( device: Arc, - shader: &Cs, + shader: &ComputeEntryPoint, spec_constants: &Css, cache: Option>, func: F, ) -> Result where - Cs: EntryPointAbstract, Css: SpecializationConstants, F: FnOnce(&mut [DescriptorSetDesc]), { - let mut descriptor_set_layout_descs = shader.descriptor_set_layout_descs().to_owned(); + let mut descriptor_set_layout_descs = + DescriptorSetDesc::from_requirements(shader.descriptor_requirements()); func(&mut descriptor_set_layout_descs); let descriptor_set_layouts = descriptor_set_layout_descs @@ -101,15 +100,14 @@ impl ComputePipeline { /// /// An error will be returned if the pipeline layout isn't a superset of what the shader /// uses. - pub fn with_pipeline_layout( + pub fn with_pipeline_layout( device: Arc, - shader: &Cs, + shader: &ComputeEntryPoint, spec_constants: &Css, pipeline_layout: Arc, cache: Option>, ) -> Result where - Cs: EntryPointAbstract, Css: SpecializationConstants, { if Css::descriptors() != shader.spec_constants() { @@ -118,7 +116,7 @@ impl ComputePipeline { unsafe { pipeline_layout.ensure_compatible_with_shader( - shader.descriptor_set_layout_descs(), + shader.descriptor_requirements(), shader.push_constant_range(), )?; ComputePipeline::with_unchecked_pipeline_layout( @@ -133,15 +131,14 @@ impl ComputePipeline { /// Same as `with_pipeline_layout`, but doesn't check whether the pipeline layout is a /// superset of what the shader expects. - pub unsafe fn with_unchecked_pipeline_layout( + pub unsafe fn with_unchecked_pipeline_layout( device: Arc, - shader: &Cs, + shader: &ComputeEntryPoint, spec_constants: &Css, pipeline_layout: Arc, cache: Option>, ) -> Result where - Cs: EntryPointAbstract, Css: SpecializationConstants, { let fns = device.fns(); @@ -200,6 +197,10 @@ impl ComputePipeline { pipeline: pipeline, }, pipeline_layout: pipeline_layout, + descriptor_requirements: shader + .descriptor_requirements() + .map(|(loc, reqs)| (loc, reqs.clone())) + .collect(), }) } @@ -214,6 +215,16 @@ impl ComputePipeline { pub fn layout(&self) -> &Arc { &self.pipeline_layout } + + /// Returns an iterator over the descriptor requirements for this pipeline. + #[inline] + pub fn descriptor_requirements( + &self, + ) -> impl ExactSizeIterator { + self.descriptor_requirements + .iter() + .map(|(loc, reqs)| (*loc, reqs)) + } } impl fmt::Debug for ComputePipeline { @@ -362,10 +373,9 @@ mod tests { use crate::buffer::CpuAccessibleBuffer; use crate::command_buffer::AutoCommandBufferBuilder; use crate::command_buffer::CommandBufferUsage; - use crate::descriptor_set::layout::DescriptorDesc; - use crate::descriptor_set::layout::DescriptorDescTy; - use crate::descriptor_set::layout::DescriptorSetDesc; + use crate::descriptor_set::layout::DescriptorType; use crate::descriptor_set::PersistentDescriptorSet; + use crate::pipeline::shader::DescriptorRequirements; use crate::pipeline::shader::ShaderModule; use crate::pipeline::shader::ShaderStages; use crate::pipeline::shader::SpecializationConstants; @@ -432,16 +442,21 @@ mod tests { static NAME: [u8; 5] = [109, 97, 105, 110, 0]; // "main" module.compute_entry_point( CStr::from_ptr(NAME.as_ptr() as *const _), - [DescriptorSetDesc::new([Some(DescriptorDesc { - ty: DescriptorDescTy::StorageBuffer, - descriptor_count: 1, - stages: ShaderStages { - compute: true, - ..ShaderStages::none() + [( + (0, 0), + DescriptorRequirements { + descriptor_types: vec![DescriptorType::StorageBuffer], + descriptor_count: 1, + format: None, + image_view_type: None, + multisampled: false, + mutable: false, + stages: ShaderStages { + compute: true, + ..ShaderStages::none() + }, }, - mutable: false, - variable_count: false, - })])], + )], None, SpecConsts::descriptors(), ) diff --git a/vulkano/src/pipeline/graphics_pipeline/builder.rs b/vulkano/src/pipeline/graphics_pipeline/builder.rs index 82d07b8f..030cd5aa 100644 --- a/vulkano/src/pipeline/graphics_pipeline/builder.rs +++ b/vulkano/src/pipeline/graphics_pipeline/builder.rs @@ -28,7 +28,7 @@ use crate::pipeline::layout::{PipelineLayout, PipelineLayoutCreationError, Pipel use crate::pipeline::multisample::MultisampleState; use crate::pipeline::rasterization::{CullMode, FrontFace, PolygonMode, RasterizationState}; use crate::pipeline::shader::{ - EntryPointAbstract, GraphicsEntryPoint, GraphicsShaderType, ShaderStage, + DescriptorRequirements, GraphicsEntryPoint, GraphicsShaderType, ShaderStage, SpecializationConstants, }; use crate::pipeline::tessellation::TessellationState; @@ -163,12 +163,39 @@ where } } - let mut descriptor_set_layout_descs = stages + // Produce `DescriptorRequirements` for each binding, by iterating over all shaders + // and adding the requirements of each. + let mut descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements> = + HashMap::default(); + + for (loc, reqs) in stages .iter() - .try_fold(vec![], |total, shader| -> Result<_, ()> { - DescriptorSetDesc::union_multiple(&total, shader.descriptor_set_layout_descs()) - }) - .expect("Can't be union'd"); + .map(|shader| shader.descriptor_requirements()) + .flatten() + { + match descriptor_requirements.entry(loc) { + Entry::Occupied(entry) => { + // Previous shaders already added requirements, so we produce the + // intersection of the previous requirements and those of the + // current shader. + let previous = entry.into_mut(); + *previous = previous.intersection(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + // No previous shader had this descriptor yet, so we just insert the + // requirements. + entry.insert(reqs.clone()); + } + } + } + + // Build a description of a descriptor set layout from the shader requirements, then + // feed it to the user-provided closure to allow tweaking. + let mut descriptor_set_layout_descs = DescriptorSetDesc::from_requirements( + descriptor_requirements + .iter() + .map(|(&loc, reqs)| (loc, reqs)), + ); func(&mut descriptor_set_layout_descs); // We want to union each push constant range into a set of ranges that do not have intersecting stage flags. @@ -227,47 +254,104 @@ where // Checking that the pipeline layout matches the shader stages. // TODO: more details in the errors + let mut descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements> = + HashMap::default(); { let shader = &self.vertex_shader.as_ref().unwrap().0; pipeline_layout.ensure_compatible_with_shader( - shader.descriptor_set_layout_descs(), + shader.descriptor_requirements(), shader.push_constant_range(), )?; + for (loc, reqs) in shader.descriptor_requirements() { + match descriptor_requirements.entry(loc) { + Entry::Occupied(entry) => { + let previous = entry.into_mut(); + *previous = previous.intersection(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + entry.insert(reqs.clone()); + } + } + } } if let Some(ref geometry_shader) = self.geometry_shader { let shader = &geometry_shader.0; pipeline_layout.ensure_compatible_with_shader( - shader.descriptor_set_layout_descs(), + shader.descriptor_requirements(), shader.push_constant_range(), )?; + for (loc, reqs) in shader.descriptor_requirements() { + match descriptor_requirements.entry(loc) { + Entry::Occupied(entry) => { + let previous = entry.into_mut(); + *previous = previous.intersection(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + entry.insert(reqs.clone()); + } + } + } } if let Some(ref tess) = self.tessellation_shaders { { let shader = &tess.tessellation_control_shader.0; pipeline_layout.ensure_compatible_with_shader( - shader.descriptor_set_layout_descs(), + shader.descriptor_requirements(), shader.push_constant_range(), )?; + for (loc, reqs) in shader.descriptor_requirements() { + match descriptor_requirements.entry(loc) { + Entry::Occupied(entry) => { + let previous = entry.into_mut(); + *previous = previous.intersection(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + entry.insert(reqs.clone()); + } + } + } } { let shader = &tess.tessellation_evaluation_shader.0; pipeline_layout.ensure_compatible_with_shader( - shader.descriptor_set_layout_descs(), + shader.descriptor_requirements(), shader.push_constant_range(), )?; + for (loc, reqs) in shader.descriptor_requirements() { + match descriptor_requirements.entry(loc) { + Entry::Occupied(entry) => { + let previous = entry.into_mut(); + *previous = previous.intersection(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + entry.insert(reqs.clone()); + } + } + } } } if let Some(ref fragment_shader) = self.fragment_shader { let shader = &fragment_shader.0; pipeline_layout.ensure_compatible_with_shader( - shader.descriptor_set_layout_descs(), + shader.descriptor_requirements(), shader.push_constant_range(), )?; + for (loc, reqs) in shader.descriptor_requirements() { + match descriptor_requirements.entry(loc) { + Entry::Occupied(entry) => { + let previous = entry.into_mut(); + *previous = previous.intersection(reqs).expect("Could not produce an intersection of the shader descriptor requirements"); + } + Entry::Vacant(entry) => { + entry.insert(reqs.clone()); + } + } + } // Check that the subpass can accept the output of the fragment shader. // TODO: If there is no fragment shader, what should be checked then? The previous stage? @@ -938,6 +1022,7 @@ where layout: pipeline_layout, subpass, shaders, + descriptor_requirements, vertex_input, // Can be None if there's a mesh shader, but we don't support that yet input_assembly_state: self.input_assembly_state, // Can be None if there's a mesh shader, but we don't support that yet diff --git a/vulkano/src/pipeline/graphics_pipeline/mod.rs b/vulkano/src/pipeline/graphics_pipeline/mod.rs index 59727763..da5cc5d7 100644 --- a/vulkano/src/pipeline/graphics_pipeline/mod.rs +++ b/vulkano/src/pipeline/graphics_pipeline/mod.rs @@ -17,7 +17,7 @@ use crate::pipeline::input_assembly::InputAssemblyState; use crate::pipeline::layout::PipelineLayout; use crate::pipeline::multisample::MultisampleState; use crate::pipeline::rasterization::RasterizationState; -use crate::pipeline::shader::ShaderStage; +use crate::pipeline::shader::{DescriptorRequirements, ShaderStage}; use crate::pipeline::tessellation::TessellationState; use crate::pipeline::vertex::{BuffersDefinition, VertexInput}; use crate::pipeline::viewport::ViewportState; @@ -47,6 +47,7 @@ pub struct GraphicsPipeline { subpass: Subpass, // TODO: replace () with an object that describes the shaders in some way. shaders: FnvHashMap, + descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements>, vertex_input: VertexInput, input_assembly_state: InputAssemblyState, @@ -114,6 +115,16 @@ impl GraphicsPipeline { self.shaders.get(&stage).copied() } + /// Returns an iterator over the descriptor requirements for this pipeline. + #[inline] + pub fn descriptor_requirements( + &self, + ) -> impl ExactSizeIterator { + self.descriptor_requirements + .iter() + .map(|(loc, reqs)| (*loc, reqs)) + } + /// Returns the vertex input state used to create this pipeline. #[inline] pub fn vertex_input(&self) -> &VertexInput { diff --git a/vulkano/src/pipeline/layout/sys.rs b/vulkano/src/pipeline/layout/sys.rs index 2cacef73..4a683c9b 100644 --- a/vulkano/src/pipeline/layout/sys.rs +++ b/vulkano/src/pipeline/layout/sys.rs @@ -9,19 +9,18 @@ use super::limits_check; use crate::check_errors; -use crate::descriptor_set::layout::DescriptorSetCompatibilityError; -use crate::descriptor_set::layout::DescriptorSetDesc; +use crate::descriptor_set::layout::DescriptorRequirementsNotMet; use crate::descriptor_set::layout::DescriptorSetLayout; use crate::descriptor_set::layout::DescriptorSetLayoutError; use crate::device::Device; use crate::device::DeviceOwned; use crate::pipeline::layout::PipelineLayoutLimitsError; +use crate::pipeline::shader::DescriptorRequirements; use crate::pipeline::shader::ShaderStages; use crate::Error; use crate::OomError; use crate::VulkanObject; use smallvec::SmallVec; -use std::cmp; use std::error; use std::fmt; use std::mem::MaybeUninit; @@ -190,32 +189,32 @@ impl PipelineLayout { /// Makes sure that `self` is a superset of the provided descriptor set layouts and push /// constant ranges. Returns an `Err` if this is not the case. - pub fn ensure_compatible_with_shader( + pub fn ensure_compatible_with_shader<'a>( &self, - descriptor_set_layout_descs: &[DescriptorSetDesc], + descriptor_requirements: impl IntoIterator, push_constant_range: &Option, ) -> Result<(), PipelineLayoutSupersetError> { - // Ewwwwwww - let empty = DescriptorSetDesc::empty(); - let num_sets = cmp::max( - self.descriptor_set_layouts.len(), - descriptor_set_layout_descs.len(), - ); - - for set_num in 0..num_sets { - let first = self + for ((set_num, binding_num), reqs) in descriptor_requirements.into_iter() { + let descriptor_desc = self .descriptor_set_layouts - .get(set_num) - .map(|set| set.desc()) - .unwrap_or_else(|| &empty); - let second = descriptor_set_layout_descs - .get(set_num) - .unwrap_or_else(|| &empty); + .get(set_num as usize) + .and_then(|set_desc| set_desc.descriptor(binding_num)); - if let Err(error) = first.ensure_compatible_with_shader(second) { - return Err(PipelineLayoutSupersetError::DescriptorSet { + let descriptor_desc = match descriptor_desc { + Some(x) => x, + None => { + return Err(PipelineLayoutSupersetError::DescriptorMissing { + set_num, + binding_num, + }) + } + }; + + if let Err(error) = descriptor_desc.ensure_compatible_with_shader(reqs) { + return Err(PipelineLayoutSupersetError::DescriptorRequirementsNotMet { + set_num, + binding_num, error, - set_num: set_num as u32, }); } } @@ -379,9 +378,14 @@ impl From for PipelineLayoutCreationError { /// Error when checking whether a pipeline layout is a superset of another one. #[derive(Clone, Debug, PartialEq, Eq)] pub enum PipelineLayoutSupersetError { - DescriptorSet { - error: DescriptorSetCompatibilityError, + DescriptorMissing { set_num: u32, + binding_num: u32, + }, + DescriptorRequirementsNotMet { + set_num: u32, + binding_num: u32, + error: DescriptorRequirementsNotMet, }, PushConstantRange { first_range: PipelineLayoutPcRange, @@ -393,8 +397,10 @@ impl error::Error for PipelineLayoutSupersetError { #[inline] fn source(&self) -> Option<&(dyn error::Error + 'static)> { match *self { - PipelineLayoutSupersetError::DescriptorSet { ref error, .. } => Some(error), - ref error @ PipelineLayoutSupersetError::PushConstantRange { .. } => Some(error), + PipelineLayoutSupersetError::DescriptorRequirementsNotMet { ref error, .. } => { + Some(error) + } + _ => None, } } } @@ -402,10 +408,20 @@ impl error::Error for PipelineLayoutSupersetError { impl fmt::Display for PipelineLayoutSupersetError { #[inline] fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { - match *self { - PipelineLayoutSupersetError::DescriptorSet { .. } => { - write!(fmt, "the descriptor set was not a superset of the other") - } + match self { + PipelineLayoutSupersetError::DescriptorRequirementsNotMet { set_num, binding_num, .. } => write!( + fmt, + "the descriptor at set {} binding {} does not meet the requirements", + set_num, binding_num + ), + PipelineLayoutSupersetError::DescriptorMissing { + set_num, + binding_num, + } => write!( + fmt, + "a descriptor at set {} binding {} is required by the shaders, but is missing from the pipeline layout", + set_num, binding_num + ), PipelineLayoutSupersetError::PushConstantRange { first_range, second_range, diff --git a/vulkano/src/pipeline/shader.rs b/vulkano/src/pipeline/shader.rs index e3423c78..b35d98b8 100644 --- a/vulkano/src/pipeline/shader.rs +++ b/vulkano/src/pipeline/shader.rs @@ -18,18 +18,22 @@ //! `vulkano-shaders` crate that will generate Rust code that wraps around vulkano's shaders API. use crate::check_errors; -use crate::descriptor_set::layout::DescriptorSetDesc; +use crate::descriptor_set::layout::DescriptorType; use crate::device::Device; use crate::format::Format; +use crate::image::view::ImageViewType; use crate::pipeline::input_assembly::PrimitiveTopology; use crate::pipeline::layout::PipelineLayoutPcRange; use crate::sync::PipelineStages; use crate::OomError; use crate::VulkanObject; +use fnv::FnvHashMap; use std::borrow::Cow; use std::error; +use std::error::Error; use std::ffi::CStr; use std::fmt; +use std::fmt::Display; use std::mem; use std::mem::MaybeUninit; use std::ops::BitOr; @@ -128,23 +132,20 @@ impl ShaderModule { /// - The input, output and layout must correctly describe the input, output and layout used /// by this stage. /// - pub unsafe fn graphics_entry_point<'a, D>( + pub unsafe fn graphics_entry_point<'a>( &'a self, name: &'a CStr, - descriptor_set_layout_descs: D, + descriptor_requirements: impl IntoIterator, push_constant_range: Option, spec_constants: &'static [SpecializationMapEntry], input: ShaderInterface, output: ShaderInterface, ty: GraphicsShaderType, - ) -> GraphicsEntryPoint<'a> - where - D: IntoIterator, - { + ) -> GraphicsEntryPoint<'a> { GraphicsEntryPoint { module: self, name, - descriptor_set_layout_descs: descriptor_set_layout_descs.into_iter().collect(), + descriptor_requirements: descriptor_requirements.into_iter().collect(), push_constant_range, spec_constants, input, @@ -165,20 +166,17 @@ impl ShaderModule { /// - The layout must correctly describe the layout used by this stage. /// #[inline] - pub unsafe fn compute_entry_point<'a, D>( + pub unsafe fn compute_entry_point<'a>( &'a self, name: &'a CStr, - descriptor_set_layout_descs: D, + descriptor_requirements: impl IntoIterator, push_constant_range: Option, spec_constants: &'static [SpecializationMapEntry], - ) -> ComputeEntryPoint<'a> - where - D: IntoIterator, - { + ) -> ComputeEntryPoint<'a> { ComputeEntryPoint { module: self, name, - descriptor_set_layout_descs: descriptor_set_layout_descs.into_iter().collect(), + descriptor_requirements: descriptor_requirements.into_iter().collect(), push_constant_range, spec_constants, } @@ -205,23 +203,6 @@ impl Drop for ShaderModule { } } -pub unsafe trait EntryPointAbstract { - /// Returns the module this entry point comes from. - fn module(&self) -> &ShaderModule; - - /// Returns the name of the entry point. - fn name(&self) -> &CStr; - - /// Returns a description of the descriptor set layouts. - fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]; - - /// Returns the push constant ranges. - fn push_constant_range(&self) -> &Option; - - /// Returns the layout of the specialization constants. - fn spec_constants(&self) -> &[SpecializationMapEntry]; -} - /// Represents a shader entry point in a shader module. /// /// Can be obtained by calling `entry_point()` on the shader module. @@ -230,7 +211,7 @@ pub struct GraphicsEntryPoint<'a> { module: &'a ShaderModule, name: &'a CStr, - descriptor_set_layout_descs: Vec, + descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements>, push_constant_range: Option, spec_constants: &'static [SpecializationMapEntry], input: ShaderInterface, @@ -239,6 +220,38 @@ pub struct GraphicsEntryPoint<'a> { } impl<'a> GraphicsEntryPoint<'a> { + /// Returns the module this entry point comes from. + #[inline] + pub fn module(&self) -> &ShaderModule { + self.module + } + + /// Returns the name of the entry point. + #[inline] + pub fn name(&self) -> &CStr { + self.name + } + + /// Returns the descriptor requirements. + #[inline] + pub fn descriptor_requirements( + &self, + ) -> impl ExactSizeIterator { + self.descriptor_requirements.iter().map(|(k, v)| (*k, v)) + } + + /// Returns the push constant ranges. + #[inline] + pub fn push_constant_range(&self) -> &Option { + &self.push_constant_range + } + + /// Returns the layout of the specialization constants. + #[inline] + pub fn spec_constants(&self) -> &[SpecializationMapEntry] { + self.spec_constants + } + /// Returns the input attributes used by the shader stage. #[inline] pub fn input(&self) -> &ShaderInterface { @@ -258,33 +271,6 @@ impl<'a> GraphicsEntryPoint<'a> { } } -unsafe impl<'a> EntryPointAbstract for GraphicsEntryPoint<'a> { - #[inline] - fn module(&self) -> &ShaderModule { - self.module - } - - #[inline] - fn name(&self) -> &CStr { - self.name - } - - #[inline] - fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc] { - &self.descriptor_set_layout_descs - } - - #[inline] - fn push_constant_range(&self) -> &Option { - &self.push_constant_range - } - - #[inline] - fn spec_constants(&self) -> &[SpecializationMapEntry] { - self.spec_constants - } -} - #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum GraphicsShaderType { Vertex, @@ -343,34 +329,41 @@ impl GeometryShaderExecutionMode { pub struct ComputeEntryPoint<'a> { module: &'a ShaderModule, name: &'a CStr, - descriptor_set_layout_descs: Vec, + descriptor_requirements: FnvHashMap<(u32, u32), DescriptorRequirements>, push_constant_range: Option, spec_constants: &'static [SpecializationMapEntry], } -unsafe impl<'a> EntryPointAbstract for ComputeEntryPoint<'a> { +impl<'a> ComputeEntryPoint<'a> { + /// Returns the module this entry point comes from. #[inline] - fn module(&self) -> &ShaderModule { + pub fn module(&self) -> &ShaderModule { self.module } + /// Returns the name of the entry point. #[inline] - fn name(&self) -> &CStr { + pub fn name(&self) -> &CStr { self.name } + /// Returns the descriptor requirements. #[inline] - fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc] { - &self.descriptor_set_layout_descs + pub fn descriptor_requirements( + &self, + ) -> impl ExactSizeIterator { + self.descriptor_requirements.iter().map(|(k, v)| (*k, v)) } + /// Returns the push constant ranges. #[inline] - fn push_constant_range(&self) -> &Option { + pub fn push_constant_range(&self) -> &Option { &self.push_constant_range } + /// Returns the layout of the specialization constants. #[inline] - fn spec_constants(&self) -> &[SpecializationMapEntry] { + pub fn spec_constants(&self) -> &[SpecializationMapEntry] { self.spec_constants } } @@ -529,12 +522,6 @@ impl fmt::Display for ShaderInterfaceMismatchError { /// /// This trait is implemented on `()` for shaders that don't have any specialization constant. /// -/// Note that it is the shader module that chooses which type that implements -/// `SpecializationConstants` it is possible to pass when creating the pipeline, through [the -/// `EntryPointAbstract` trait](crate::pipeline::shader::EntryPointAbstract). Therefore there is generally no -/// point to implement this trait yourself, unless you are also writing your own implementation of -/// `EntryPointAbstract`. -/// /// # Example /// /// ```rust @@ -637,7 +624,7 @@ impl From for ash::vk::ShaderStageFlags { /// A set of shader stages. // TODO: add example with BitOr -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] pub struct ShaderStages { pub vertex: bool, pub tessellation_control: bool, @@ -813,3 +800,117 @@ impl From for PipelineStages { } } } + +/// The requirements imposed by a shader on a descriptor within a descriptor set layout, and on any +/// resource that is bound to that descriptor. +#[derive(Clone, Debug, Default)] +pub struct DescriptorRequirements { + /// The descriptor types that are allowed. + pub descriptor_types: Vec, + + /// The number of descriptors (array elements) that the shader requires. The descriptor set + /// layout can declare more than this, but never less. + pub descriptor_count: u32, + + /// The image format that is required for image views bound to this descriptor. If this is + /// `None`, then any image format is allowed. + pub format: Option, + + /// The view type that is required for image views bound to this descriptor. This is `None` for + /// non-image descriptors. + pub image_view_type: Option, + + /// Whether image views bound to this descriptor must have multisampling enabled or disabled. + pub multisampled: bool, + + /// Whether the shader requires mutable (exclusive) access to the resource bound to this + /// descriptor. + pub mutable: bool, + + /// The shader stages that the descriptor must be declared for. + pub stages: ShaderStages, +} + +impl DescriptorRequirements { + /// Produces the intersection of two descriptor requirements, so that the requirements of both + /// are satisfied. An error is returned if the requirements conflict. + pub fn intersection(&self, other: &Self) -> Result { + let descriptor_types: Vec<_> = self + .descriptor_types + .iter() + .copied() + .filter(|ty| other.descriptor_types.contains(&ty)) + .collect(); + + if descriptor_types.is_empty() { + return Err(DescriptorRequirementsIncompatible::DescriptorType); + } + + if let (Some(first), Some(second)) = (self.format, other.format) { + if first != second { + return Err(DescriptorRequirementsIncompatible::Format); + } + } + + if let (Some(first), Some(second)) = (self.image_view_type, other.image_view_type) { + if first != second { + return Err(DescriptorRequirementsIncompatible::ImageViewType); + } + } + + if self.multisampled != other.multisampled { + return Err(DescriptorRequirementsIncompatible::Multisampled); + } + + Ok(Self { + descriptor_types, + descriptor_count: self.descriptor_count.max(other.descriptor_count), + format: self.format.or(other.format), + image_view_type: self.image_view_type.or(other.image_view_type), + multisampled: self.multisampled, + mutable: self.mutable || other.mutable, + stages: self.stages | other.stages, + }) + } +} + +/// An error that can be returned when trying to create the intersection of two +/// `DescriptorRequirements` values. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DescriptorRequirementsIncompatible { + /// The allowed descriptor types of the descriptors do not overlap. + DescriptorType, + /// The descriptors require different formats. + Format, + /// The descriptors require different image view types. + ImageViewType, + /// The multisampling requirements of the descriptors differ. + Multisampled, +} + +impl Error for DescriptorRequirementsIncompatible {} + +impl Display for DescriptorRequirementsIncompatible { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DescriptorRequirementsIncompatible::DescriptorType => { + write!( + f, + "the allowed descriptor types of the two descriptors do not overlap" + ) + } + DescriptorRequirementsIncompatible::Format => { + write!(f, "the descriptors require different formats") + } + DescriptorRequirementsIncompatible::ImageViewType => { + write!(f, "the descriptors require different image view types") + } + DescriptorRequirementsIncompatible::Multisampled => { + write!( + f, + "the multisampling requirements of the descriptors differ" + ) + } + } + } +}