move out ID to resource mapping code from Device.create_bind_group

This commit is contained in:
teoxoy 2024-07-02 12:20:15 +02:00 committed by Teodor Tanasoaia
parent 1be51946e3
commit b61be30e53
3 changed files with 208 additions and 92 deletions

View File

@ -7,8 +7,8 @@ use crate::{
init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction},
pipeline::{ComputePipeline, RenderPipeline},
resource::{
DestroyedResourceError, Labeled, MissingBufferUsageError, MissingTextureUsageError,
ResourceErrorIdent, TrackingData,
Buffer, DestroyedResourceError, Labeled, MissingBufferUsageError, MissingTextureUsageError,
ResourceErrorIdent, Sampler, TextureView, TrackingData,
},
resource_log,
snatch::{SnatchGuard, Snatchable},
@ -414,6 +414,16 @@ pub struct BindGroupEntry<'a> {
pub resource: BindingResource<'a>,
}
/// Bindable resource and the slot to bind it to.
#[derive(Clone, Debug)]
pub struct ResolvedBindGroupEntry<'a, A: HalApi> {
/// Slot for which binding provides resource. Corresponds to an entry of the same
/// binding index in the [`BindGroupLayoutDescriptor`].
pub binding: u32,
/// Resource to attach to the binding
pub resource: ResolvedBindingResource<'a, A>,
}
/// Describes a group of bindings and the resources to be bound.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@ -428,6 +438,19 @@ pub struct BindGroupDescriptor<'a> {
pub entries: Cow<'a, [BindGroupEntry<'a>]>,
}
/// Describes a group of bindings and the resources to be bound.
#[derive(Clone, Debug)]
pub struct ResolvedBindGroupDescriptor<'a, A: HalApi> {
/// Debug label of the bind group.
///
/// This will show up in graphics debuggers for easy identification.
pub label: Label<'a>,
/// The [`BindGroupLayout`] that corresponds to this bind group.
pub layout: Arc<BindGroupLayout<A>>,
/// The resources to bind to this bind group.
pub entries: Cow<'a, [ResolvedBindGroupEntry<'a, A>]>,
}
/// Describes a [`BindGroupLayout`].
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -757,6 +780,13 @@ pub struct BufferBinding {
pub size: Option<wgt::BufferSize>,
}
#[derive(Clone, Debug)]
pub struct ResolvedBufferBinding<A: HalApi> {
pub buffer: Arc<Buffer<A>>,
pub offset: wgt::BufferAddress,
pub size: Option<wgt::BufferSize>,
}
// Note: Duplicated in `wgpu-rs` as `BindingResource`
// They're different enough that it doesn't make sense to share a common type
#[derive(Debug, Clone)]
@ -770,6 +800,18 @@ pub enum BindingResource<'a> {
TextureViewArray(Cow<'a, [TextureViewId]>),
}
// Note: Duplicated in `wgpu-rs` as `BindingResource`
// They're different enough that it doesn't make sense to share a common type
#[derive(Debug, Clone)]
pub enum ResolvedBindingResource<'a, A: HalApi> {
Buffer(ResolvedBufferBinding<A>),
BufferArray(Cow<'a, [ResolvedBufferBinding<A>]>),
Sampler(Arc<Sampler<A>>),
SamplerArray(Cow<'a, [Arc<Sampler<A>>]>),
TextureView(Arc<TextureView<A>>),
TextureViewArray(Cow<'a, [Arc<TextureView<A>>]>),
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum BindError {

View File

@ -1,7 +1,12 @@
#[cfg(feature = "trace")]
use crate::device::trace;
use crate::{
api_log, binding_model, command, conv,
api_log,
binding_model::{
self, BindGroupEntry, BindingResource, BufferBinding, ResolvedBindGroupDescriptor,
ResolvedBindGroupEntry, ResolvedBindingResource, ResolvedBufferBinding,
},
command, conv,
device::{
bgl, life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure,
DeviceLostReason, HostMap,
@ -21,6 +26,7 @@ use crate::{
self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError,
Trackable,
},
storage::Storage,
Label,
};
@ -1157,12 +1163,96 @@ impl Global {
trace.add(trace::Action::CreateBindGroup(fid.id(), desc.clone()));
}
let bind_group_layout = match hub.bind_group_layouts.get(desc.layout) {
let layout = match hub.bind_group_layouts.get(desc.layout) {
Ok(layout) => layout,
Err(..) => break 'error binding_model::CreateBindGroupError::InvalidLayout,
};
let bind_group = match device.create_bind_group(&bind_group_layout, desc, hub) {
fn map_entry<'a, A: HalApi>(
e: &BindGroupEntry<'a>,
buffer_storage: &Storage<resource::Buffer<A>>,
sampler_storage: &Storage<resource::Sampler<A>>,
texture_view_storage: &Storage<resource::TextureView<A>>,
) -> Result<ResolvedBindGroupEntry<'a, A>, binding_model::CreateBindGroupError>
{
let map_buffer = |bb: &BufferBinding| {
buffer_storage
.get_owned(bb.buffer_id)
.map(|buffer| ResolvedBufferBinding {
buffer,
offset: bb.offset,
size: bb.size,
})
.map_err(|_| {
binding_model::CreateBindGroupError::InvalidBufferId(bb.buffer_id)
})
};
let map_sampler = |id: &id::SamplerId| {
sampler_storage
.get_owned(*id)
.map_err(|_| binding_model::CreateBindGroupError::InvalidSamplerId(*id))
};
let map_view = |id: &id::TextureViewId| {
texture_view_storage
.get_owned(*id)
.map_err(|_| binding_model::CreateBindGroupError::InvalidTextureViewId(*id))
};
let resource = match e.resource {
BindingResource::Buffer(ref buffer) => {
ResolvedBindingResource::Buffer(map_buffer(buffer)?)
}
BindingResource::BufferArray(ref buffers) => {
let buffers = buffers
.iter()
.map(map_buffer)
.collect::<Result<Vec<_>, _>>()?;
ResolvedBindingResource::BufferArray(Cow::Owned(buffers))
}
BindingResource::Sampler(ref sampler) => {
ResolvedBindingResource::Sampler(map_sampler(sampler)?)
}
BindingResource::SamplerArray(ref samplers) => {
let samplers = samplers
.iter()
.map(map_sampler)
.collect::<Result<Vec<_>, _>>()?;
ResolvedBindingResource::SamplerArray(Cow::Owned(samplers))
}
BindingResource::TextureView(ref view) => {
ResolvedBindingResource::TextureView(map_view(view)?)
}
BindingResource::TextureViewArray(ref views) => {
let views = views.iter().map(map_view).collect::<Result<Vec<_>, _>>()?;
ResolvedBindingResource::TextureViewArray(Cow::Owned(views))
}
};
Ok(ResolvedBindGroupEntry {
binding: e.binding,
resource,
})
}
let entries = {
let buffer_guard = hub.buffers.read();
let texture_view_guard = hub.texture_views.read();
let sampler_guard = hub.samplers.read();
desc.entries
.iter()
.map(|e| map_entry(e, &buffer_guard, &sampler_guard, &texture_view_guard))
.collect::<Result<Vec<_>, _>>()
};
let entries = match entries {
Ok(entries) => Cow::Owned(entries),
Err(e) => break 'error e,
};
let desc = ResolvedBindGroupDescriptor {
label: desc.label.clone(),
layout,
entries,
};
let bind_group = match device.create_bind_group(desc) {
Ok(bind_group) => bind_group,
Err(e) => break 'error e,
};

View File

@ -12,8 +12,6 @@ use crate::{
},
hal_api::HalApi,
hal_label,
hub::Hub,
id,
init_tracker::{
BufferInitTracker, BufferInitTrackerAction, MemoryInitKind, TextureInitRange,
TextureInitTracker, TextureInitTrackerAction,
@ -28,7 +26,6 @@ use crate::{
},
resource_log,
snatch::{SnatchGuard, SnatchLock, Snatchable},
storage::Storage,
track::{
BindGroupStates, TextureSelector, Tracker, TrackerIndexAllocators, UsageScope,
UsageScopePool,
@ -1847,14 +1844,13 @@ impl<A: HalApi> Device<A> {
pub(crate) fn create_buffer_binding<'a>(
self: &Arc<Self>,
bb: &binding_model::BufferBinding,
bb: &'a binding_model::ResolvedBufferBinding<A>,
binding: u32,
decl: &wgt::BindGroupLayoutEntry,
used_buffer_ranges: &mut Vec<BufferInitTrackerAction<A>>,
dynamic_binding_info: &mut Vec<binding_model::BindGroupDynamicBindingData>,
late_buffer_binding_sizes: &mut FastHashMap<u32, wgt::BufferSize>,
used: &mut BindGroupStates<A>,
storage: &'a Storage<Buffer<A>>,
limits: &wgt::Limits,
snatch_guard: &'a SnatchGuard<'a>,
) -> Result<hal::BufferBinding<'a, A>, binding_model::CreateBindGroupError> {
@ -1902,9 +1898,7 @@ impl<A: HalApi> Device<A> {
));
}
let buffer = storage
.get(bb.buffer_id)
.map_err(|_| Error::InvalidBufferId(bb.buffer_id))?;
let buffer = &bb.buffer;
used.buffers.add_single(buffer, internal_use);
@ -1988,34 +1982,61 @@ impl<A: HalApi> Device<A> {
fn create_sampler_binding<'a>(
self: &Arc<Self>,
used: &BindGroupStates<A>,
storage: &'a Storage<Sampler<A>>,
id: id::Id<id::markers::Sampler>,
) -> Result<&'a Sampler<A>, binding_model::CreateBindGroupError> {
binding: u32,
decl: &wgt::BindGroupLayoutEntry,
sampler: &'a Arc<Sampler<A>>,
) -> Result<&'a A::Sampler, binding_model::CreateBindGroupError> {
use crate::binding_model::CreateBindGroupError as Error;
let sampler = storage.get(id).map_err(|_| Error::InvalidSamplerId(id))?;
used.samplers.add_single(sampler);
sampler.same_device(self)?;
Ok(sampler)
match decl.ty {
wgt::BindingType::Sampler(ty) => {
let (allowed_filtering, allowed_comparison) = match ty {
wgt::SamplerBindingType::Filtering => (None, false),
wgt::SamplerBindingType::NonFiltering => (Some(false), false),
wgt::SamplerBindingType::Comparison => (None, true),
};
if let Some(allowed_filtering) = allowed_filtering {
if allowed_filtering != sampler.filtering {
return Err(Error::WrongSamplerFiltering {
binding,
layout_flt: allowed_filtering,
sampler_flt: sampler.filtering,
});
}
}
if allowed_comparison != sampler.comparison {
return Err(Error::WrongSamplerComparison {
binding,
layout_cmp: allowed_comparison,
sampler_cmp: sampler.comparison,
});
}
}
_ => {
return Err(Error::WrongBindingType {
binding,
actual: decl.ty,
expected: "Sampler",
})
}
}
Ok(sampler.raw())
}
pub(crate) fn create_texture_binding<'a>(
self: &Arc<Self>,
binding: u32,
decl: &wgt::BindGroupLayoutEntry,
storage: &'a Storage<TextureView<A>>,
id: id::Id<id::markers::TextureView>,
view: &'a Arc<TextureView<A>>,
used: &mut BindGroupStates<A>,
used_texture_ranges: &mut Vec<TextureInitTrackerAction<A>>,
snatch_guard: &'a SnatchGuard<'a>,
) -> Result<hal::TextureBinding<'a, A>, binding_model::CreateBindGroupError> {
use crate::binding_model::CreateBindGroupError as Error;
let view = storage
.get(id)
.map_err(|_| Error::InvalidTextureViewId(id))?;
used.views.add_single(view);
view.same_device(self)?;
@ -2058,11 +2079,11 @@ impl<A: HalApi> Device<A> {
// (not passing a duplicate) beforehand.
pub(crate) fn create_bind_group(
self: &Arc<Self>,
layout: &Arc<BindGroupLayout<A>>,
desc: &binding_model::BindGroupDescriptor,
hub: &Hub<A>,
desc: binding_model::ResolvedBindGroupDescriptor<A>,
) -> Result<BindGroup<A>, binding_model::CreateBindGroupError> {
use crate::binding_model::{BindingResource as Br, CreateBindGroupError as Error};
use crate::binding_model::{CreateBindGroupError as Error, ResolvedBindingResource as Br};
let layout = desc.layout;
self.check_is_valid()?;
layout.same_device(self)?;
@ -2087,10 +2108,6 @@ impl<A: HalApi> Device<A> {
// fill out the descriptors
let mut used = BindGroupStates::new();
let buffer_guard = hub.buffers.read();
let texture_view_guard = hub.texture_views.read();
let sampler_guard = hub.samplers.read();
let mut used_buffer_ranges = Vec::new();
let mut used_texture_ranges = Vec::new();
let mut hal_entries = Vec::with_capacity(desc.entries.len());
@ -2115,7 +2132,6 @@ impl<A: HalApi> Device<A> {
&mut dynamic_binding_info,
&mut late_buffer_binding_sizes,
&mut used,
&*buffer_guard,
&self.limits,
&snatch_guard,
)?;
@ -2138,7 +2154,6 @@ impl<A: HalApi> Device<A> {
&mut dynamic_binding_info,
&mut late_buffer_binding_sizes,
&mut used,
&*buffer_guard,
&self.limits,
&snatch_guard,
)?;
@ -2146,63 +2161,31 @@ impl<A: HalApi> Device<A> {
}
(res_index, num_bindings)
}
Br::Sampler(id) => match decl.ty {
wgt::BindingType::Sampler(ty) => {
let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?;
Br::Sampler(ref sampler) => {
let sampler = self.create_sampler_binding(&used, binding, decl, sampler)?;
let (allowed_filtering, allowed_comparison) = match ty {
wgt::SamplerBindingType::Filtering => (None, false),
wgt::SamplerBindingType::NonFiltering => (Some(false), false),
wgt::SamplerBindingType::Comparison => (None, true),
};
if let Some(allowed_filtering) = allowed_filtering {
if allowed_filtering != sampler.filtering {
return Err(Error::WrongSamplerFiltering {
binding,
layout_flt: allowed_filtering,
sampler_flt: sampler.filtering,
});
}
}
if allowed_comparison != sampler.comparison {
return Err(Error::WrongSamplerComparison {
binding,
layout_cmp: allowed_comparison,
sampler_cmp: sampler.comparison,
});
}
let res_index = hal_samplers.len();
hal_samplers.push(sampler.raw());
(res_index, 1)
}
_ => {
return Err(Error::WrongBindingType {
binding,
actual: decl.ty,
expected: "Sampler",
})
}
},
Br::SamplerArray(ref bindings_array) => {
let num_bindings = bindings_array.len();
let res_index = hal_samplers.len();
hal_samplers.push(sampler);
(res_index, 1)
}
Br::SamplerArray(ref samplers) => {
let num_bindings = samplers.len();
Self::check_array_binding(self.features, decl.count, num_bindings)?;
let res_index = hal_samplers.len();
for &id in bindings_array.iter() {
let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?;
for sampler in samplers.iter() {
let sampler = self.create_sampler_binding(&used, binding, decl, sampler)?;
hal_samplers.push(sampler.raw());
hal_samplers.push(sampler);
}
(res_index, num_bindings)
}
Br::TextureView(id) => {
Br::TextureView(ref view) => {
let tb = self.create_texture_binding(
binding,
decl,
&texture_view_guard,
id,
view,
&mut used,
&mut used_texture_ranges,
&snatch_guard,
@ -2211,17 +2194,16 @@ impl<A: HalApi> Device<A> {
hal_textures.push(tb);
(res_index, 1)
}
Br::TextureViewArray(ref bindings_array) => {
let num_bindings = bindings_array.len();
Br::TextureViewArray(ref views) => {
let num_bindings = views.len();
Self::check_array_binding(self.features, decl.count, num_bindings)?;
let res_index = hal_textures.len();
for &id in bindings_array.iter() {
for view in views.iter() {
let tb = self.create_texture_binding(
binding,
decl,
&texture_view_guard,
id,
view,
&mut used,
&mut used_texture_ranges,
&snatch_guard,
@ -2266,22 +2248,24 @@ impl<A: HalApi> Device<A> {
.map_err(DeviceError::from)?
};
// collect in the order of BGL iteration
let late_buffer_binding_sizes = layout
.entries
.indices()
.flat_map(|binding| late_buffer_binding_sizes.get(&binding).cloned())
.collect();
Ok(BindGroup {
raw: Snatchable::new(raw),
device: self.clone(),
layout: layout.clone(),
layout,
label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.bind_groups.clone()),
used,
used_buffer_ranges,
used_texture_ranges,
dynamic_binding_info,
// collect in the order of BGL iteration
late_buffer_binding_sizes: layout
.entries
.indices()
.flat_map(|binding| late_buffer_binding_sizes.get(&binding).cloned())
.collect(),
late_buffer_binding_sizes,
})
}