implement exclusive pipeline validation

This gets the bind_group_layout_dedup test passing again.
This commit is contained in:
teoxoy 2024-06-28 17:38:15 +02:00 committed by Teodor Tanasoaia
parent 4a19ac279c
commit 9d3d4ee297
4 changed files with 115 additions and 26 deletions

View File

@ -6,6 +6,7 @@ use crate::{
hal_api::HalApi, hal_api::HalApi,
id::{BindGroupLayoutId, BufferId, SamplerId, TextureViewId}, id::{BindGroupLayoutId, BufferId, SamplerId, TextureViewId},
init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction}, init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction},
pipeline::{ComputePipeline, RenderPipeline},
resource::{ resource::{
DestroyedResourceError, Labeled, MissingBufferUsageError, MissingTextureUsageError, DestroyedResourceError, Labeled, MissingBufferUsageError, MissingTextureUsageError,
ResourceErrorIdent, TrackingData, ResourceErrorIdent, TrackingData,
@ -18,12 +19,17 @@ use crate::{
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use once_cell::sync::OnceCell;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::Deserialize; use serde::Deserialize;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::Serialize; use serde::Serialize;
use std::{borrow::Cow, ops::Range, sync::Arc}; use std::{
borrow::Cow,
ops::Range,
sync::{Arc, Weak},
};
use thiserror::Error; use thiserror::Error;
@ -437,6 +443,38 @@ pub struct BindGroupLayoutDescriptor<'a> {
pub entries: Cow<'a, [wgt::BindGroupLayoutEntry]>, pub entries: Cow<'a, [wgt::BindGroupLayoutEntry]>,
} }
/// Used by [`BindGroupLayout`]. It indicates whether the BGL must be
/// used with a specific pipeline. This constraint only happens when
/// the BGLs have been derived from a pipeline without a layout.
#[derive(Debug)]
pub(crate) enum ExclusivePipeline<A: HalApi> {
None,
Render(Weak<RenderPipeline<A>>),
Compute(Weak<ComputePipeline<A>>),
}
impl<A: HalApi> std::fmt::Display for ExclusivePipeline<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExclusivePipeline::None => f.write_str("None"),
ExclusivePipeline::Render(p) => {
if let Some(p) = p.upgrade() {
p.error_ident().fmt(f)
} else {
f.write_str("RenderPipeline")
}
}
ExclusivePipeline::Compute(p) => {
if let Some(p) = p.upgrade() {
p.error_ident().fmt(f)
} else {
f.write_str("ComputePipeline")
}
}
}
}
}
/// Bind group layout. /// Bind group layout.
#[derive(Debug)] #[derive(Debug)]
pub struct BindGroupLayout<A: HalApi> { pub struct BindGroupLayout<A: HalApi> {
@ -450,6 +488,7 @@ pub struct BindGroupLayout<A: HalApi> {
/// We cannot unconditionally remove from the pool, as BGLs that don't come from the pool /// We cannot unconditionally remove from the pool, as BGLs that don't come from the pool
/// (derived BGLs) must not be removed. /// (derived BGLs) must not be removed.
pub(crate) origin: bgl::Origin, pub(crate) origin: bgl::Origin,
pub(crate) exclusive_pipeline: OnceCell<ExclusivePipeline<A>>,
#[allow(unused)] #[allow(unused)]
pub(crate) binding_count_validator: BindingTypeMaxCountValidator, pub(crate) binding_count_validator: BindingTypeMaxCountValidator,
/// The `label` from the descriptor used to create the resource. /// The `label` from the descriptor used to create the resource.

View File

@ -18,12 +18,15 @@ mod compat {
use crate::{ use crate::{
binding_model::BindGroupLayout, binding_model::BindGroupLayout,
device::bgl,
error::MultiError, error::MultiError,
hal_api::HalApi, hal_api::HalApi,
resource::{Labeled, ParentDevice, ResourceErrorIdent}, resource::{Labeled, ParentDevice, ResourceErrorIdent},
}; };
use std::{num::NonZeroU32, ops::Range, sync::Arc}; use std::{
num::NonZeroU32,
ops::Range,
sync::{Arc, Weak},
};
pub(crate) enum Error { pub(crate) enum Error {
Incompatible { Incompatible {
@ -74,28 +77,41 @@ mod compat {
Ok(()) Ok(())
} else { } else {
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]
#[error("Expected an {expected_bgl_type} bind group layout, got an {assigned_bgl_type} bind group layout")] #[error(
struct IncompatibleTypes { "Exclusive pipelines don't match: expected {expected}, got {assigned}"
expected_bgl_type: &'static str, )]
assigned_bgl_type: &'static str, struct IncompatibleExclusivePipelines {
expected: String,
assigned: String,
} }
if expected_bgl.origin != assigned_bgl.origin { use crate::binding_model::ExclusivePipeline;
fn get_bgl_type(origin: bgl::Origin) -> &'static str { match (
match origin { expected_bgl.exclusive_pipeline.get().unwrap(),
bgl::Origin::Derived => "implicit", assigned_bgl.exclusive_pipeline.get().unwrap(),
bgl::Origin::Pool => "explicit", ) {
} (ExclusivePipeline::None, ExclusivePipeline::None) => {}
(
ExclusivePipeline::Render(e_pipeline),
ExclusivePipeline::Render(a_pipeline),
) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
(
ExclusivePipeline::Compute(e_pipeline),
ExclusivePipeline::Compute(a_pipeline),
) if Weak::ptr_eq(e_pipeline, a_pipeline) => {}
(expected, assigned) => {
return Err(Error::Incompatible {
expected_bgl: expected_bgl.error_ident(),
assigned_bgl: assigned_bgl.error_ident(),
inner: MultiError::new(core::iter::once(
IncompatibleExclusivePipelines {
expected: expected.to_string(),
assigned: assigned.to_string(),
},
))
.unwrap(),
});
} }
return Err(Error::Incompatible {
expected_bgl: expected_bgl.error_ident(),
assigned_bgl: assigned_bgl.error_ident(),
inner: MultiError::new(core::iter::once(IncompatibleTypes {
expected_bgl_type: get_bgl_type(expected_bgl.origin),
assigned_bgl_type: get_bgl_type(assigned_bgl.origin),
}))
.unwrap(),
});
} }
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]

View File

@ -1014,6 +1014,9 @@ impl Global {
let bgl_result = device.bgl_pool.get_or_init(entry_map, |entry_map| { let bgl_result = device.bgl_pool.get_or_init(entry_map, |entry_map| {
let bgl = let bgl =
device.create_bind_group_layout(&desc.label, entry_map, bgl::Origin::Pool)?; device.create_bind_group_layout(&desc.label, entry_map, bgl::Origin::Pool)?;
bgl.exclusive_pipeline
.set(binding_model::ExclusivePipeline::None)
.unwrap();
let (id_inner, arc) = fid.take().unwrap().assign(Arc::new(bgl)); let (id_inner, arc) = fid.take().unwrap().assign(Arc::new(bgl));
id = Some(id_inner); id = Some(id_inner);
@ -1633,7 +1636,7 @@ impl Global {
Err(e) => break 'error e, Err(e) => break 'error e,
}; };
let (id, resource) = fid.assign(Arc::new(pipeline)); let (id, resource) = fid.assign(pipeline);
api_log!("Device::create_render_pipeline -> {id:?}"); api_log!("Device::create_render_pipeline -> {id:?}");
device device
@ -1772,7 +1775,7 @@ impl Global {
Err(e) => break 'error e, Err(e) => break 'error e,
}; };
let (id, resource) = fid.assign(Arc::new(pipeline)); let (id, resource) = fid.assign(pipeline);
api_log!("Device::create_compute_pipeline -> {id:?}"); api_log!("Device::create_compute_pipeline -> {id:?}");
device device

View File

@ -1840,6 +1840,7 @@ impl<A: HalApi> Device<A> {
device: self.clone(), device: self.clone(),
entries: entry_map, entries: entry_map,
origin, origin,
exclusive_pipeline: OnceCell::new(),
binding_count_validator: count_validator, binding_count_validator: count_validator,
label: label.to_string(), label: label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.bind_group_layouts.clone()), tracking_data: TrackingData::new(self.tracker_indices.bind_group_layouts.clone()),
@ -2607,7 +2608,7 @@ impl<A: HalApi> Device<A> {
desc: &pipeline::ComputePipelineDescriptor, desc: &pipeline::ComputePipelineDescriptor,
implicit_context: Option<ImplicitPipelineContext>, implicit_context: Option<ImplicitPipelineContext>,
hub: &Hub<A>, hub: &Hub<A>,
) -> Result<pipeline::ComputePipeline<A>, pipeline::CreateComputePipelineError> { ) -> Result<Arc<pipeline::ComputePipeline<A>>, pipeline::CreateComputePipelineError> {
self.check_is_valid()?; self.check_is_valid()?;
// This has to be done first, or otherwise the IDs may be pointing to entries // This has to be done first, or otherwise the IDs may be pointing to entries
@ -2630,6 +2631,8 @@ impl<A: HalApi> Device<A> {
shader_module.same_device(self)?; shader_module.same_device(self)?;
let is_auto_layout = desc.layout.is_none();
// Get the pipeline layout from the desc if it is provided. // Get the pipeline layout from the desc if it is provided.
let pipeline_layout = match desc.layout { let pipeline_layout = match desc.layout {
Some(pipeline_layout_id) => { Some(pipeline_layout_id) => {
@ -2744,6 +2747,19 @@ impl<A: HalApi> Device<A> {
label: desc.label.to_string(), label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.compute_pipelines.clone()), tracking_data: TrackingData::new(self.tracker_indices.compute_pipelines.clone()),
}; };
let pipeline = Arc::new(pipeline);
if is_auto_layout {
for bgl in pipeline.layout.bind_group_layouts.iter() {
bgl.exclusive_pipeline
.set(binding_model::ExclusivePipeline::Compute(Arc::downgrade(
&pipeline,
)))
.unwrap();
}
}
Ok(pipeline) Ok(pipeline)
} }
@ -2753,7 +2769,7 @@ impl<A: HalApi> Device<A> {
desc: &pipeline::RenderPipelineDescriptor, desc: &pipeline::RenderPipelineDescriptor,
implicit_context: Option<ImplicitPipelineContext>, implicit_context: Option<ImplicitPipelineContext>,
hub: &Hub<A>, hub: &Hub<A>,
) -> Result<pipeline::RenderPipeline<A>, pipeline::CreateRenderPipelineError> { ) -> Result<Arc<pipeline::RenderPipeline<A>>, pipeline::CreateRenderPipelineError> {
use wgt::TextureFormatFeatureFlags as Tfff; use wgt::TextureFormatFeatureFlags as Tfff;
self.check_is_valid()?; self.check_is_valid()?;
@ -3070,6 +3086,8 @@ impl<A: HalApi> Device<A> {
return Err(pipeline::CreateRenderPipelineError::NoTargetSpecified); return Err(pipeline::CreateRenderPipelineError::NoTargetSpecified);
} }
let is_auto_layout = desc.layout.is_none();
// Get the pipeline layout from the desc if it is provided. // Get the pipeline layout from the desc if it is provided.
let pipeline_layout = match desc.layout { let pipeline_layout = match desc.layout {
Some(pipeline_layout_id) => { Some(pipeline_layout_id) => {
@ -3396,6 +3414,19 @@ impl<A: HalApi> Device<A> {
label: desc.label.to_string(), label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.render_pipelines.clone()), tracking_data: TrackingData::new(self.tracker_indices.render_pipelines.clone()),
}; };
let pipeline = Arc::new(pipeline);
if is_auto_layout {
for bgl in pipeline.layout.bind_group_layouts.iter() {
bgl.exclusive_pipeline
.set(binding_model::ExclusivePipeline::Render(Arc::downgrade(
&pipeline,
)))
.unwrap();
}
}
Ok(pipeline) Ok(pipeline)
} }