diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 78c28f362..a1e55bcd5 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -2582,7 +2582,7 @@ impl Global { if let Some(ref module) = shader_module.module { interface = validation::check_stage( module, - &group_layouts, + validation::IntrospectionBindGroupLayouts::Given(&group_layouts), &entry_point_name, naga::ShaderStage::Vertex, interface, @@ -2614,7 +2614,7 @@ impl Global { if let Some(ref module) = shader_module.module { interface = validation::check_stage( module, - &group_layouts, + validation::IntrospectionBindGroupLayouts::Given(&group_layouts), &entry_point_name, naga::ShaderStage::Fragment, interface, @@ -2820,17 +2820,32 @@ impl Global { let device = device_guard .get(device_id) .map_err(|_| DeviceError::Invalid)?; - let (raw_pipeline, layout_ref_count) = { + let (raw_pipeline, layout_id, layout_ref_count) = { let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token); let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token); - let layout = pipeline_layout_guard - .get(desc.layout) - .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?; - let group_layouts = layout - .bind_group_layout_ids - .iter() - .map(|&id| &bgl_guard[id].entries) - .collect::>(); + + let mut derived_group_layouts = + ArrayVec::<[binding_model::BindEntryMap; MAX_BIND_GROUPS]>::new(); + let given_group_layouts: ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>; + let group_layouts = match desc.layout { + Some(pipeline_layout_id) => { + let layout = pipeline_layout_guard + .get(pipeline_layout_id) + .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?; + given_group_layouts = layout + .bind_group_layout_ids + .iter() + .map(|&id| &bgl_guard[id].entries) + .collect(); + validation::IntrospectionBindGroupLayouts::Given(&given_group_layouts) + } + None => { + for _ in 0..device.limits.max_bind_groups { + derived_group_layouts.push(binding_model::BindEntryMap::default()); + } + validation::IntrospectionBindGroupLayouts::Derived(&mut derived_group_layouts) + } + }; let interface = validation::StageInterface::default(); let pipeline_stage = &desc.compute_stage; @@ -2849,7 +2864,7 @@ impl Global { if let Some(ref module) = shader_module.module { let _ = validation::check_stage( module, - &group_layouts, + group_layouts, &entry_point_name, naga::ShaderStage::Compute, interface, @@ -2868,6 +2883,17 @@ impl Global { // TODO let parent = hal::pso::BasePipeline::None; + let pipeline_layout_id = match desc.layout { + Some(id) => id, + None => { + //TODO: create a new pipeline layout + unimplemented!() + } + }; + let layout = pipeline_layout_guard + .get(pipeline_layout_id) + .map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?; + let pipeline_desc = hal::pso::ComputePipelineDesc { shader, layout: &layout.raw, @@ -2876,7 +2902,7 @@ impl Global { }; match unsafe { device.raw.create_compute_pipeline(&pipeline_desc, None) } { - Ok(pipeline) => (pipeline, layout.life_guard.add_ref()), + Ok(pipeline) => (pipeline, pipeline_layout_id, layout.life_guard.add_ref()), Err(hal::pso::CreationError::OutOfMemory(_)) => { return Err(pipeline::CreateComputePipelineError::Device( DeviceError::OutOfMemory, @@ -2889,7 +2915,7 @@ impl Global { let pipeline = pipeline::ComputePipeline { raw: raw_pipeline, layout_id: Stored { - value: id::Valid(desc.layout), + value: id::Valid(layout_id), ref_count: layout_ref_count, }, device_id: Stored { diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 859136b3b..2a279e316 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -3,6 +3,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use crate::{binding_model::BindEntryMap, FastHashMap}; +use std::collections::hash_map::Entry; use thiserror::Error; use wgt::{BindGroupLayoutEntry, BindingType}; @@ -71,6 +72,8 @@ pub enum BindingError { WrongTextureMultisampled, #[error("comparison flag doesn't match the shader")] WrongSamplerComparison, + #[error("derived bind group layout type is not consistent between stages")] + InconsistentlyDerivedType, } #[derive(Clone, Debug, Error)] @@ -159,13 +162,12 @@ fn get_aligned_type_size( } } -fn check_binding( +fn check_binding_use( module: &naga::Module, var: &naga::GlobalVariable, entry: &BindGroupLayoutEntry, - usage: naga::GlobalUse, -) -> Result<(), BindingError> { - let allowed_usage = match module.types[var.ty].inner { +) -> Result { + match module.types[var.ty].inner { naga::TypeInner::Struct { ref members } => { let (allowed_usage, min_size) = match entry.ty { BindingType::UniformBuffer { @@ -196,17 +198,17 @@ fn check_binding( } _ => (), } - allowed_usage + Ok(allowed_usage) } naga::TypeInner::Sampler { comparison } => match entry.ty { BindingType::Sampler { comparison: cmp } => { if cmp == comparison { - naga::GlobalUse::empty() + Ok(naga::GlobalUse::empty()) } else { - return Err(BindingError::WrongSamplerComparison); + Err(BindingError::WrongSamplerComparison) } } - _ => return Err(BindingError::WrongType), + _ => Err(BindingError::WrongType), }, naga::TypeInner::Image { base, dim, flags } => { if flags.contains(naga::ImageFlags::MULTISAMPLED) { @@ -284,14 +286,9 @@ fn check_binding( if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) { return Err(BindingError::WrongTextureSampled); } - allowed_usage + Ok(allowed_usage) } - _ => return Err(BindingError::WrongType), - }; - if allowed_usage.contains(usage) { - Ok(()) - } else { - Err(BindingError::WrongUsage(usage)) + _ => Err(BindingError::WrongType), } } @@ -685,9 +682,80 @@ pub fn check_texture_format(format: wgt::TextureFormat, output: &naga::TypeInner pub type StageInterface<'a> = FastHashMap>; +pub enum IntrospectionBindGroupLayouts<'a> { + Given(&'a [&'a BindEntryMap]), + Derived(&'a mut [BindEntryMap]), +} + +fn derive_binding_type( + module: &naga::Module, + var: &naga::GlobalVariable, + usage: naga::GlobalUse, +) -> Result { + let ty = &module.types[var.ty]; + Ok(match ty.inner { + naga::TypeInner::Struct { ref members } => { + let dynamic = false; + let mut actual_size = 0; + for (i, member) in members.iter().enumerate() { + actual_size += get_aligned_type_size(module, member.ty, i + 1 == members.len()); + } + match var.class { + naga::StorageClass::Uniform => BindingType::UniformBuffer { + dynamic, + min_binding_size: wgt::BufferSize::new(actual_size), + }, + naga::StorageClass::StorageBuffer => BindingType::StorageBuffer { + dynamic, + min_binding_size: wgt::BufferSize::new(actual_size), + readonly: !usage.contains(naga::GlobalUse::STORE), //TODO: clarify + }, + _ => return Err(BindingError::WrongType), + } + } + naga::TypeInner::Sampler { comparison } => BindingType::Sampler { comparison }, + naga::TypeInner::Image { base, dim, flags } => { + let array = flags.contains(naga::ImageFlags::ARRAYED); + let dimension = match dim { + naga::ImageDimension::D1 => wgt::TextureViewDimension::D1, + naga::ImageDimension::D2 if array => wgt::TextureViewDimension::D2Array, + naga::ImageDimension::D2 => wgt::TextureViewDimension::D2, + naga::ImageDimension::D3 => wgt::TextureViewDimension::D3, + naga::ImageDimension::Cube if array => wgt::TextureViewDimension::CubeArray, + naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube, + }; + if flags.contains(naga::ImageFlags::SAMPLED) { + BindingType::SampledTexture { + dimension, + component_type: match module.types[base].inner { + naga::TypeInner::Scalar { kind, .. } + | naga::TypeInner::Vector { kind, .. } => match kind { + naga::ScalarKind::Float => wgt::TextureComponentType::Float, + naga::ScalarKind::Sint => wgt::TextureComponentType::Sint, + naga::ScalarKind::Uint => wgt::TextureComponentType::Uint, + other => { + return Err(BindingError::WrongTextureComponentType(Some(other))) + } + }, + _ => return Err(BindingError::WrongTextureComponentType(None)), + }, + multisampled: flags.contains(naga::ImageFlags::MULTISAMPLED), + } + } else { + BindingType::StorageTexture { + dimension, + format: wgt::TextureFormat::Rgba8Unorm, //TODO + readonly: !flags.contains(naga::ImageFlags::CAN_STORE), + } + } + } + _ => return Err(BindingError::WrongType), + }) +} + pub fn check_stage<'a>( module: &'a naga::Module, - group_layouts: &[&BindEntryMap], + mut group_layouts: IntrospectionBindGroupLayouts<'a>, entry_point_name: &str, stage: naga::ShaderStage, inputs: StageInterface<'a>, @@ -713,18 +781,49 @@ pub fn check_stage<'a>( } match var.binding { Some(naga::Binding::Descriptor { set, binding }) => { - let result = group_layouts - .get(set as usize) - .and_then(|map| map.get(&binding)) - .ok_or(BindingError::Missing) - .and_then(|entry| { - if entry.visibility.contains(stage_bit) { - Ok(entry) - } else { - Err(BindingError::Invisible) - } - }) - .and_then(|entry| check_binding(module, var, entry, usage)); + let result = match group_layouts { + IntrospectionBindGroupLayouts::Given(layouts) => layouts + .get(set as usize) + .and_then(|map| map.get(&binding)) + .ok_or(BindingError::Missing) + .and_then(|entry| { + if entry.visibility.contains(stage_bit) { + Ok(entry) + } else { + Err(BindingError::Invisible) + } + }) + .and_then(|entry| check_binding_use(module, var, entry)) + .and_then(|allowed_usage| { + if allowed_usage.contains(usage) { + Ok(()) + } else { + Err(BindingError::WrongUsage(usage)) + } + }), + IntrospectionBindGroupLayouts::Derived(ref mut layouts) => layouts + .get_mut(set as usize) + .ok_or(BindingError::Missing) + .and_then(|set| { + let ty = derive_binding_type(module, var, usage)?; + Ok(match set.entry(binding) { + Entry::Occupied(e) if e.get().ty != ty => { + return Err(BindingError::InconsistentlyDerivedType) + } + Entry::Occupied(e) => { + e.into_mut().visibility |= stage_bit; + } + Entry::Vacant(e) => { + e.insert(BindGroupLayoutEntry { + binding, + ty, + visibility: stage_bit, + count: None, + }); + } + }) + }), + }; if let Err(error) = result { return Err(StageError::Binding { set, diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index edf2f54a0..0ce51d9e5 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -2194,18 +2194,23 @@ impl<'a, L, D> RenderPipelineDescriptor<'a, L, D> { #[cfg_attr(feature = "replay", derive(serde::Deserialize))] pub struct ComputePipelineDescriptor { /// The layout of bind groups for this pipeline. - pub layout: L, + pub layout: Option, /// The compiled compute stage and its entry point. pub compute_stage: D, } impl ComputePipelineDescriptor { - pub fn new(layout: L, compute_stage: D) -> Self { + pub fn new(compute_stage: D) -> Self { Self { - layout, + layout: None, compute_stage, } } + + pub fn layout(&mut self, layout: L) -> &mut Self { + self.layout = Some(layout); + self + } } /// Describes a [`CommandBuffer`].