Derive bind group layout entries in Naga validation

This commit is contained in:
Dzmitry Malyshau 2020-08-08 01:09:41 -04:00
parent 430b29d781
commit f18fa7ef9b
3 changed files with 175 additions and 45 deletions

View File

@ -2582,7 +2582,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if let Some(ref module) = shader_module.module {
interface = validation::check_stage(
module,
&group_layouts,
validation::IntrospectionBindGroupLayouts::Given(&group_layouts),
&entry_point_name,
naga::ShaderStage::Vertex,
interface,
@ -2614,7 +2614,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if let Some(ref module) = shader_module.module {
interface = validation::check_stage(
module,
&group_layouts,
validation::IntrospectionBindGroupLayouts::Given(&group_layouts),
&entry_point_name,
naga::ShaderStage::Fragment,
interface,
@ -2820,17 +2820,32 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let device = device_guard
.get(device_id)
.map_err(|_| DeviceError::Invalid)?;
let (raw_pipeline, layout_ref_count) = {
let (raw_pipeline, layout_id, layout_ref_count) = {
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
let layout = pipeline_layout_guard
.get(desc.layout)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
let group_layouts = layout
.bind_group_layout_ids
.iter()
.map(|&id| &bgl_guard[id].entries)
.collect::<ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>>();
let mut derived_group_layouts =
ArrayVec::<[binding_model::BindEntryMap; MAX_BIND_GROUPS]>::new();
let given_group_layouts: ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>;
let group_layouts = match desc.layout {
Some(pipeline_layout_id) => {
let layout = pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
given_group_layouts = layout
.bind_group_layout_ids
.iter()
.map(|&id| &bgl_guard[id].entries)
.collect();
validation::IntrospectionBindGroupLayouts::Given(&given_group_layouts)
}
None => {
for _ in 0..device.limits.max_bind_groups {
derived_group_layouts.push(binding_model::BindEntryMap::default());
}
validation::IntrospectionBindGroupLayouts::Derived(&mut derived_group_layouts)
}
};
let interface = validation::StageInterface::default();
let pipeline_stage = &desc.compute_stage;
@ -2849,7 +2864,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if let Some(ref module) = shader_module.module {
let _ = validation::check_stage(
module,
&group_layouts,
group_layouts,
&entry_point_name,
naga::ShaderStage::Compute,
interface,
@ -2868,6 +2883,17 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// TODO
let parent = hal::pso::BasePipeline::None;
let pipeline_layout_id = match desc.layout {
Some(id) => id,
None => {
//TODO: create a new pipeline layout
unimplemented!()
}
};
let layout = pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
let pipeline_desc = hal::pso::ComputePipelineDesc {
shader,
layout: &layout.raw,
@ -2876,7 +2902,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
};
match unsafe { device.raw.create_compute_pipeline(&pipeline_desc, None) } {
Ok(pipeline) => (pipeline, layout.life_guard.add_ref()),
Ok(pipeline) => (pipeline, pipeline_layout_id, layout.life_guard.add_ref()),
Err(hal::pso::CreationError::OutOfMemory(_)) => {
return Err(pipeline::CreateComputePipelineError::Device(
DeviceError::OutOfMemory,
@ -2889,7 +2915,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let pipeline = pipeline::ComputePipeline {
raw: raw_pipeline,
layout_id: Stored {
value: id::Valid(desc.layout),
value: id::Valid(layout_id),
ref_count: layout_ref_count,
},
device_id: Stored {

View File

@ -3,6 +3,7 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
use crate::{binding_model::BindEntryMap, FastHashMap};
use std::collections::hash_map::Entry;
use thiserror::Error;
use wgt::{BindGroupLayoutEntry, BindingType};
@ -71,6 +72,8 @@ pub enum BindingError {
WrongTextureMultisampled,
#[error("comparison flag doesn't match the shader")]
WrongSamplerComparison,
#[error("derived bind group layout type is not consistent between stages")]
InconsistentlyDerivedType,
}
#[derive(Clone, Debug, Error)]
@ -159,13 +162,12 @@ fn get_aligned_type_size(
}
}
fn check_binding(
fn check_binding_use(
module: &naga::Module,
var: &naga::GlobalVariable,
entry: &BindGroupLayoutEntry,
usage: naga::GlobalUse,
) -> Result<(), BindingError> {
let allowed_usage = match module.types[var.ty].inner {
) -> Result<naga::GlobalUse, BindingError> {
match module.types[var.ty].inner {
naga::TypeInner::Struct { ref members } => {
let (allowed_usage, min_size) = match entry.ty {
BindingType::UniformBuffer {
@ -196,17 +198,17 @@ fn check_binding(
}
_ => (),
}
allowed_usage
Ok(allowed_usage)
}
naga::TypeInner::Sampler { comparison } => match entry.ty {
BindingType::Sampler { comparison: cmp } => {
if cmp == comparison {
naga::GlobalUse::empty()
Ok(naga::GlobalUse::empty())
} else {
return Err(BindingError::WrongSamplerComparison);
Err(BindingError::WrongSamplerComparison)
}
}
_ => return Err(BindingError::WrongType),
_ => Err(BindingError::WrongType),
},
naga::TypeInner::Image { base, dim, flags } => {
if flags.contains(naga::ImageFlags::MULTISAMPLED) {
@ -284,14 +286,9 @@ fn check_binding(
if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) {
return Err(BindingError::WrongTextureSampled);
}
allowed_usage
Ok(allowed_usage)
}
_ => return Err(BindingError::WrongType),
};
if allowed_usage.contains(usage) {
Ok(())
} else {
Err(BindingError::WrongUsage(usage))
_ => Err(BindingError::WrongType),
}
}
@ -685,9 +682,80 @@ pub fn check_texture_format(format: wgt::TextureFormat, output: &naga::TypeInner
pub type StageInterface<'a> = FastHashMap<wgt::ShaderLocation, MaybeOwned<'a, naga::TypeInner>>;
pub enum IntrospectionBindGroupLayouts<'a> {
Given(&'a [&'a BindEntryMap]),
Derived(&'a mut [BindEntryMap]),
}
fn derive_binding_type(
module: &naga::Module,
var: &naga::GlobalVariable,
usage: naga::GlobalUse,
) -> Result<BindingType, BindingError> {
let ty = &module.types[var.ty];
Ok(match ty.inner {
naga::TypeInner::Struct { ref members } => {
let dynamic = false;
let mut actual_size = 0;
for (i, member) in members.iter().enumerate() {
actual_size += get_aligned_type_size(module, member.ty, i + 1 == members.len());
}
match var.class {
naga::StorageClass::Uniform => BindingType::UniformBuffer {
dynamic,
min_binding_size: wgt::BufferSize::new(actual_size),
},
naga::StorageClass::StorageBuffer => BindingType::StorageBuffer {
dynamic,
min_binding_size: wgt::BufferSize::new(actual_size),
readonly: !usage.contains(naga::GlobalUse::STORE), //TODO: clarify
},
_ => return Err(BindingError::WrongType),
}
}
naga::TypeInner::Sampler { comparison } => BindingType::Sampler { comparison },
naga::TypeInner::Image { base, dim, flags } => {
let array = flags.contains(naga::ImageFlags::ARRAYED);
let dimension = match dim {
naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
naga::ImageDimension::D2 if array => wgt::TextureViewDimension::D2Array,
naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
naga::ImageDimension::Cube if array => wgt::TextureViewDimension::CubeArray,
naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
};
if flags.contains(naga::ImageFlags::SAMPLED) {
BindingType::SampledTexture {
dimension,
component_type: match module.types[base].inner {
naga::TypeInner::Scalar { kind, .. }
| naga::TypeInner::Vector { kind, .. } => match kind {
naga::ScalarKind::Float => wgt::TextureComponentType::Float,
naga::ScalarKind::Sint => wgt::TextureComponentType::Sint,
naga::ScalarKind::Uint => wgt::TextureComponentType::Uint,
other => {
return Err(BindingError::WrongTextureComponentType(Some(other)))
}
},
_ => return Err(BindingError::WrongTextureComponentType(None)),
},
multisampled: flags.contains(naga::ImageFlags::MULTISAMPLED),
}
} else {
BindingType::StorageTexture {
dimension,
format: wgt::TextureFormat::Rgba8Unorm, //TODO
readonly: !flags.contains(naga::ImageFlags::CAN_STORE),
}
}
}
_ => return Err(BindingError::WrongType),
})
}
pub fn check_stage<'a>(
module: &'a naga::Module,
group_layouts: &[&BindEntryMap],
mut group_layouts: IntrospectionBindGroupLayouts<'a>,
entry_point_name: &str,
stage: naga::ShaderStage,
inputs: StageInterface<'a>,
@ -713,18 +781,49 @@ pub fn check_stage<'a>(
}
match var.binding {
Some(naga::Binding::Descriptor { set, binding }) => {
let result = group_layouts
.get(set as usize)
.and_then(|map| map.get(&binding))
.ok_or(BindingError::Missing)
.and_then(|entry| {
if entry.visibility.contains(stage_bit) {
Ok(entry)
} else {
Err(BindingError::Invisible)
}
})
.and_then(|entry| check_binding(module, var, entry, usage));
let result = match group_layouts {
IntrospectionBindGroupLayouts::Given(layouts) => layouts
.get(set as usize)
.and_then(|map| map.get(&binding))
.ok_or(BindingError::Missing)
.and_then(|entry| {
if entry.visibility.contains(stage_bit) {
Ok(entry)
} else {
Err(BindingError::Invisible)
}
})
.and_then(|entry| check_binding_use(module, var, entry))
.and_then(|allowed_usage| {
if allowed_usage.contains(usage) {
Ok(())
} else {
Err(BindingError::WrongUsage(usage))
}
}),
IntrospectionBindGroupLayouts::Derived(ref mut layouts) => layouts
.get_mut(set as usize)
.ok_or(BindingError::Missing)
.and_then(|set| {
let ty = derive_binding_type(module, var, usage)?;
Ok(match set.entry(binding) {
Entry::Occupied(e) if e.get().ty != ty => {
return Err(BindingError::InconsistentlyDerivedType)
}
Entry::Occupied(e) => {
e.into_mut().visibility |= stage_bit;
}
Entry::Vacant(e) => {
e.insert(BindGroupLayoutEntry {
binding,
ty,
visibility: stage_bit,
count: None,
});
}
})
}),
};
if let Err(error) = result {
return Err(StageError::Binding {
set,

View File

@ -2194,18 +2194,23 @@ impl<'a, L, D> RenderPipelineDescriptor<'a, L, D> {
#[cfg_attr(feature = "replay", derive(serde::Deserialize))]
pub struct ComputePipelineDescriptor<L, D> {
/// The layout of bind groups for this pipeline.
pub layout: L,
pub layout: Option<L>,
/// The compiled compute stage and its entry point.
pub compute_stage: D,
}
impl<L, D> ComputePipelineDescriptor<L, D> {
pub fn new(layout: L, compute_stage: D) -> Self {
pub fn new(compute_stage: D) -> Self {
Self {
layout,
layout: None,
compute_stage,
}
}
pub fn layout(&mut self, layout: L) -> &mut Self {
self.layout = Some(layout);
self
}
}
/// Describes a [`CommandBuffer`].