From b61be30e53d576e3161a1c26fae6a1f058e1eb42 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 2 Jul 2024 12:20:15 +0200 Subject: [PATCH] move out ID to resource mapping code from `Device.create_bind_group` --- wgpu-core/src/binding_model.rs | 46 ++++++++- wgpu-core/src/device/global.rs | 96 ++++++++++++++++++- wgpu-core/src/device/resource.rs | 158 ++++++++++++++----------------- 3 files changed, 208 insertions(+), 92 deletions(-) diff --git a/wgpu-core/src/binding_model.rs b/wgpu-core/src/binding_model.rs index 4a8b0d8a1..729618995 100644 --- a/wgpu-core/src/binding_model.rs +++ b/wgpu-core/src/binding_model.rs @@ -7,8 +7,8 @@ use crate::{ init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction}, pipeline::{ComputePipeline, RenderPipeline}, resource::{ - DestroyedResourceError, Labeled, MissingBufferUsageError, MissingTextureUsageError, - ResourceErrorIdent, TrackingData, + Buffer, DestroyedResourceError, Labeled, MissingBufferUsageError, MissingTextureUsageError, + ResourceErrorIdent, Sampler, TextureView, TrackingData, }, resource_log, snatch::{SnatchGuard, Snatchable}, @@ -414,6 +414,16 @@ pub struct BindGroupEntry<'a> { pub resource: BindingResource<'a>, } +/// Bindable resource and the slot to bind it to. +#[derive(Clone, Debug)] +pub struct ResolvedBindGroupEntry<'a, A: HalApi> { + /// Slot for which binding provides resource. Corresponds to an entry of the same + /// binding index in the [`BindGroupLayoutDescriptor`]. + pub binding: u32, + /// Resource to attach to the binding + pub resource: ResolvedBindingResource<'a, A>, +} + /// Describes a group of bindings and the resources to be bound. #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -428,6 +438,19 @@ pub struct BindGroupDescriptor<'a> { pub entries: Cow<'a, [BindGroupEntry<'a>]>, } +/// Describes a group of bindings and the resources to be bound. +#[derive(Clone, Debug)] +pub struct ResolvedBindGroupDescriptor<'a, A: HalApi> { + /// Debug label of the bind group. + /// + /// This will show up in graphics debuggers for easy identification. + pub label: Label<'a>, + /// The [`BindGroupLayout`] that corresponds to this bind group. + pub layout: Arc>, + /// The resources to bind to this bind group. + pub entries: Cow<'a, [ResolvedBindGroupEntry<'a, A>]>, +} + /// Describes a [`BindGroupLayout`]. #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -757,6 +780,13 @@ pub struct BufferBinding { pub size: Option, } +#[derive(Clone, Debug)] +pub struct ResolvedBufferBinding { + pub buffer: Arc>, + pub offset: wgt::BufferAddress, + pub size: Option, +} + // Note: Duplicated in `wgpu-rs` as `BindingResource` // They're different enough that it doesn't make sense to share a common type #[derive(Debug, Clone)] @@ -770,6 +800,18 @@ pub enum BindingResource<'a> { TextureViewArray(Cow<'a, [TextureViewId]>), } +// Note: Duplicated in `wgpu-rs` as `BindingResource` +// They're different enough that it doesn't make sense to share a common type +#[derive(Debug, Clone)] +pub enum ResolvedBindingResource<'a, A: HalApi> { + Buffer(ResolvedBufferBinding), + BufferArray(Cow<'a, [ResolvedBufferBinding]>), + Sampler(Arc>), + SamplerArray(Cow<'a, [Arc>]>), + TextureView(Arc>), + TextureViewArray(Cow<'a, [Arc>]>), +} + #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum BindError { diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 7f4864c3e..04e43a143 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -1,7 +1,12 @@ #[cfg(feature = "trace")] use crate::device::trace; use crate::{ - api_log, binding_model, command, conv, + api_log, + binding_model::{ + self, BindGroupEntry, BindingResource, BufferBinding, ResolvedBindGroupDescriptor, + ResolvedBindGroupEntry, ResolvedBindingResource, ResolvedBufferBinding, + }, + command, conv, device::{ bgl, life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure, DeviceLostReason, HostMap, @@ -21,6 +26,7 @@ use crate::{ self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError, Trackable, }, + storage::Storage, Label, }; @@ -1157,12 +1163,96 @@ impl Global { trace.add(trace::Action::CreateBindGroup(fid.id(), desc.clone())); } - let bind_group_layout = match hub.bind_group_layouts.get(desc.layout) { + let layout = match hub.bind_group_layouts.get(desc.layout) { Ok(layout) => layout, Err(..) => break 'error binding_model::CreateBindGroupError::InvalidLayout, }; - let bind_group = match device.create_bind_group(&bind_group_layout, desc, hub) { + fn map_entry<'a, A: HalApi>( + e: &BindGroupEntry<'a>, + buffer_storage: &Storage>, + sampler_storage: &Storage>, + texture_view_storage: &Storage>, + ) -> Result, binding_model::CreateBindGroupError> + { + let map_buffer = |bb: &BufferBinding| { + buffer_storage + .get_owned(bb.buffer_id) + .map(|buffer| ResolvedBufferBinding { + buffer, + offset: bb.offset, + size: bb.size, + }) + .map_err(|_| { + binding_model::CreateBindGroupError::InvalidBufferId(bb.buffer_id) + }) + }; + let map_sampler = |id: &id::SamplerId| { + sampler_storage + .get_owned(*id) + .map_err(|_| binding_model::CreateBindGroupError::InvalidSamplerId(*id)) + }; + let map_view = |id: &id::TextureViewId| { + texture_view_storage + .get_owned(*id) + .map_err(|_| binding_model::CreateBindGroupError::InvalidTextureViewId(*id)) + }; + let resource = match e.resource { + BindingResource::Buffer(ref buffer) => { + ResolvedBindingResource::Buffer(map_buffer(buffer)?) + } + BindingResource::BufferArray(ref buffers) => { + let buffers = buffers + .iter() + .map(map_buffer) + .collect::, _>>()?; + ResolvedBindingResource::BufferArray(Cow::Owned(buffers)) + } + BindingResource::Sampler(ref sampler) => { + ResolvedBindingResource::Sampler(map_sampler(sampler)?) + } + BindingResource::SamplerArray(ref samplers) => { + let samplers = samplers + .iter() + .map(map_sampler) + .collect::, _>>()?; + ResolvedBindingResource::SamplerArray(Cow::Owned(samplers)) + } + BindingResource::TextureView(ref view) => { + ResolvedBindingResource::TextureView(map_view(view)?) + } + BindingResource::TextureViewArray(ref views) => { + let views = views.iter().map(map_view).collect::, _>>()?; + ResolvedBindingResource::TextureViewArray(Cow::Owned(views)) + } + }; + Ok(ResolvedBindGroupEntry { + binding: e.binding, + resource, + }) + } + + let entries = { + let buffer_guard = hub.buffers.read(); + let texture_view_guard = hub.texture_views.read(); + let sampler_guard = hub.samplers.read(); + desc.entries + .iter() + .map(|e| map_entry(e, &buffer_guard, &sampler_guard, &texture_view_guard)) + .collect::, _>>() + }; + let entries = match entries { + Ok(entries) => Cow::Owned(entries), + Err(e) => break 'error e, + }; + + let desc = ResolvedBindGroupDescriptor { + label: desc.label.clone(), + layout, + entries, + }; + + let bind_group = match device.create_bind_group(desc) { Ok(bind_group) => bind_group, Err(e) => break 'error e, }; diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index e9983495e..fc9c6a3f0 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -12,8 +12,6 @@ use crate::{ }, hal_api::HalApi, hal_label, - hub::Hub, - id, init_tracker::{ BufferInitTracker, BufferInitTrackerAction, MemoryInitKind, TextureInitRange, TextureInitTracker, TextureInitTrackerAction, @@ -28,7 +26,6 @@ use crate::{ }, resource_log, snatch::{SnatchGuard, SnatchLock, Snatchable}, - storage::Storage, track::{ BindGroupStates, TextureSelector, Tracker, TrackerIndexAllocators, UsageScope, UsageScopePool, @@ -1847,14 +1844,13 @@ impl Device { pub(crate) fn create_buffer_binding<'a>( self: &Arc, - bb: &binding_model::BufferBinding, + bb: &'a binding_model::ResolvedBufferBinding, binding: u32, decl: &wgt::BindGroupLayoutEntry, used_buffer_ranges: &mut Vec>, dynamic_binding_info: &mut Vec, late_buffer_binding_sizes: &mut FastHashMap, used: &mut BindGroupStates, - storage: &'a Storage>, limits: &wgt::Limits, snatch_guard: &'a SnatchGuard<'a>, ) -> Result, binding_model::CreateBindGroupError> { @@ -1902,9 +1898,7 @@ impl Device { )); } - let buffer = storage - .get(bb.buffer_id) - .map_err(|_| Error::InvalidBufferId(bb.buffer_id))?; + let buffer = &bb.buffer; used.buffers.add_single(buffer, internal_use); @@ -1988,34 +1982,61 @@ impl Device { fn create_sampler_binding<'a>( self: &Arc, used: &BindGroupStates, - storage: &'a Storage>, - id: id::Id, - ) -> Result<&'a Sampler, binding_model::CreateBindGroupError> { + binding: u32, + decl: &wgt::BindGroupLayoutEntry, + sampler: &'a Arc>, + ) -> Result<&'a A::Sampler, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; - let sampler = storage.get(id).map_err(|_| Error::InvalidSamplerId(id))?; used.samplers.add_single(sampler); sampler.same_device(self)?; - Ok(sampler) + match decl.ty { + wgt::BindingType::Sampler(ty) => { + let (allowed_filtering, allowed_comparison) = match ty { + wgt::SamplerBindingType::Filtering => (None, false), + wgt::SamplerBindingType::NonFiltering => (Some(false), false), + wgt::SamplerBindingType::Comparison => (None, true), + }; + if let Some(allowed_filtering) = allowed_filtering { + if allowed_filtering != sampler.filtering { + return Err(Error::WrongSamplerFiltering { + binding, + layout_flt: allowed_filtering, + sampler_flt: sampler.filtering, + }); + } + } + if allowed_comparison != sampler.comparison { + return Err(Error::WrongSamplerComparison { + binding, + layout_cmp: allowed_comparison, + sampler_cmp: sampler.comparison, + }); + } + } + _ => { + return Err(Error::WrongBindingType { + binding, + actual: decl.ty, + expected: "Sampler", + }) + } + } + + Ok(sampler.raw()) } pub(crate) fn create_texture_binding<'a>( self: &Arc, binding: u32, decl: &wgt::BindGroupLayoutEntry, - storage: &'a Storage>, - id: id::Id, + view: &'a Arc>, used: &mut BindGroupStates, used_texture_ranges: &mut Vec>, snatch_guard: &'a SnatchGuard<'a>, ) -> Result, binding_model::CreateBindGroupError> { - use crate::binding_model::CreateBindGroupError as Error; - - let view = storage - .get(id) - .map_err(|_| Error::InvalidTextureViewId(id))?; used.views.add_single(view); view.same_device(self)?; @@ -2058,11 +2079,11 @@ impl Device { // (not passing a duplicate) beforehand. pub(crate) fn create_bind_group( self: &Arc, - layout: &Arc>, - desc: &binding_model::BindGroupDescriptor, - hub: &Hub, + desc: binding_model::ResolvedBindGroupDescriptor, ) -> Result, binding_model::CreateBindGroupError> { - use crate::binding_model::{BindingResource as Br, CreateBindGroupError as Error}; + use crate::binding_model::{CreateBindGroupError as Error, ResolvedBindingResource as Br}; + + let layout = desc.layout; self.check_is_valid()?; layout.same_device(self)?; @@ -2087,10 +2108,6 @@ impl Device { // fill out the descriptors let mut used = BindGroupStates::new(); - let buffer_guard = hub.buffers.read(); - let texture_view_guard = hub.texture_views.read(); - let sampler_guard = hub.samplers.read(); - let mut used_buffer_ranges = Vec::new(); let mut used_texture_ranges = Vec::new(); let mut hal_entries = Vec::with_capacity(desc.entries.len()); @@ -2115,7 +2132,6 @@ impl Device { &mut dynamic_binding_info, &mut late_buffer_binding_sizes, &mut used, - &*buffer_guard, &self.limits, &snatch_guard, )?; @@ -2138,7 +2154,6 @@ impl Device { &mut dynamic_binding_info, &mut late_buffer_binding_sizes, &mut used, - &*buffer_guard, &self.limits, &snatch_guard, )?; @@ -2146,63 +2161,31 @@ impl Device { } (res_index, num_bindings) } - Br::Sampler(id) => match decl.ty { - wgt::BindingType::Sampler(ty) => { - let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?; + Br::Sampler(ref sampler) => { + let sampler = self.create_sampler_binding(&used, binding, decl, sampler)?; - let (allowed_filtering, allowed_comparison) = match ty { - wgt::SamplerBindingType::Filtering => (None, false), - wgt::SamplerBindingType::NonFiltering => (Some(false), false), - wgt::SamplerBindingType::Comparison => (None, true), - }; - if let Some(allowed_filtering) = allowed_filtering { - if allowed_filtering != sampler.filtering { - return Err(Error::WrongSamplerFiltering { - binding, - layout_flt: allowed_filtering, - sampler_flt: sampler.filtering, - }); - } - } - if allowed_comparison != sampler.comparison { - return Err(Error::WrongSamplerComparison { - binding, - layout_cmp: allowed_comparison, - sampler_cmp: sampler.comparison, - }); - } - - let res_index = hal_samplers.len(); - hal_samplers.push(sampler.raw()); - (res_index, 1) - } - _ => { - return Err(Error::WrongBindingType { - binding, - actual: decl.ty, - expected: "Sampler", - }) - } - }, - Br::SamplerArray(ref bindings_array) => { - let num_bindings = bindings_array.len(); + let res_index = hal_samplers.len(); + hal_samplers.push(sampler); + (res_index, 1) + } + Br::SamplerArray(ref samplers) => { + let num_bindings = samplers.len(); Self::check_array_binding(self.features, decl.count, num_bindings)?; let res_index = hal_samplers.len(); - for &id in bindings_array.iter() { - let sampler = self.create_sampler_binding(&used, &sampler_guard, id)?; + for sampler in samplers.iter() { + let sampler = self.create_sampler_binding(&used, binding, decl, sampler)?; - hal_samplers.push(sampler.raw()); + hal_samplers.push(sampler); } (res_index, num_bindings) } - Br::TextureView(id) => { + Br::TextureView(ref view) => { let tb = self.create_texture_binding( binding, decl, - &texture_view_guard, - id, + view, &mut used, &mut used_texture_ranges, &snatch_guard, @@ -2211,17 +2194,16 @@ impl Device { hal_textures.push(tb); (res_index, 1) } - Br::TextureViewArray(ref bindings_array) => { - let num_bindings = bindings_array.len(); + Br::TextureViewArray(ref views) => { + let num_bindings = views.len(); Self::check_array_binding(self.features, decl.count, num_bindings)?; let res_index = hal_textures.len(); - for &id in bindings_array.iter() { + for view in views.iter() { let tb = self.create_texture_binding( binding, decl, - &texture_view_guard, - id, + view, &mut used, &mut used_texture_ranges, &snatch_guard, @@ -2266,22 +2248,24 @@ impl Device { .map_err(DeviceError::from)? }; + // collect in the order of BGL iteration + let late_buffer_binding_sizes = layout + .entries + .indices() + .flat_map(|binding| late_buffer_binding_sizes.get(&binding).cloned()) + .collect(); + Ok(BindGroup { raw: Snatchable::new(raw), device: self.clone(), - layout: layout.clone(), + layout, label: desc.label.to_string(), tracking_data: TrackingData::new(self.tracker_indices.bind_groups.clone()), used, used_buffer_ranges, used_texture_ranges, dynamic_binding_info, - // collect in the order of BGL iteration - late_buffer_binding_sizes: layout - .entries - .indices() - .flat_map(|binding| late_buffer_binding_sizes.get(&binding).cloned()) - .collect(), + late_buffer_binding_sizes, }) }