mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-22 14:55:05 +00:00
Derive bind group layout entries in Naga validation
This commit is contained in:
parent
430b29d781
commit
f18fa7ef9b
@ -2582,7 +2582,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
if let Some(ref module) = shader_module.module {
|
||||
interface = validation::check_stage(
|
||||
module,
|
||||
&group_layouts,
|
||||
validation::IntrospectionBindGroupLayouts::Given(&group_layouts),
|
||||
&entry_point_name,
|
||||
naga::ShaderStage::Vertex,
|
||||
interface,
|
||||
@ -2614,7 +2614,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
if let Some(ref module) = shader_module.module {
|
||||
interface = validation::check_stage(
|
||||
module,
|
||||
&group_layouts,
|
||||
validation::IntrospectionBindGroupLayouts::Given(&group_layouts),
|
||||
&entry_point_name,
|
||||
naga::ShaderStage::Fragment,
|
||||
interface,
|
||||
@ -2820,17 +2820,32 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let device = device_guard
|
||||
.get(device_id)
|
||||
.map_err(|_| DeviceError::Invalid)?;
|
||||
let (raw_pipeline, layout_ref_count) = {
|
||||
let (raw_pipeline, layout_id, layout_ref_count) = {
|
||||
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
|
||||
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
|
||||
let layout = pipeline_layout_guard
|
||||
.get(desc.layout)
|
||||
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
|
||||
let group_layouts = layout
|
||||
.bind_group_layout_ids
|
||||
.iter()
|
||||
.map(|&id| &bgl_guard[id].entries)
|
||||
.collect::<ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>>();
|
||||
|
||||
let mut derived_group_layouts =
|
||||
ArrayVec::<[binding_model::BindEntryMap; MAX_BIND_GROUPS]>::new();
|
||||
let given_group_layouts: ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>;
|
||||
let group_layouts = match desc.layout {
|
||||
Some(pipeline_layout_id) => {
|
||||
let layout = pipeline_layout_guard
|
||||
.get(pipeline_layout_id)
|
||||
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
|
||||
given_group_layouts = layout
|
||||
.bind_group_layout_ids
|
||||
.iter()
|
||||
.map(|&id| &bgl_guard[id].entries)
|
||||
.collect();
|
||||
validation::IntrospectionBindGroupLayouts::Given(&given_group_layouts)
|
||||
}
|
||||
None => {
|
||||
for _ in 0..device.limits.max_bind_groups {
|
||||
derived_group_layouts.push(binding_model::BindEntryMap::default());
|
||||
}
|
||||
validation::IntrospectionBindGroupLayouts::Derived(&mut derived_group_layouts)
|
||||
}
|
||||
};
|
||||
|
||||
let interface = validation::StageInterface::default();
|
||||
let pipeline_stage = &desc.compute_stage;
|
||||
@ -2849,7 +2864,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
if let Some(ref module) = shader_module.module {
|
||||
let _ = validation::check_stage(
|
||||
module,
|
||||
&group_layouts,
|
||||
group_layouts,
|
||||
&entry_point_name,
|
||||
naga::ShaderStage::Compute,
|
||||
interface,
|
||||
@ -2868,6 +2883,17 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
// TODO
|
||||
let parent = hal::pso::BasePipeline::None;
|
||||
|
||||
let pipeline_layout_id = match desc.layout {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
//TODO: create a new pipeline layout
|
||||
unimplemented!()
|
||||
}
|
||||
};
|
||||
let layout = pipeline_layout_guard
|
||||
.get(pipeline_layout_id)
|
||||
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
|
||||
|
||||
let pipeline_desc = hal::pso::ComputePipelineDesc {
|
||||
shader,
|
||||
layout: &layout.raw,
|
||||
@ -2876,7 +2902,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
};
|
||||
|
||||
match unsafe { device.raw.create_compute_pipeline(&pipeline_desc, None) } {
|
||||
Ok(pipeline) => (pipeline, layout.life_guard.add_ref()),
|
||||
Ok(pipeline) => (pipeline, pipeline_layout_id, layout.life_guard.add_ref()),
|
||||
Err(hal::pso::CreationError::OutOfMemory(_)) => {
|
||||
return Err(pipeline::CreateComputePipelineError::Device(
|
||||
DeviceError::OutOfMemory,
|
||||
@ -2889,7 +2915,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let pipeline = pipeline::ComputePipeline {
|
||||
raw: raw_pipeline,
|
||||
layout_id: Stored {
|
||||
value: id::Valid(desc.layout),
|
||||
value: id::Valid(layout_id),
|
||||
ref_count: layout_ref_count,
|
||||
},
|
||||
device_id: Stored {
|
||||
|
@ -3,6 +3,7 @@
|
||||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
||||
|
||||
use crate::{binding_model::BindEntryMap, FastHashMap};
|
||||
use std::collections::hash_map::Entry;
|
||||
use thiserror::Error;
|
||||
use wgt::{BindGroupLayoutEntry, BindingType};
|
||||
|
||||
@ -71,6 +72,8 @@ pub enum BindingError {
|
||||
WrongTextureMultisampled,
|
||||
#[error("comparison flag doesn't match the shader")]
|
||||
WrongSamplerComparison,
|
||||
#[error("derived bind group layout type is not consistent between stages")]
|
||||
InconsistentlyDerivedType,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Error)]
|
||||
@ -159,13 +162,12 @@ fn get_aligned_type_size(
|
||||
}
|
||||
}
|
||||
|
||||
fn check_binding(
|
||||
fn check_binding_use(
|
||||
module: &naga::Module,
|
||||
var: &naga::GlobalVariable,
|
||||
entry: &BindGroupLayoutEntry,
|
||||
usage: naga::GlobalUse,
|
||||
) -> Result<(), BindingError> {
|
||||
let allowed_usage = match module.types[var.ty].inner {
|
||||
) -> Result<naga::GlobalUse, BindingError> {
|
||||
match module.types[var.ty].inner {
|
||||
naga::TypeInner::Struct { ref members } => {
|
||||
let (allowed_usage, min_size) = match entry.ty {
|
||||
BindingType::UniformBuffer {
|
||||
@ -196,17 +198,17 @@ fn check_binding(
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
allowed_usage
|
||||
Ok(allowed_usage)
|
||||
}
|
||||
naga::TypeInner::Sampler { comparison } => match entry.ty {
|
||||
BindingType::Sampler { comparison: cmp } => {
|
||||
if cmp == comparison {
|
||||
naga::GlobalUse::empty()
|
||||
Ok(naga::GlobalUse::empty())
|
||||
} else {
|
||||
return Err(BindingError::WrongSamplerComparison);
|
||||
Err(BindingError::WrongSamplerComparison)
|
||||
}
|
||||
}
|
||||
_ => return Err(BindingError::WrongType),
|
||||
_ => Err(BindingError::WrongType),
|
||||
},
|
||||
naga::TypeInner::Image { base, dim, flags } => {
|
||||
if flags.contains(naga::ImageFlags::MULTISAMPLED) {
|
||||
@ -284,14 +286,9 @@ fn check_binding(
|
||||
if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) {
|
||||
return Err(BindingError::WrongTextureSampled);
|
||||
}
|
||||
allowed_usage
|
||||
Ok(allowed_usage)
|
||||
}
|
||||
_ => return Err(BindingError::WrongType),
|
||||
};
|
||||
if allowed_usage.contains(usage) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(BindingError::WrongUsage(usage))
|
||||
_ => Err(BindingError::WrongType),
|
||||
}
|
||||
}
|
||||
|
||||
@ -685,9 +682,80 @@ pub fn check_texture_format(format: wgt::TextureFormat, output: &naga::TypeInner
|
||||
|
||||
pub type StageInterface<'a> = FastHashMap<wgt::ShaderLocation, MaybeOwned<'a, naga::TypeInner>>;
|
||||
|
||||
pub enum IntrospectionBindGroupLayouts<'a> {
|
||||
Given(&'a [&'a BindEntryMap]),
|
||||
Derived(&'a mut [BindEntryMap]),
|
||||
}
|
||||
|
||||
fn derive_binding_type(
|
||||
module: &naga::Module,
|
||||
var: &naga::GlobalVariable,
|
||||
usage: naga::GlobalUse,
|
||||
) -> Result<BindingType, BindingError> {
|
||||
let ty = &module.types[var.ty];
|
||||
Ok(match ty.inner {
|
||||
naga::TypeInner::Struct { ref members } => {
|
||||
let dynamic = false;
|
||||
let mut actual_size = 0;
|
||||
for (i, member) in members.iter().enumerate() {
|
||||
actual_size += get_aligned_type_size(module, member.ty, i + 1 == members.len());
|
||||
}
|
||||
match var.class {
|
||||
naga::StorageClass::Uniform => BindingType::UniformBuffer {
|
||||
dynamic,
|
||||
min_binding_size: wgt::BufferSize::new(actual_size),
|
||||
},
|
||||
naga::StorageClass::StorageBuffer => BindingType::StorageBuffer {
|
||||
dynamic,
|
||||
min_binding_size: wgt::BufferSize::new(actual_size),
|
||||
readonly: !usage.contains(naga::GlobalUse::STORE), //TODO: clarify
|
||||
},
|
||||
_ => return Err(BindingError::WrongType),
|
||||
}
|
||||
}
|
||||
naga::TypeInner::Sampler { comparison } => BindingType::Sampler { comparison },
|
||||
naga::TypeInner::Image { base, dim, flags } => {
|
||||
let array = flags.contains(naga::ImageFlags::ARRAYED);
|
||||
let dimension = match dim {
|
||||
naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
|
||||
naga::ImageDimension::D2 if array => wgt::TextureViewDimension::D2Array,
|
||||
naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
|
||||
naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
|
||||
naga::ImageDimension::Cube if array => wgt::TextureViewDimension::CubeArray,
|
||||
naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
|
||||
};
|
||||
if flags.contains(naga::ImageFlags::SAMPLED) {
|
||||
BindingType::SampledTexture {
|
||||
dimension,
|
||||
component_type: match module.types[base].inner {
|
||||
naga::TypeInner::Scalar { kind, .. }
|
||||
| naga::TypeInner::Vector { kind, .. } => match kind {
|
||||
naga::ScalarKind::Float => wgt::TextureComponentType::Float,
|
||||
naga::ScalarKind::Sint => wgt::TextureComponentType::Sint,
|
||||
naga::ScalarKind::Uint => wgt::TextureComponentType::Uint,
|
||||
other => {
|
||||
return Err(BindingError::WrongTextureComponentType(Some(other)))
|
||||
}
|
||||
},
|
||||
_ => return Err(BindingError::WrongTextureComponentType(None)),
|
||||
},
|
||||
multisampled: flags.contains(naga::ImageFlags::MULTISAMPLED),
|
||||
}
|
||||
} else {
|
||||
BindingType::StorageTexture {
|
||||
dimension,
|
||||
format: wgt::TextureFormat::Rgba8Unorm, //TODO
|
||||
readonly: !flags.contains(naga::ImageFlags::CAN_STORE),
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => return Err(BindingError::WrongType),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn check_stage<'a>(
|
||||
module: &'a naga::Module,
|
||||
group_layouts: &[&BindEntryMap],
|
||||
mut group_layouts: IntrospectionBindGroupLayouts<'a>,
|
||||
entry_point_name: &str,
|
||||
stage: naga::ShaderStage,
|
||||
inputs: StageInterface<'a>,
|
||||
@ -713,18 +781,49 @@ pub fn check_stage<'a>(
|
||||
}
|
||||
match var.binding {
|
||||
Some(naga::Binding::Descriptor { set, binding }) => {
|
||||
let result = group_layouts
|
||||
.get(set as usize)
|
||||
.and_then(|map| map.get(&binding))
|
||||
.ok_or(BindingError::Missing)
|
||||
.and_then(|entry| {
|
||||
if entry.visibility.contains(stage_bit) {
|
||||
Ok(entry)
|
||||
} else {
|
||||
Err(BindingError::Invisible)
|
||||
}
|
||||
})
|
||||
.and_then(|entry| check_binding(module, var, entry, usage));
|
||||
let result = match group_layouts {
|
||||
IntrospectionBindGroupLayouts::Given(layouts) => layouts
|
||||
.get(set as usize)
|
||||
.and_then(|map| map.get(&binding))
|
||||
.ok_or(BindingError::Missing)
|
||||
.and_then(|entry| {
|
||||
if entry.visibility.contains(stage_bit) {
|
||||
Ok(entry)
|
||||
} else {
|
||||
Err(BindingError::Invisible)
|
||||
}
|
||||
})
|
||||
.and_then(|entry| check_binding_use(module, var, entry))
|
||||
.and_then(|allowed_usage| {
|
||||
if allowed_usage.contains(usage) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(BindingError::WrongUsage(usage))
|
||||
}
|
||||
}),
|
||||
IntrospectionBindGroupLayouts::Derived(ref mut layouts) => layouts
|
||||
.get_mut(set as usize)
|
||||
.ok_or(BindingError::Missing)
|
||||
.and_then(|set| {
|
||||
let ty = derive_binding_type(module, var, usage)?;
|
||||
Ok(match set.entry(binding) {
|
||||
Entry::Occupied(e) if e.get().ty != ty => {
|
||||
return Err(BindingError::InconsistentlyDerivedType)
|
||||
}
|
||||
Entry::Occupied(e) => {
|
||||
e.into_mut().visibility |= stage_bit;
|
||||
}
|
||||
Entry::Vacant(e) => {
|
||||
e.insert(BindGroupLayoutEntry {
|
||||
binding,
|
||||
ty,
|
||||
visibility: stage_bit,
|
||||
count: None,
|
||||
});
|
||||
}
|
||||
})
|
||||
}),
|
||||
};
|
||||
if let Err(error) = result {
|
||||
return Err(StageError::Binding {
|
||||
set,
|
||||
|
@ -2194,18 +2194,23 @@ impl<'a, L, D> RenderPipelineDescriptor<'a, L, D> {
|
||||
#[cfg_attr(feature = "replay", derive(serde::Deserialize))]
|
||||
pub struct ComputePipelineDescriptor<L, D> {
|
||||
/// The layout of bind groups for this pipeline.
|
||||
pub layout: L,
|
||||
pub layout: Option<L>,
|
||||
/// The compiled compute stage and its entry point.
|
||||
pub compute_stage: D,
|
||||
}
|
||||
|
||||
impl<L, D> ComputePipelineDescriptor<L, D> {
|
||||
pub fn new(layout: L, compute_stage: D) -> Self {
|
||||
pub fn new(compute_stage: D) -> Self {
|
||||
Self {
|
||||
layout,
|
||||
layout: None,
|
||||
compute_stage,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layout(&mut self, layout: L) -> &mut Self {
|
||||
self.layout = Some(layout);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes a [`CommandBuffer`].
|
||||
|
Loading…
Reference in New Issue
Block a user