[wgpu-core] make the Registry generic over T: Clone

This commit is contained in:
teoxoy 2024-09-05 12:28:46 +02:00 committed by Teodor Tanasoaia
parent d550342f47
commit 70a9c01b48
6 changed files with 73 additions and 48 deletions

View File

@ -578,7 +578,7 @@ impl RenderBundleEncoder {
fn set_bind_group(
state: &mut State,
bind_group_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<BindGroup>>,
bind_group_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Arc<BindGroup>>>,
dynamic_offsets: &[u32],
index: u32,
num_dynamic_offsets: usize,
@ -630,7 +630,7 @@ fn set_bind_group(
fn set_pipeline(
state: &mut State,
pipeline_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<RenderPipeline>>,
pipeline_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Arc<RenderPipeline>>>,
context: &RenderPassContext,
is_depth_read_only: bool,
is_stencil_read_only: bool,
@ -673,7 +673,7 @@ fn set_pipeline(
fn set_index_buffer(
state: &mut State,
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer>>,
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Arc<Buffer>>>,
buffer_id: id::Id<id::markers::Buffer>,
index_format: wgt::IndexFormat,
offset: u64,
@ -708,7 +708,7 @@ fn set_index_buffer(
fn set_vertex_buffer(
state: &mut State,
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer>>,
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Arc<Buffer>>>,
slot: u32,
buffer_id: id::Id<id::markers::Buffer>,
offset: u64,
@ -852,7 +852,7 @@ fn draw_indexed(
fn multi_draw_indirect(
state: &mut State,
dynamic_offsets: &[u32],
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Buffer>>,
buffer_guard: &crate::lock::RwLockReadGuard<crate::storage::Storage<Arc<Buffer>>>,
buffer_id: id::Id<id::markers::Buffer>,
offset: u64,
indexed: bool,

View File

@ -26,7 +26,11 @@ use crate::{
use wgt::{BufferAddress, TextureFormat};
use std::{borrow::Cow, ptr::NonNull, sync::atomic::Ordering};
use std::{
borrow::Cow,
ptr::NonNull,
sync::{atomic::Ordering, Arc},
};
use super::{ImplicitPipelineIds, UserClosures};
@ -803,9 +807,9 @@ impl Global {
fn resolve_entry<'a>(
e: &BindGroupEntry<'a>,
buffer_storage: &Storage<resource::Buffer>,
sampler_storage: &Storage<resource::Sampler>,
texture_view_storage: &Storage<resource::TextureView>,
buffer_storage: &Storage<Arc<resource::Buffer>>,
sampler_storage: &Storage<Arc<resource::Sampler>>,
texture_view_storage: &Storage<Arc<resource::TextureView>>,
) -> Result<ResolvedBindGroupEntry<'a>, binding_model::CreateBindGroupError>
{
let resolve_buffer = |bb: &BufferBinding| {

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use crate::{
hal_api::HalApi,
hub::{Hub, HubReport},
@ -23,7 +25,7 @@ impl GlobalReport {
pub struct Global {
pub instance: Instance,
pub(crate) surfaces: Registry<Surface>,
pub(crate) surfaces: Registry<Arc<Surface>>,
pub(crate) hub: Hub,
}

View File

@ -110,7 +110,7 @@ use crate::{
resource::{Buffer, QuerySet, Sampler, StagingBuffer, Texture, TextureView},
storage::{Element, Storage},
};
use std::fmt::Debug;
use std::{fmt::Debug, sync::Arc};
#[derive(Debug, PartialEq, Eq)]
pub struct HubReport {
@ -162,24 +162,24 @@ impl HubReport {
///
/// [`A::hub(global)`]: HalApi::hub
pub struct Hub {
pub(crate) adapters: Registry<Adapter>,
pub(crate) devices: Registry<Device>,
pub(crate) queues: Registry<Queue>,
pub(crate) pipeline_layouts: Registry<PipelineLayout>,
pub(crate) shader_modules: Registry<ShaderModule>,
pub(crate) bind_group_layouts: Registry<BindGroupLayout>,
pub(crate) bind_groups: Registry<BindGroup>,
pub(crate) command_buffers: Registry<CommandBuffer>,
pub(crate) render_bundles: Registry<RenderBundle>,
pub(crate) render_pipelines: Registry<RenderPipeline>,
pub(crate) compute_pipelines: Registry<ComputePipeline>,
pub(crate) pipeline_caches: Registry<PipelineCache>,
pub(crate) query_sets: Registry<QuerySet>,
pub(crate) buffers: Registry<Buffer>,
pub(crate) staging_buffers: Registry<StagingBuffer>,
pub(crate) textures: Registry<Texture>,
pub(crate) texture_views: Registry<TextureView>,
pub(crate) samplers: Registry<Sampler>,
pub(crate) adapters: Registry<Arc<Adapter>>,
pub(crate) devices: Registry<Arc<Device>>,
pub(crate) queues: Registry<Arc<Queue>>,
pub(crate) pipeline_layouts: Registry<Arc<PipelineLayout>>,
pub(crate) shader_modules: Registry<Arc<ShaderModule>>,
pub(crate) bind_group_layouts: Registry<Arc<BindGroupLayout>>,
pub(crate) bind_groups: Registry<Arc<BindGroup>>,
pub(crate) command_buffers: Registry<Arc<CommandBuffer>>,
pub(crate) render_bundles: Registry<Arc<RenderBundle>>,
pub(crate) render_pipelines: Registry<Arc<RenderPipeline>>,
pub(crate) compute_pipelines: Registry<Arc<ComputePipeline>>,
pub(crate) pipeline_caches: Registry<Arc<PipelineCache>>,
pub(crate) query_sets: Registry<Arc<QuerySet>>,
pub(crate) buffers: Registry<Arc<Buffer>>,
pub(crate) staging_buffers: Registry<Arc<StagingBuffer>>,
pub(crate) textures: Registry<Arc<Texture>>,
pub(crate) texture_views: Registry<Arc<TextureView>>,
pub(crate) samplers: Registry<Arc<Sampler>>,
}
impl Hub {
@ -206,7 +206,7 @@ impl Hub {
}
}
pub(crate) fn clear(&self, surface_guard: &Storage<Surface>) {
pub(crate) fn clear(&self, surface_guard: &Storage<Arc<Surface>>) {
let mut devices = self.devices.write();
for element in devices.map.iter() {
if let Element::Occupied(ref device, _) = *element {

View File

@ -68,7 +68,7 @@ impl<T: StorageItem> FutureId<'_, T> {
/// Assign a new resource to this ID.
///
/// Registers it with the registry.
pub fn assign(self, value: Arc<T>) -> Id<T::Marker> {
pub fn assign(self, value: T) -> Id<T::Marker> {
let mut data = self.data.write();
data.insert(self.id, value);
self.id
@ -98,9 +98,6 @@ impl<T: StorageItem> Registry<T> {
}
}
pub(crate) fn get(&self, id: Id<T::Marker>) -> Result<Arc<T>, InvalidId> {
self.read().get_owned(id)
}
#[track_caller]
pub(crate) fn read<'a>(&'a self) -> RwLockReadGuard<'a, Storage<T>> {
self.storage.read()
@ -114,7 +111,7 @@ impl<T: StorageItem> Registry<T> {
storage.remove(id);
storage.insert_error(id);
}
pub(crate) fn unregister(&self, id: Id<T::Marker>) -> Option<Arc<T>> {
pub(crate) fn unregister(&self, id: Id<T::Marker>) -> Option<T> {
let value = self.storage.write().remove(id);
// This needs to happen *after* removing it from the storage, to maintain the
// invariant that `self.identity` only contains ids which are actually available
@ -142,6 +139,12 @@ impl<T: StorageItem> Registry<T> {
}
}
impl<T: StorageItem + Clone> Registry<T> {
pub(crate) fn get(&self, id: Id<T::Marker>) -> Result<T, InvalidId> {
self.read().get_owned(id)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;

View File

@ -8,13 +8,16 @@ use crate::{Epoch, Index};
/// An entry in a `Storage::map` table.
#[derive(Debug)]
pub(crate) enum Element<T> {
pub(crate) enum Element<T>
where
T: StorageItem,
{
/// There are no live ids with this index.
Vacant,
/// There is one live id with this index, allocated at the given
/// epoch.
Occupied(Arc<T>, Epoch),
Occupied(T, Epoch),
/// Like `Occupied`, but an error occurred when creating the
/// resource.
@ -28,6 +31,14 @@ pub(crate) trait StorageItem: ResourceType {
type Marker: Marker;
}
impl<T: ResourceType> ResourceType for Arc<T> {
const TYPE: &'static str = T::TYPE;
}
impl<T: StorageItem> StorageItem for Arc<T> {
type Marker = T::Marker;
}
#[macro_export]
macro_rules! impl_storage_item {
($ty:ident) => {
@ -72,7 +83,7 @@ where
{
/// Get a reference to an item behind a potentially invalid ID.
/// Panics if there is an epoch mismatch, or the entry is empty.
pub(crate) fn get(&self, id: Id<T::Marker>) -> Result<&Arc<T>, InvalidId> {
pub(crate) fn get(&self, id: Id<T::Marker>) -> Result<&T, InvalidId> {
let (index, epoch, _) = id.unzip();
let (result, storage_epoch) = match self.map.get(index as usize) {
Some(&Element::Occupied(ref v, epoch)) => (Ok(v), epoch),
@ -87,12 +98,6 @@ where
result
}
/// Get an owned reference to an item behind a potentially invalid ID.
/// Panics if there is an epoch mismatch, or the entry is empty.
pub(crate) fn get_owned(&self, id: Id<T::Marker>) -> Result<Arc<T>, InvalidId> {
Ok(Arc::clone(self.get(id)?))
}
fn insert_impl(&mut self, index: usize, epoch: Epoch, element: Element<T>) {
if index >= self.map.len() {
self.map.resize_with(index + 1, || Element::Vacant);
@ -118,7 +123,7 @@ where
}
}
pub(crate) fn insert(&mut self, id: Id<T::Marker>, value: Arc<T>) {
pub(crate) fn insert(&mut self, id: Id<T::Marker>, value: T) {
let (index, epoch, _backend) = id.unzip();
self.insert_impl(index as usize, epoch, Element::Occupied(value, epoch))
}
@ -128,7 +133,7 @@ where
self.insert_impl(index as usize, epoch, Element::Error(epoch))
}
pub(crate) fn replace_with_error(&mut self, id: Id<T::Marker>) -> Result<Arc<T>, InvalidId> {
pub(crate) fn replace_with_error(&mut self, id: Id<T::Marker>) -> Result<T, InvalidId> {
let (index, epoch, _) = id.unzip();
match std::mem::replace(&mut self.map[index as usize], Element::Error(epoch)) {
Element::Vacant => panic!("Cannot access vacant resource"),
@ -140,7 +145,7 @@ where
}
}
pub(crate) fn remove(&mut self, id: Id<T::Marker>) -> Option<Arc<T>> {
pub(crate) fn remove(&mut self, id: Id<T::Marker>) -> Option<T> {
let (index, epoch, _) = id.unzip();
match std::mem::replace(&mut self.map[index as usize], Element::Vacant) {
Element::Occupied(value, storage_epoch) => {
@ -152,7 +157,7 @@ where
}
}
pub(crate) fn iter(&self, backend: Backend) -> impl Iterator<Item = (Id<T::Marker>, &Arc<T>)> {
pub(crate) fn iter(&self, backend: Backend) -> impl Iterator<Item = (Id<T::Marker>, &T)> {
self.map
.iter()
.enumerate()
@ -168,3 +173,14 @@ where
self.map.len()
}
}
impl<T> Storage<T>
where
T: StorageItem + Clone,
{
/// Get an owned reference to an item behind a potentially invalid ID.
/// Panics if there is an epoch mismatch, or the entry is empty.
pub(crate) fn get_owned(&self, id: Id<T::Marker>) -> Result<T, InvalidId> {
Ok(self.get(id)?.clone())
}
}