870: Implicit layout r=cwfitzgerald a=kvark

**Connections**
Closes #868

**Description**
The implementation can be split into 3 parts:
  1. reflecting the shader for binding expectations, and building a bind entry map from it, merging them between stages. This is only done for shaders that can be reflected, and we error on the rest, for now.
  2. based on this info, create new bind group layouts and pipeline layouts. The tricky part here is that we can't generate the ID out of thin air, so we have to pass them into the `create_xx_pipeline` function, which now also returns the number of IDs it consumed, allowing the client to free the rest.
  3. API changes in the descriptors, new methods to obtain the bind group layouts from a pipeline

**Testing**
This isn't tested, but I think it's fine: it doesn't affect the old path, and we'll be testing the new path while improving Naga and our reflection anyway.

Co-authored-by: Dzmitry Malyshau <kvarkus@gmail.com>
This commit is contained in:
bors[bot] 2020-08-10 18:38:34 +00:00 committed by GitHub
commit 12352035f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 704 additions and 265 deletions

View File

@ -248,7 +248,7 @@ impl GlobalPlay for wgc::hub::Global<IdentityPassThroughFactory> {
}
A::CreateComputePipeline(id, desc) => {
self.device_maintain_ids::<B>(device).unwrap();
self.device_create_compute_pipeline::<B>(device, &desc, id)
self.device_create_compute_pipeline::<B>(device, &desc, id, None)
.unwrap();
}
A::DestroyComputePipeline(id) => {
@ -256,7 +256,7 @@ impl GlobalPlay for wgc::hub::Global<IdentityPassThroughFactory> {
}
A::CreateRenderPipeline(id, desc) => {
self.device_maintain_ids::<B>(device).unwrap();
self.device_create_render_pipeline::<B>(device, &desc, id)
self.device_create_render_pipeline::<B>(device, &desc, id, None)
.unwrap();
}
A::DestroyRenderPipeline(id) => {

View File

@ -523,3 +523,11 @@ impl<B: hal::Backend> Borrow<()> for BindGroup<B> {
&DUMMY_SELECTOR
}
}
#[derive(Clone, Debug, Error)]
pub enum GetBindGroupLayoutError {
#[error("pipeline is invalid")]
InvalidPipeline,
#[error("invalid group index {0}")]
InvalidGroupIndex(u32),
}

View File

@ -6,7 +6,9 @@ use crate::{
binding_model::{self, CreateBindGroupError, CreatePipelineLayoutError},
command, conv,
device::life::WaitIdleError,
hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Hub, Input, InvalidId, Token},
hub::{
GfxBackend, Global, GlobalIdentityHandlerFactory, Hub, Input, InvalidId, Storage, Token,
},
id, pipeline, resource, span, swap_chain,
track::{BufferState, TextureState, TrackerSet},
validation::{self, check_buffer_usage, check_texture_usage},
@ -601,6 +603,216 @@ impl<B: GfxBackend> Device<B> {
unsafe { self.raw.create_render_pass(all, iter::once(subpass), &[]) }
}
fn deduplicate_bind_group_layout(
self_id: id::DeviceId,
entry_map: &binding_model::BindEntryMap,
guard: &Storage<binding_model::BindGroupLayout<B>, id::BindGroupLayoutId>,
) -> Option<id::BindGroupLayoutId> {
guard
.iter(self_id.backend())
.find(|(_, bgl)| bgl.device_id.value.0 == self_id && bgl.entries == *entry_map)
.map(|(id, value)| {
value.multi_ref_count.inc();
id
})
}
fn get_introspection_bind_group_layouts<'a>(
pipeline_layout: &binding_model::PipelineLayout<B>,
bgl_guard: &'a Storage<binding_model::BindGroupLayout<B>, id::BindGroupLayoutId>,
) -> validation::IntrospectionBindGroupLayouts<'a> {
validation::IntrospectionBindGroupLayouts::Given(
pipeline_layout
.bind_group_layout_ids
.iter()
.map(|&id| &bgl_guard[id].entries)
.collect(),
)
}
fn create_bind_group_layout(
&self,
self_id: id::DeviceId,
label: Option<&str>,
entry_map: binding_model::BindEntryMap,
) -> Result<binding_model::BindGroupLayout<B>, binding_model::CreateBindGroupLayoutError> {
// Validate the count parameter
for binding in entry_map.values() {
if let Some(count) = binding.count {
if count == 0 {
return Err(binding_model::CreateBindGroupLayoutError::ZeroCount);
}
match binding.ty {
wgt::BindingType::SampledTexture { .. } => {
if !self
.features
.contains(wgt::Features::SAMPLED_TEXTURE_BINDING_ARRAY)
{
return Err(binding_model::CreateBindGroupLayoutError::MissingFeature(
wgt::Features::SAMPLED_TEXTURE_BINDING_ARRAY,
));
}
}
_ => return Err(binding_model::CreateBindGroupLayoutError::ArrayUnsupported),
}
}
}
let raw_bindings = entry_map
.values()
.map(|entry| hal::pso::DescriptorSetLayoutBinding {
binding: entry.binding,
ty: conv::map_binding_type(entry),
count: entry
.count
.map_or(1, |v| v as hal::pso::DescriptorArrayIndex), //TODO: consolidate
stage_flags: conv::map_shader_stage_flags(entry.visibility),
immutable_samplers: false, // TODO
})
.collect::<Vec<_>>(); //TODO: avoid heap allocation
let raw = unsafe {
let mut raw_layout = self
.raw
.create_descriptor_set_layout(&raw_bindings, &[])
.or(Err(DeviceError::OutOfMemory))?;
if let Some(label) = label {
self.raw
.set_descriptor_set_layout_name(&mut raw_layout, label);
}
raw_layout
};
let mut count_validator = binding_model::BindingTypeMaxCountValidator::default();
for entry in entry_map.values() {
count_validator.add_binding(entry);
}
// If a single bind group layout violates limits, the pipeline layout is definitely
// going to violate limits too, lets catch it now.
count_validator
.validate(&self.limits)
.map_err(binding_model::CreateBindGroupLayoutError::TooManyBindings)?;
Ok(binding_model::BindGroupLayout {
raw,
device_id: Stored {
value: id::Valid(self_id),
ref_count: self.life_guard.add_ref(),
},
multi_ref_count: MultiRefCount::new(),
desc_counts: raw_bindings.iter().cloned().collect(),
dynamic_count: entry_map
.values()
.filter(|b| b.has_dynamic_offset())
.count(),
count_validator,
entries: entry_map,
})
}
fn create_pipeline_layout(
&self,
self_id: id::DeviceId,
desc: &wgt::PipelineLayoutDescriptor<id::BindGroupLayoutId>,
bgl_guard: &Storage<binding_model::BindGroupLayout<B>, id::BindGroupLayoutId>,
) -> Result<binding_model::PipelineLayout<B>, CreatePipelineLayoutError> {
let bind_group_layouts_count = desc.bind_group_layouts.len();
let device_max_bind_groups = self.limits.max_bind_groups as usize;
if bind_group_layouts_count > device_max_bind_groups {
return Err(CreatePipelineLayoutError::TooManyGroups {
actual: bind_group_layouts_count,
max: device_max_bind_groups,
});
}
if !desc.push_constant_ranges.is_empty()
&& !self.features.contains(wgt::Features::PUSH_CONSTANTS)
{
return Err(CreatePipelineLayoutError::MissingFeature(
wgt::Features::PUSH_CONSTANTS,
));
}
let mut used_stages = wgt::ShaderStage::empty();
for (index, pc) in desc.push_constant_ranges.iter().enumerate() {
if pc.stages.intersects(used_stages) {
return Err(
CreatePipelineLayoutError::MoreThanOnePushConstantRangePerStage {
index,
provided: pc.stages,
intersected: pc.stages & used_stages,
},
);
}
used_stages |= pc.stages;
let device_max_pc_size = self.limits.max_push_constant_size;
if device_max_pc_size < pc.range.end {
return Err(CreatePipelineLayoutError::PushConstantRangeTooLarge {
index,
range: pc.range.clone(),
max: device_max_pc_size,
});
}
if pc.range.start % wgt::PUSH_CONSTANT_ALIGNMENT != 0 {
return Err(CreatePipelineLayoutError::MisalignedPushConstantRange {
index,
bound: pc.range.start,
});
}
if pc.range.end % wgt::PUSH_CONSTANT_ALIGNMENT != 0 {
return Err(CreatePipelineLayoutError::MisalignedPushConstantRange {
index,
bound: pc.range.end,
});
}
}
let mut count_validator = binding_model::BindingTypeMaxCountValidator::default();
// validate total resource counts
for &id in desc.bind_group_layouts.iter() {
let bind_group_layout = bgl_guard
.get(id)
.map_err(|_| CreatePipelineLayoutError::InvalidBindGroupLayout(id))?;
count_validator.merge(&bind_group_layout.count_validator);
}
count_validator
.validate(&self.limits)
.map_err(CreatePipelineLayoutError::TooManyBindings)?;
let descriptor_set_layouts = desc
.bind_group_layouts
.iter()
.map(|&id| &bgl_guard.get(id).unwrap().raw);
let push_constants = desc
.push_constant_ranges
.iter()
.map(|pc| (conv::map_shader_stage_flags(pc.stages), pc.range.clone()));
Ok(binding_model::PipelineLayout {
raw: unsafe {
self.raw
.create_pipeline_layout(descriptor_set_layouts, push_constants)
.or(Err(DeviceError::OutOfMemory))?
},
device_id: Stored {
value: id::Valid(self_id),
ref_count: self.life_guard.add_ref(),
},
life_guard: LifeGuard::new(),
bind_group_layout_ids: desc
.bind_group_layouts
.iter()
.map(|&id| {
bgl_guard.get(id).unwrap().multi_ref_count.inc();
id::Valid(id)
})
.collect(),
push_constant_ranges: desc.push_constant_ranges.iter().cloned().collect(),
})
}
fn wait_for_submit(
&self,
submission_index: SubmissionIndex,
@ -707,6 +919,11 @@ impl DeviceError {
}
}
pub struct ImplicitPipelineIds<'a, G: GlobalIdentityHandlerFactory> {
pub root_id: Input<G, id::PipelineLayoutId>,
pub group_ids: &'a [Input<G, id::BindGroupLayoutId>],
}
impl<G: GlobalIdentityHandlerFactory> Global<G> {
pub fn device_features<B: GfxBackend>(
&self,
@ -1400,97 +1617,18 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// so their inputs are `PhantomData` of size 0.
if mem::size_of::<Input<G, id::BindGroupLayoutId>>() == 0 {
let (bgl_guard, _) = hub.bind_group_layouts.read(&mut token);
let bind_group_layout_id = bgl_guard
.iter(device_id.backend())
.find(|(_, bgl)| bgl.device_id.value.0 == device_id && bgl.entries == entry_map);
if let Some((id, value)) = bind_group_layout_id {
value.multi_ref_count.inc();
if let Some(id) =
Device::deduplicate_bind_group_layout(device_id, &entry_map, &*bgl_guard)
{
return Ok(id);
}
}
// Validate the count parameter
for binding in desc
.entries
.iter()
.filter(|binding| binding.count.is_some())
{
if let Some(count) = binding.count {
if count == 0 {
return Err(binding_model::CreateBindGroupLayoutError::ZeroCount);
}
match binding.ty {
wgt::BindingType::SampledTexture { .. } => {
if !device
.features
.contains(wgt::Features::SAMPLED_TEXTURE_BINDING_ARRAY)
{
return Err(binding_model::CreateBindGroupLayoutError::MissingFeature(
wgt::Features::SAMPLED_TEXTURE_BINDING_ARRAY,
));
}
}
_ => return Err(binding_model::CreateBindGroupLayoutError::ArrayUnsupported),
}
} else {
unreachable!() // programming bug
}
}
let raw_bindings = desc
.entries
.iter()
.map(|binding| hal::pso::DescriptorSetLayoutBinding {
binding: binding.binding,
ty: conv::map_binding_type(binding),
count: binding
.count
.map_or(1, |v| v as hal::pso::DescriptorArrayIndex), //TODO: consolidate
stage_flags: conv::map_shader_stage_flags(binding.visibility),
immutable_samplers: false, // TODO
})
.collect::<Vec<_>>(); //TODO: avoid heap allocation
let raw = unsafe {
let mut raw_layout = device
.raw
.create_descriptor_set_layout(&raw_bindings, &[])
.or(Err(DeviceError::OutOfMemory))?;
if let Some(label) = desc.label.as_ref() {
device
.raw
.set_descriptor_set_layout_name(&mut raw_layout, label);
}
raw_layout
};
let mut count_validator = binding_model::BindingTypeMaxCountValidator::default();
desc.entries
.iter()
.for_each(|b| count_validator.add_binding(b));
// If a single bind group layout violates limits, the pipeline layout is definitely
// going to violate limits too, lets catch it now.
count_validator
.validate(&device.limits)
.map_err(binding_model::CreateBindGroupLayoutError::TooManyBindings)?;
let layout = binding_model::BindGroupLayout {
raw,
device_id: Stored {
value: id::Valid(device_id),
ref_count: device.life_guard.add_ref(),
},
multi_ref_count: MultiRefCount::new(),
entries: entry_map,
desc_counts: raw_bindings.iter().cloned().collect(),
dynamic_count: desc
.entries
.iter()
.filter(|b| b.has_dynamic_offset())
.count(),
count_validator,
};
let layout = device.create_bind_group_layout(
device_id,
desc.label.as_ref().map(|cow| cow.as_ref()),
entry_map,
)?;
let id = hub
.bind_group_layouts
@ -1557,108 +1695,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let device = device_guard
.get(device_id)
.map_err(|_| DeviceError::Invalid)?;
let bind_group_layouts_count = desc.bind_group_layouts.len();
let device_max_bind_groups = device.limits.max_bind_groups as usize;
if bind_group_layouts_count > device_max_bind_groups {
return Err(CreatePipelineLayoutError::TooManyGroups {
actual: bind_group_layouts_count,
max: device_max_bind_groups,
});
}
if !desc.push_constant_ranges.is_empty()
&& !device.features.contains(wgt::Features::PUSH_CONSTANTS)
{
return Err(CreatePipelineLayoutError::MissingFeature(
wgt::Features::PUSH_CONSTANTS,
));
}
let mut used_stages = wgt::ShaderStage::empty();
for (index, pc) in desc.push_constant_ranges.iter().enumerate() {
if pc.stages.intersects(used_stages) {
return Err(
CreatePipelineLayoutError::MoreThanOnePushConstantRangePerStage {
index,
provided: pc.stages,
intersected: pc.stages & used_stages,
},
);
}
used_stages |= pc.stages;
let device_max_pc_size = device.limits.max_push_constant_size;
if device_max_pc_size < pc.range.end {
return Err(CreatePipelineLayoutError::PushConstantRangeTooLarge {
index,
range: pc.range.clone(),
max: device_max_pc_size,
});
}
if pc.range.start % wgt::PUSH_CONSTANT_ALIGNMENT != 0 {
return Err(CreatePipelineLayoutError::MisalignedPushConstantRange {
index,
bound: pc.range.start,
});
}
if pc.range.end % wgt::PUSH_CONSTANT_ALIGNMENT != 0 {
return Err(CreatePipelineLayoutError::MisalignedPushConstantRange {
index,
bound: pc.range.end,
});
}
}
let layout = {
let mut count_validator = binding_model::BindingTypeMaxCountValidator::default();
let (bind_group_layout_guard, _) = hub.bind_group_layouts.read(&mut token);
// validate total resource counts
for &id in desc.bind_group_layouts.iter() {
let bind_group_layout = bind_group_layout_guard
.get(id)
.map_err(|_| CreatePipelineLayoutError::InvalidBindGroupLayout(id))?;
count_validator.merge(&bind_group_layout.count_validator);
}
count_validator
.validate(&device.limits)
.map_err(CreatePipelineLayoutError::TooManyBindings)?;
let descriptor_set_layouts = desc
.bind_group_layouts
.iter()
.map(|&id| &bind_group_layout_guard.get(id).unwrap().raw);
let push_constants = desc
.push_constant_ranges
.iter()
.map(|pc| (conv::map_shader_stage_flags(pc.stages), pc.range.clone()));
binding_model::PipelineLayout {
raw: unsafe {
device
.raw
.create_pipeline_layout(descriptor_set_layouts, push_constants)
.or(Err(DeviceError::OutOfMemory))?
},
device_id: Stored {
value: id::Valid(device_id),
ref_count: device.life_guard.add_ref(),
},
life_guard: LifeGuard::new(),
bind_group_layout_ids: desc
.bind_group_layouts
.iter()
.map(|&id| {
bind_group_layout_guard
.get(id)
.unwrap()
.multi_ref_count
.inc();
id::Valid(id)
})
.collect(),
push_constant_ranges: desc.push_constant_ranges.iter().cloned().collect(),
}
let (bgl_guard, _) = hub.bind_group_layouts.read(&mut token);
device.create_pipeline_layout(device_id, desc, &*bgl_guard)?
};
let id = hub
@ -1714,6 +1754,98 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
});
}
fn derive_pipeline_layout<B: GfxBackend>(
&self,
device: &Device<B>,
device_id: id::DeviceId,
implicit_pipeline_ids: Option<ImplicitPipelineIds<G>>,
mut derived_group_layouts: ArrayVec<[binding_model::BindEntryMap; MAX_BIND_GROUPS]>,
bgl_guard: &mut Storage<binding_model::BindGroupLayout<B>, id::BindGroupLayoutId>,
pipeline_layout_guard: &mut Storage<binding_model::PipelineLayout<B>, id::PipelineLayoutId>,
) -> Result<
(id::PipelineLayoutId, pipeline::ImplicitBindGroupCount),
pipeline::ImplicitLayoutError,
> {
let hub = B::hub(self);
let derived_bind_group_count =
derived_group_layouts.len() as pipeline::ImplicitBindGroupCount;
while derived_group_layouts
.last()
.map_or(false, |map| map.is_empty())
{
derived_group_layouts.pop();
}
let ids = implicit_pipeline_ids
.as_ref()
.ok_or(pipeline::ImplicitLayoutError::MissingIds(0))?;
if ids.group_ids.len() < derived_group_layouts.len() {
tracing::error!(
"Not enough bind group IDs ({}) specified for the implicit layout ({})",
ids.group_ids.len(),
derived_group_layouts.len()
);
return Err(pipeline::ImplicitLayoutError::MissingIds(
derived_bind_group_count,
));
}
let mut derived_group_layout_ids =
ArrayVec::<[id::BindGroupLayoutId; MAX_BIND_GROUPS]>::new();
for (bgl_id, map) in ids.group_ids.iter().zip(derived_group_layouts) {
let processed_id =
match Device::deduplicate_bind_group_layout(device_id, &map, bgl_guard) {
Some(dedup_id) => dedup_id,
None => {
#[cfg(feature = "trace")]
let bgl_desc = wgt::BindGroupLayoutDescriptor {
label: None,
entries: if device.trace.is_some() {
Cow::Owned(map.values().cloned().collect())
} else {
Cow::Borrowed(&[])
},
};
let bgl = device.create_bind_group_layout(device_id, None, map)?;
let out_id = hub.bind_group_layouts.register_identity_locked(
bgl_id.clone(),
bgl,
bgl_guard,
);
#[cfg(feature = "trace")]
match device.trace {
Some(ref trace) => trace
.lock()
.add(trace::Action::CreateBindGroupLayout(out_id.0, bgl_desc)),
None => (),
};
out_id.0
}
};
derived_group_layout_ids.push(processed_id);
}
let layout_desc = wgt::PipelineLayoutDescriptor {
bind_group_layouts: Cow::Borrowed(&derived_group_layout_ids),
push_constant_ranges: Cow::Borrowed(&[]), //TODO?
};
let layout = device.create_pipeline_layout(device_id, &layout_desc, bgl_guard)?;
let layout_id = hub.pipeline_layouts.register_identity_locked(
ids.root_id.clone(),
layout,
pipeline_layout_guard,
);
#[cfg(feature = "trace")]
match device.trace {
Some(ref trace) => trace.lock().add(trace::Action::CreatePipelineLayout(
layout_id.0,
layout_desc,
)),
None => (),
};
Ok((layout_id.0, derived_bind_group_count))
}
pub fn device_create_bind_group<B: GfxBackend>(
&self,
device_id: id::DeviceId,
@ -2390,7 +2522,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
device_id: id::DeviceId,
desc: &pipeline::RenderPipelineDescriptor,
id_in: Input<G, id::RenderPipelineId>,
) -> Result<id::RenderPipelineId, pipeline::CreateRenderPipelineError> {
implicit_pipeline_ids: Option<ImplicitPipelineIds<G>>,
) -> Result<
(id::RenderPipelineId, pipeline::ImplicitBindGroupCount),
pipeline::CreateRenderPipelineError,
> {
span!(_guard, INFO, "Device::create_render_pipeline");
let hub = B::hub(self);
@ -2516,17 +2652,18 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
));
}
let (raw_pipeline, layout_ref_count) = {
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
let layout = pipeline_layout_guard
.get(desc.layout)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?;
let group_layouts = layout
.bind_group_layout_ids
.iter()
.map(|&id| &bgl_guard[id].entries)
.collect::<ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>>();
let (raw_pipeline, layout_id, layout_ref_count, derived_bind_group_count) = {
//TODO: only lock mutable if the layout is derived
let (mut pipeline_layout_guard, mut token) = hub.pipeline_layouts.write(&mut token);
let (mut bgl_guard, mut token) = hub.bind_group_layouts.write(&mut token);
let mut derived_group_layouts =
ArrayVec::<[binding_model::BindEntryMap; MAX_BIND_GROUPS]>::new();
if desc.layout.is_none() {
for _ in 0..device.limits.max_bind_groups {
derived_group_layouts.push(binding_model::BindEntryMap::default());
}
}
let (shader_module_guard, _) = hub.shader_modules.read(&mut token);
@ -2580,9 +2717,21 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
})?;
if let Some(ref module) = shader_module.module {
let group_layouts = match desc.layout {
Some(pipeline_layout_id) => Device::get_introspection_bind_group_layouts(
pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?,
&*bgl_guard,
),
None => validation::IntrospectionBindGroupLayouts::Derived(
&mut derived_group_layouts,
),
};
interface = validation::check_stage(
module,
&group_layouts,
group_layouts,
&entry_point_name,
naga::ShaderStage::Vertex,
interface,
@ -2610,11 +2759,23 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
})?;
let group_layouts = match desc.layout {
Some(pipeline_layout_id) => Device::get_introspection_bind_group_layouts(
pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?,
&*bgl_guard,
),
None => validation::IntrospectionBindGroupLayouts::Derived(
&mut derived_group_layouts,
),
};
if validated_stages == wgt::ShaderStage::VERTEX {
if let Some(ref module) = shader_module.module {
interface = validation::check_stage(
module,
&group_layouts,
group_layouts,
&entry_point_name,
naga::ShaderStage::Fragment,
interface,
@ -2653,6 +2814,13 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
}
}
let last_stage = match desc.fragment_stage {
Some(_) => wgt::ShaderStage::FRAGMENT,
None => wgt::ShaderStage::VERTEX,
};
if desc.layout.is_none() && !validated_stages.contains(last_stage) {
Err(pipeline::ImplicitLayoutError::ReflectionError(last_stage))?
}
let shaders = hal::pso::GraphicsShaderSet {
vertex,
@ -2664,6 +2832,23 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// TODO
let flags = hal::pso::PipelineCreationFlags::empty();
// TODO
let parent = hal::pso::BasePipeline::None;
let (pipeline_layout_id, derived_bind_group_count) = match desc.layout {
Some(id) => (id, 0),
None => self.derive_pipeline_layout(
device,
device_id,
implicit_pipeline_ids,
derived_group_layouts,
&mut *bgl_guard,
&mut *pipeline_layout_guard,
)?,
};
let layout = pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?;
let mut render_pass_cache = device.render_passes.lock();
let pipeline_desc = hal::pso::GraphicsPipelineDesc {
@ -2690,7 +2875,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
},
},
flags,
parent: hal::pso::BasePipeline::None,
parent,
};
// TODO: cache
let pipeline = unsafe {
@ -2703,7 +2888,12 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
})?
};
(pipeline, layout.life_guard.add_ref())
(
pipeline,
pipeline_layout_id,
layout.life_guard.add_ref(),
derived_bind_group_count,
)
};
let pass_context = RenderPassContext {
@ -2735,7 +2925,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let pipeline = pipeline::RenderPipeline {
raw: raw_pipeline,
layout_id: Stored {
value: id::Valid(desc.layout),
value: id::Valid(layout_id),
ref_count: layout_ref_count,
},
device_id: Stored {
@ -2755,11 +2945,42 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
#[cfg(feature = "trace")]
match device.trace {
Some(ref trace) => trace
.lock()
.add(trace::Action::CreateRenderPipeline(id.0, desc.clone())),
Some(ref trace) => trace.lock().add(trace::Action::CreateRenderPipeline(
id.0,
wgt::RenderPipelineDescriptor {
layout: Some(layout_id),
..desc.clone()
},
)),
None => (),
};
Ok((id.0, derived_bind_group_count))
}
/// Get an ID of one of the bind group layouts. The ID adds a refcount,
/// which needs to be released by calling `bind_group_layout_drop`.
pub fn render_pipeline_get_bind_group_layout<B: GfxBackend>(
&self,
pipeline_id: id::RenderPipelineId,
index: u32,
) -> Result<id::BindGroupLayoutId, binding_model::GetBindGroupLayoutError> {
let hub = B::hub(self);
let mut token = Token::root();
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
let (_, mut token) = hub.bind_groups.read(&mut token);
let (pipeline_guard, _) = hub.render_pipelines.read(&mut token);
let pipeline = pipeline_guard
.get(pipeline_id)
.map_err(|_| binding_model::GetBindGroupLayoutError::InvalidPipeline)?;
let id = pipeline_layout_guard[pipeline.layout_id.value]
.bind_group_layout_ids
.get(index as usize)
.ok_or(binding_model::GetBindGroupLayoutError::InvalidGroupIndex(
index,
))?;
bgl_guard[*id].multi_ref_count.inc();
Ok(id.0)
}
@ -2810,7 +3031,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
device_id: id::DeviceId,
desc: &pipeline::ComputePipelineDescriptor,
id_in: Input<G, id::ComputePipelineId>,
) -> Result<id::ComputePipelineId, pipeline::CreateComputePipelineError> {
implicit_pipeline_ids: Option<ImplicitPipelineIds<G>>,
) -> Result<
(id::ComputePipelineId, pipeline::ImplicitBindGroupCount),
pipeline::CreateComputePipelineError,
> {
span!(_guard, INFO, "Device::create_compute_pipeline");
let hub = B::hub(self);
@ -2820,24 +3045,19 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let device = device_guard
.get(device_id)
.map_err(|_| DeviceError::Invalid)?;
let (raw_pipeline, layout_ref_count) = {
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
let layout = pipeline_layout_guard
.get(desc.layout)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
let group_layouts = layout
.bind_group_layout_ids
.iter()
.map(|&id| &bgl_guard[id].entries)
.collect::<ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>>();
let (raw_pipeline, layout_id, layout_ref_count, derived_bind_group_count) = {
//TODO: only lock mutable if the layout is derived
let (mut pipeline_layout_guard, mut token) = hub.pipeline_layouts.write(&mut token);
let (mut bgl_guard, mut token) = hub.bind_group_layouts.write(&mut token);
let mut derived_group_layouts =
ArrayVec::<[binding_model::BindEntryMap; MAX_BIND_GROUPS]>::new();
let interface = validation::StageInterface::default();
let pipeline_stage = &desc.compute_stage;
let (shader_module_guard, _) = hub.shader_modules.read(&mut token);
let entry_point_name = &pipeline_stage.entry_point;
let shader_module = shader_module_guard
.get(pipeline_stage.module)
.map_err(|_| {
@ -2847,14 +3067,34 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
})?;
if let Some(ref module) = shader_module.module {
let group_layouts = match desc.layout {
Some(pipeline_layout_id) => Device::get_introspection_bind_group_layouts(
pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?,
&*bgl_guard,
),
None => {
for _ in 0..device.limits.max_bind_groups {
derived_group_layouts.push(binding_model::BindEntryMap::default());
}
validation::IntrospectionBindGroupLayouts::Derived(
&mut derived_group_layouts,
)
}
};
let _ = validation::check_stage(
module,
&group_layouts,
group_layouts,
&entry_point_name,
naga::ShaderStage::Compute,
interface,
)
.map_err(pipeline::CreateComputePipelineError::Stage)?;
} else if desc.layout.is_none() {
Err(pipeline::ImplicitLayoutError::ReflectionError(
wgt::ShaderStage::COMPUTE,
))?
}
let shader = hal::pso::EntryPoint::<B> {
@ -2868,6 +3108,21 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// TODO
let parent = hal::pso::BasePipeline::None;
let (pipeline_layout_id, derived_bind_group_count) = match desc.layout {
Some(id) => (id, 0),
None => self.derive_pipeline_layout(
device,
device_id,
implicit_pipeline_ids,
derived_group_layouts,
&mut *bgl_guard,
&mut *pipeline_layout_guard,
)?,
};
let layout = pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
let pipeline_desc = hal::pso::ComputePipelineDesc {
shader,
layout: &layout.raw,
@ -2876,7 +3131,12 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
};
match unsafe { device.raw.create_compute_pipeline(&pipeline_desc, None) } {
Ok(pipeline) => (pipeline, layout.life_guard.add_ref()),
Ok(pipeline) => (
pipeline,
pipeline_layout_id,
layout.life_guard.add_ref(),
derived_bind_group_count,
),
Err(hal::pso::CreationError::OutOfMemory(_)) => {
return Err(pipeline::CreateComputePipelineError::Device(
DeviceError::OutOfMemory,
@ -2889,7 +3149,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let pipeline = pipeline::ComputePipeline {
raw: raw_pipeline,
layout_id: Stored {
value: id::Valid(desc.layout),
value: id::Valid(layout_id),
ref_count: layout_ref_count,
},
device_id: Stored {
@ -2904,11 +3164,42 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
#[cfg(feature = "trace")]
match device.trace {
Some(ref trace) => trace
.lock()
.add(trace::Action::CreateComputePipeline(id.0, desc.clone())),
Some(ref trace) => trace.lock().add(trace::Action::CreateComputePipeline(
id.0,
wgt::ComputePipelineDescriptor {
layout: Some(layout_id),
..desc.clone()
},
)),
None => (),
};
Ok((id.0, derived_bind_group_count))
}
/// Get an ID of one of the bind group layouts. The ID adds a refcount,
/// which needs to be released by calling `bind_group_layout_drop`.
pub fn compute_pipeline_get_bind_group_layout<B: GfxBackend>(
&self,
pipeline_id: id::ComputePipelineId,
index: u32,
) -> Result<id::BindGroupLayoutId, binding_model::GetBindGroupLayoutError> {
let hub = B::hub(self);
let mut token = Token::root();
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
let (_, mut token) = hub.bind_groups.read(&mut token);
let (pipeline_guard, _) = hub.compute_pipelines.read(&mut token);
let pipeline = pipeline_guard
.get(pipeline_id)
.map_err(|_| binding_model::GetBindGroupLayoutError::InvalidPipeline)?;
let id = pipeline_layout_guard[pipeline.layout_id.value]
.bind_group_layout_ids
.get(index as usize)
.ok_or(binding_model::GetBindGroupLayoutError::InvalidGroupIndex(
index,
))?;
bgl_guard[*id].multi_ref_count.inc();
Ok(id.0)
}

View File

@ -429,6 +429,17 @@ impl<T, I: TypedId + Copy, F: IdentityHandlerFactory<I>> Registry<T, I, F> {
Valid(id)
}
pub(crate) fn register_identity_locked(
&self,
id_in: <F::Filter as IdentityHandler<I>>::Input,
value: T,
guard: &mut Storage<T, I>,
) -> Valid<I> {
let id = self.identity.process(id_in, self.backend);
guard.insert(id, value);
Valid(id)
}
pub fn register_error<A: Access<T>>(
&self,
id_in: <F::Filter as IdentityHandler<I>>::Input,

View File

@ -3,6 +3,7 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
use crate::{
binding_model::{CreateBindGroupLayoutError, CreatePipelineLayoutError},
device::{DeviceError, RenderPassContext},
id::{DeviceId, PipelineLayoutId, ShaderModuleId},
validation::StageError,
@ -37,6 +38,21 @@ pub enum CreateShaderModuleError {
pub type ProgrammableStageDescriptor<'a> = wgt::ProgrammableStageDescriptor<'a, ShaderModuleId>;
/// Number of implicit bind groups derived at pipeline creation.
pub type ImplicitBindGroupCount = u8;
#[derive(Clone, Debug, Error)]
pub enum ImplicitLayoutError {
#[error("missing IDs for deriving {0} bind groups")]
MissingIds(ImplicitBindGroupCount),
#[error("unable to reflect the shader {0:?} interface")]
ReflectionError(wgt::ShaderStage),
#[error(transparent)]
BindGroup(#[from] CreateBindGroupLayoutError),
#[error(transparent)]
Pipeline(#[from] CreatePipelineLayoutError),
}
pub type ComputePipelineDescriptor<'a> =
wgt::ComputePipelineDescriptor<PipelineLayoutId, ProgrammableStageDescriptor<'a>>;
@ -44,8 +60,10 @@ pub type ComputePipelineDescriptor<'a> =
pub enum CreateComputePipelineError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error("pipelie layout is invalid")]
#[error("pipeline layout is invalid")]
InvalidLayout,
#[error("unable to derive an implicit layout")]
Implicit(#[from] ImplicitLayoutError),
#[error(transparent)]
Stage(StageError),
}
@ -73,6 +91,8 @@ pub enum CreateRenderPipelineError {
Device(#[from] DeviceError),
#[error("pipelie layout is invalid")]
InvalidLayout,
#[error("unable to derive an implicit layout")]
Implicit(#[from] ImplicitLayoutError),
#[error("incompatible output format at index {index}")]
IncompatibleOutputFormat { index: u8 },
#[error("invalid sample count {0}")]

View File

@ -2,7 +2,9 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
use crate::{binding_model::BindEntryMap, FastHashMap};
use crate::{binding_model::BindEntryMap, FastHashMap, MAX_BIND_GROUPS};
use arrayvec::ArrayVec;
use std::collections::hash_map::Entry;
use thiserror::Error;
use wgt::{BindGroupLayoutEntry, BindingType};
@ -71,6 +73,8 @@ pub enum BindingError {
WrongTextureMultisampled,
#[error("comparison flag doesn't match the shader")]
WrongSamplerComparison,
#[error("derived bind group layout type is not consistent between stages")]
InconsistentlyDerivedType,
}
#[derive(Clone, Debug, Error)]
@ -159,13 +163,12 @@ fn get_aligned_type_size(
}
}
fn check_binding(
fn check_binding_use(
module: &naga::Module,
var: &naga::GlobalVariable,
entry: &BindGroupLayoutEntry,
usage: naga::GlobalUse,
) -> Result<(), BindingError> {
let allowed_usage = match module.types[var.ty].inner {
) -> Result<naga::GlobalUse, BindingError> {
match module.types[var.ty].inner {
naga::TypeInner::Struct { ref members } => {
let (allowed_usage, min_size) = match entry.ty {
BindingType::UniformBuffer {
@ -196,17 +199,17 @@ fn check_binding(
}
_ => (),
}
allowed_usage
Ok(allowed_usage)
}
naga::TypeInner::Sampler { comparison } => match entry.ty {
BindingType::Sampler { comparison: cmp } => {
if cmp == comparison {
naga::GlobalUse::empty()
Ok(naga::GlobalUse::empty())
} else {
return Err(BindingError::WrongSamplerComparison);
Err(BindingError::WrongSamplerComparison)
}
}
_ => return Err(BindingError::WrongType),
_ => Err(BindingError::WrongType),
},
naga::TypeInner::Image { base, dim, flags } => {
if flags.contains(naga::ImageFlags::MULTISAMPLED) {
@ -284,14 +287,9 @@ fn check_binding(
if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) {
return Err(BindingError::WrongTextureSampled);
}
allowed_usage
Ok(allowed_usage)
}
_ => return Err(BindingError::WrongType),
};
if allowed_usage.contains(usage) {
Ok(())
} else {
Err(BindingError::WrongUsage(usage))
_ => Err(BindingError::WrongType),
}
}
@ -685,9 +683,80 @@ pub fn check_texture_format(format: wgt::TextureFormat, output: &naga::TypeInner
pub type StageInterface<'a> = FastHashMap<wgt::ShaderLocation, MaybeOwned<'a, naga::TypeInner>>;
pub enum IntrospectionBindGroupLayouts<'a> {
Given(ArrayVec<[&'a BindEntryMap; MAX_BIND_GROUPS]>),
Derived(&'a mut [BindEntryMap]),
}
fn derive_binding_type(
module: &naga::Module,
var: &naga::GlobalVariable,
usage: naga::GlobalUse,
) -> Result<BindingType, BindingError> {
let ty = &module.types[var.ty];
Ok(match ty.inner {
naga::TypeInner::Struct { ref members } => {
let dynamic = false;
let mut actual_size = 0;
for (i, member) in members.iter().enumerate() {
actual_size += get_aligned_type_size(module, member.ty, i + 1 == members.len());
}
match var.class {
naga::StorageClass::Uniform => BindingType::UniformBuffer {
dynamic,
min_binding_size: wgt::BufferSize::new(actual_size),
},
naga::StorageClass::StorageBuffer => BindingType::StorageBuffer {
dynamic,
min_binding_size: wgt::BufferSize::new(actual_size),
readonly: !usage.contains(naga::GlobalUse::STORE), //TODO: clarify
},
_ => return Err(BindingError::WrongType),
}
}
naga::TypeInner::Sampler { comparison } => BindingType::Sampler { comparison },
naga::TypeInner::Image { base, dim, flags } => {
let array = flags.contains(naga::ImageFlags::ARRAYED);
let dimension = match dim {
naga::ImageDimension::D1 => wgt::TextureViewDimension::D1,
naga::ImageDimension::D2 if array => wgt::TextureViewDimension::D2Array,
naga::ImageDimension::D2 => wgt::TextureViewDimension::D2,
naga::ImageDimension::D3 => wgt::TextureViewDimension::D3,
naga::ImageDimension::Cube if array => wgt::TextureViewDimension::CubeArray,
naga::ImageDimension::Cube => wgt::TextureViewDimension::Cube,
};
if flags.contains(naga::ImageFlags::SAMPLED) {
BindingType::SampledTexture {
dimension,
component_type: match module.types[base].inner {
naga::TypeInner::Scalar { kind, .. }
| naga::TypeInner::Vector { kind, .. } => match kind {
naga::ScalarKind::Float => wgt::TextureComponentType::Float,
naga::ScalarKind::Sint => wgt::TextureComponentType::Sint,
naga::ScalarKind::Uint => wgt::TextureComponentType::Uint,
other => {
return Err(BindingError::WrongTextureComponentType(Some(other)))
}
},
_ => return Err(BindingError::WrongTextureComponentType(None)),
},
multisampled: flags.contains(naga::ImageFlags::MULTISAMPLED),
}
} else {
BindingType::StorageTexture {
dimension,
format: wgt::TextureFormat::Rgba8Unorm, //TODO
readonly: !flags.contains(naga::ImageFlags::CAN_STORE),
}
}
}
_ => return Err(BindingError::WrongType),
})
}
pub fn check_stage<'a>(
module: &'a naga::Module,
group_layouts: &[&BindEntryMap],
mut group_layouts: IntrospectionBindGroupLayouts,
entry_point_name: &str,
stage: naga::ShaderStage,
inputs: StageInterface<'a>,
@ -713,18 +782,49 @@ pub fn check_stage<'a>(
}
match var.binding {
Some(naga::Binding::Descriptor { set, binding }) => {
let result = group_layouts
.get(set as usize)
.and_then(|map| map.get(&binding))
.ok_or(BindingError::Missing)
.and_then(|entry| {
if entry.visibility.contains(stage_bit) {
Ok(entry)
} else {
Err(BindingError::Invisible)
}
})
.and_then(|entry| check_binding(module, var, entry, usage));
let result = match group_layouts {
IntrospectionBindGroupLayouts::Given(ref layouts) => layouts
.get(set as usize)
.and_then(|map| map.get(&binding))
.ok_or(BindingError::Missing)
.and_then(|entry| {
if entry.visibility.contains(stage_bit) {
Ok(entry)
} else {
Err(BindingError::Invisible)
}
})
.and_then(|entry| check_binding_use(module, var, entry))
.and_then(|allowed_usage| {
if allowed_usage.contains(usage) {
Ok(())
} else {
Err(BindingError::WrongUsage(usage))
}
}),
IntrospectionBindGroupLayouts::Derived(ref mut layouts) => layouts
.get_mut(set as usize)
.ok_or(BindingError::Missing)
.and_then(|set| {
let ty = derive_binding_type(module, var, usage)?;
Ok(match set.entry(binding) {
Entry::Occupied(e) if e.get().ty != ty => {
return Err(BindingError::InconsistentlyDerivedType)
}
Entry::Occupied(e) => {
e.into_mut().visibility |= stage_bit;
}
Entry::Vacant(e) => {
e.insert(BindGroupLayoutEntry {
binding,
ty,
visibility: stage_bit,
count: None,
});
}
})
}),
};
if let Err(error) = result {
return Err(StageError::Binding {
set,

View File

@ -2104,7 +2104,7 @@ impl<'a, M> ProgrammableStageDescriptor<'a, M> {
#[cfg_attr(feature = "replay", derive(serde::Deserialize))]
pub struct RenderPipelineDescriptor<'a, L, D> {
/// The layout of bind groups for this pipeline.
pub layout: L,
pub layout: Option<L>,
/// The compiled vertex stage and its entry point.
pub vertex_stage: D,
/// The compiled fragment stage and its entry point, if any.
@ -2136,14 +2136,13 @@ pub struct RenderPipelineDescriptor<'a, L, D> {
impl<'a, L, D> RenderPipelineDescriptor<'a, L, D> {
pub fn new(
layout: L,
vertex_stage: D,
primitive_topology: PrimitiveTopology,
color_states: impl IntoCow<'a, [ColorStateDescriptor]>,
vertex_state: VertexStateDescriptor<'a>,
) -> Self {
Self {
layout,
layout: None,
vertex_stage,
fragment_stage: None,
rasterization_state: None,
@ -2157,6 +2156,11 @@ impl<'a, L, D> RenderPipelineDescriptor<'a, L, D> {
}
}
pub fn layout(&mut self, layout: L) -> &mut Self {
self.layout = Some(layout);
self
}
pub fn fragment_stage(&mut self, stage: D) -> &mut Self {
self.fragment_stage = Some(stage);
self
@ -2194,18 +2198,23 @@ impl<'a, L, D> RenderPipelineDescriptor<'a, L, D> {
#[cfg_attr(feature = "replay", derive(serde::Deserialize))]
pub struct ComputePipelineDescriptor<L, D> {
/// The layout of bind groups for this pipeline.
pub layout: L,
pub layout: Option<L>,
/// The compiled compute stage and its entry point.
pub compute_stage: D,
}
impl<L, D> ComputePipelineDescriptor<L, D> {
pub fn new(layout: L, compute_stage: D) -> Self {
pub fn new(compute_stage: D) -> Self {
Self {
layout,
layout: None,
compute_stage,
}
}
pub fn layout(&mut self, layout: L) -> &mut Self {
self.layout = Some(layout);
self
}
}
/// Describes a [`CommandBuffer`].