move PendingWrites into the Queue

The `Device` should not contain any `Arc`s to resources as that creates cycles (since all resources hold strong references to the `Device`).
Note that `PendingWrites` internally has `Arc`s to resources.

I think this change also makes more sense conceptually since most operations that use `PendingWrites` are on the `Queue`.
This commit is contained in:
teoxoy 2024-10-15 11:23:34 +02:00 committed by Teodor Tanasoaia
parent 97acfd26ce
commit 6c7dbba399
6 changed files with 169 additions and 136 deletions

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
device::queue::TempResource, device::{queue::TempResource, Device},
global::Global, global::Global,
hub::Hub, hub::Hub,
id::CommandEncoderId, id::CommandEncoderId,
@ -20,10 +20,7 @@ use crate::{
use wgt::{math::align_to, BufferUsages, Features}; use wgt::{math::align_to, BufferUsages, Features};
use super::CommandBufferMutable; use super::CommandBufferMutable;
use crate::device::queue::PendingWrites;
use hal::BufferUses; use hal::BufferUses;
use std::mem::ManuallyDrop;
use std::ops::DerefMut;
use std::{ use std::{
cmp::max, cmp::max,
num::NonZeroU64, num::NonZeroU64,
@ -184,7 +181,7 @@ impl Global {
build_command_index, build_command_index,
&mut buf_storage, &mut buf_storage,
hub, hub,
device.pending_writes.lock().deref_mut(), device,
)?; )?;
let snatch_guard = device.snatchable_lock.read(); let snatch_guard = device.snatchable_lock.read();
@ -248,7 +245,9 @@ impl Global {
.get() .get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?; .map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone()); cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
device.pending_writes.lock().insert_tlas(&tlas); if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}
cmd_buf_data.tlas_actions.push(TlasAction { cmd_buf_data.tlas_actions.push(TlasAction {
tlas: tlas.clone(), tlas: tlas.clone(),
@ -349,10 +348,12 @@ impl Global {
} }
} }
device if let Some(queue) = device.get_queue() {
queue
.pending_writes .pending_writes
.lock() .lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer)); .consume_temp(TempResource::ScratchBuffer(scratch_buffer));
}
Ok(()) Ok(())
} }
@ -495,7 +496,7 @@ impl Global {
build_command_index, build_command_index,
&mut buf_storage, &mut buf_storage,
hub, hub,
device.pending_writes.lock().deref_mut(), device,
)?; )?;
let snatch_guard = device.snatchable_lock.read(); let snatch_guard = device.snatchable_lock.read();
@ -516,7 +517,9 @@ impl Global {
.get(package.tlas_id) .get(package.tlas_id)
.get() .get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?; .map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
device.pending_writes.lock().insert_tlas(&tlas); if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone()); cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
tlas_lock_store.push((Some(package), tlas.clone())) tlas_lock_store.push((Some(package), tlas.clone()))
@ -742,17 +745,21 @@ impl Global {
} }
if let Some(staging_buffer) = staging_buffer { if let Some(staging_buffer) = staging_buffer {
device if let Some(queue) = device.get_queue() {
queue
.pending_writes .pending_writes
.lock() .lock()
.consume_temp(TempResource::StagingBuffer(staging_buffer)); .consume_temp(TempResource::StagingBuffer(staging_buffer));
} }
} }
}
device if let Some(queue) = device.get_queue() {
queue
.pending_writes .pending_writes
.lock() .lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer)); .consume_temp(TempResource::ScratchBuffer(scratch_buffer));
}
Ok(()) Ok(())
} }
@ -839,7 +846,7 @@ fn iter_blas<'a>(
build_command_index: NonZeroU64, build_command_index: NonZeroU64,
buf_storage: &mut Vec<TriangleBufferStore<'a>>, buf_storage: &mut Vec<TriangleBufferStore<'a>>,
hub: &Hub, hub: &Hub,
pending_writes: &mut ManuallyDrop<PendingWrites>, device: &Device,
) -> Result<(), BuildAccelerationStructureError> { ) -> Result<(), BuildAccelerationStructureError> {
let mut temp_buffer = Vec::new(); let mut temp_buffer = Vec::new();
for entry in blas_iter { for entry in blas_iter {
@ -849,7 +856,9 @@ fn iter_blas<'a>(
.get() .get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?; .map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
cmd_buf_data.trackers.blas_s.set_single(blas.clone()); cmd_buf_data.trackers.blas_s.set_single(blas.clone());
pending_writes.insert_blas(&blas); if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_blas(&blas);
}
cmd_buf_data.blas_actions.push(BlasAction { cmd_buf_data.blas_actions.push(BlasAction {
blas: blas.clone(), blas: blas.clone(),

View File

@ -14,7 +14,7 @@ use crate::{
hal_label, hal_label,
id::{self, QueueId}, id::{self, QueueId},
init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange}, init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange},
lock::RwLockWriteGuard, lock::{rank, Mutex, RwLockWriteGuard},
resource::{ resource::{
Buffer, BufferAccessError, BufferMapState, DestroyedBuffer, DestroyedResourceError, Buffer, BufferAccessError, BufferMapState, DestroyedBuffer, DestroyedResourceError,
DestroyedTexture, Fallible, FlushedStagingBuffer, InvalidResourceError, Labeled, DestroyedTexture, Fallible, FlushedStagingBuffer, InvalidResourceError, Labeled,
@ -42,14 +42,59 @@ use super::Device;
pub struct Queue { pub struct Queue {
raw: ManuallyDrop<Box<dyn hal::DynQueue>>, raw: ManuallyDrop<Box<dyn hal::DynQueue>>,
pub(crate) device: Arc<Device>, pub(crate) device: Arc<Device>,
pub(crate) pending_writes: Mutex<ManuallyDrop<PendingWrites>>,
} }
impl Queue { impl Queue {
pub(crate) fn new(device: Arc<Device>, raw: Box<dyn hal::DynQueue>) -> Self { pub(crate) fn new(
Queue { device: Arc<Device>,
raw: Box<dyn hal::DynQueue>,
) -> Result<Self, DeviceError> {
let pending_encoder = device
.command_allocator
.acquire_encoder(device.raw(), raw.as_ref())
.map_err(DeviceError::from_hal);
let pending_encoder = match pending_encoder {
Ok(pending_encoder) => pending_encoder,
Err(e) => {
device.release_queue(raw);
return Err(e);
}
};
let mut pending_writes = PendingWrites::new(pending_encoder);
let zero_buffer = device.zero_buffer.as_ref();
pending_writes.activate();
unsafe {
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer,
usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST,
}]);
pending_writes
.command_encoder
.clear_buffer(zero_buffer, 0..super::ZERO_BUFFER_SIZE);
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer,
usage: hal::BufferUses::COPY_DST..hal::BufferUses::COPY_SRC,
}]);
}
let pending_writes = Mutex::new(
rank::QUEUE_PENDING_WRITES,
ManuallyDrop::new(pending_writes),
);
Ok(Queue {
raw: ManuallyDrop::new(raw), raw: ManuallyDrop::new(raw),
device, device,
} pending_writes,
})
} }
pub(crate) fn raw(&self) -> &dyn hal::DynQueue { pub(crate) fn raw(&self) -> &dyn hal::DynQueue {
@ -70,6 +115,9 @@ crate::impl_storage_item!(Queue);
impl Drop for Queue { impl Drop for Queue {
fn drop(&mut self) { fn drop(&mut self) {
resource_log!("Drop {}", self.error_ident()); resource_log!("Drop {}", self.error_ident());
// SAFETY: We are in the Drop impl and we don't use self.pending_writes anymore after this point.
let pending_writes = unsafe { ManuallyDrop::take(&mut self.pending_writes.lock()) };
pending_writes.dispose(self.device.raw());
// SAFETY: we never access `self.raw` beyond this point. // SAFETY: we never access `self.raw` beyond this point.
let queue = unsafe { ManuallyDrop::take(&mut self.raw) }; let queue = unsafe { ManuallyDrop::take(&mut self.raw) };
self.device.release_queue(queue); self.device.release_queue(queue);
@ -418,7 +466,7 @@ impl Queue {
// freed, even if an error occurs. All paths from here must call // freed, even if an error occurs. All paths from here must call
// `device.pending_writes.consume`. // `device.pending_writes.consume`.
let mut staging_buffer = StagingBuffer::new(&self.device, data_size)?; let mut staging_buffer = StagingBuffer::new(&self.device, data_size)?;
let mut pending_writes = self.device.pending_writes.lock(); let mut pending_writes = self.pending_writes.lock();
let staging_buffer = { let staging_buffer = {
profiling::scope!("copy"); profiling::scope!("copy");
@ -460,7 +508,7 @@ impl Queue {
let buffer = buffer.get()?; let buffer = buffer.get()?;
let mut pending_writes = self.device.pending_writes.lock(); let mut pending_writes = self.pending_writes.lock();
// At this point, we have taken ownership of the staging_buffer from the // At this point, we have taken ownership of the staging_buffer from the
// user. Platform validation requires that the staging buffer always // user. Platform validation requires that the staging buffer always
@ -636,7 +684,7 @@ impl Queue {
.map_err(TransferError::from)?; .map_err(TransferError::from)?;
} }
let mut pending_writes = self.device.pending_writes.lock(); let mut pending_writes = self.pending_writes.lock();
let encoder = pending_writes.activate(); let encoder = pending_writes.activate();
// If the copy does not fully cover the layers, we need to initialize to // If the copy does not fully cover the layers, we need to initialize to
@ -888,7 +936,7 @@ impl Queue {
let (selector, dst_base) = extract_texture_selector(&destination, &size, &dst)?; let (selector, dst_base) = extract_texture_selector(&destination, &size, &dst)?;
let mut pending_writes = self.device.pending_writes.lock(); let mut pending_writes = self.pending_writes.lock();
let encoder = pending_writes.activate(); let encoder = pending_writes.activate();
// If the copy does not fully cover the layers, we need to initialize to // If the copy does not fully cover the layers, we need to initialize to
@ -1157,7 +1205,7 @@ impl Queue {
} }
} }
let mut pending_writes = self.device.pending_writes.lock(); let mut pending_writes = self.pending_writes.lock();
{ {
used_surface_textures.set_size(self.device.tracker_indices.textures.size()); used_surface_textures.set_size(self.device.tracker_indices.textures.size());

View File

@ -6,10 +6,8 @@ use crate::{
device::{ device::{
bgl, create_validator, bgl, create_validator,
life::{LifetimeTracker, WaitIdleError}, life::{LifetimeTracker, WaitIdleError},
map_buffer, map_buffer, AttachmentData, DeviceLostInvocation, HostMap, MissingDownlevelFlags,
queue::PendingWrites, MissingFeatures, RenderPassContext, CLEANUP_WAIT_MS,
AttachmentData, DeviceLostInvocation, HostMap, MissingDownlevelFlags, MissingFeatures,
RenderPassContext, CLEANUP_WAIT_MS,
}, },
hal_label, hal_label,
init_tracker::{ init_tracker::{
@ -141,7 +139,6 @@ pub struct Device {
pub(crate) features: wgt::Features, pub(crate) features: wgt::Features,
pub(crate) downlevel: wgt::DownlevelCapabilities, pub(crate) downlevel: wgt::DownlevelCapabilities,
pub(crate) instance_flags: wgt::InstanceFlags, pub(crate) instance_flags: wgt::InstanceFlags,
pub(crate) pending_writes: Mutex<ManuallyDrop<PendingWrites>>,
pub(crate) deferred_destroy: Mutex<Vec<DeferredDestroy>>, pub(crate) deferred_destroy: Mutex<Vec<DeferredDestroy>>,
pub(crate) usage_scopes: UsageScopePool, pub(crate) usage_scopes: UsageScopePool,
pub(crate) last_acceleration_structure_build_command_index: AtomicU64, pub(crate) last_acceleration_structure_build_command_index: AtomicU64,
@ -181,11 +178,8 @@ impl Drop for Device {
let raw = unsafe { ManuallyDrop::take(&mut self.raw) }; let raw = unsafe { ManuallyDrop::take(&mut self.raw) };
// SAFETY: We are in the Drop impl and we don't use self.zero_buffer anymore after this point. // SAFETY: We are in the Drop impl and we don't use self.zero_buffer anymore after this point.
let zero_buffer = unsafe { ManuallyDrop::take(&mut self.zero_buffer) }; let zero_buffer = unsafe { ManuallyDrop::take(&mut self.zero_buffer) };
// SAFETY: We are in the Drop impl and we don't use self.pending_writes anymore after this point.
let pending_writes = unsafe { ManuallyDrop::take(&mut self.pending_writes.lock()) };
// SAFETY: We are in the Drop impl and we don't use self.fence anymore after this point. // SAFETY: We are in the Drop impl and we don't use self.fence anymore after this point.
let fence = unsafe { ManuallyDrop::take(&mut self.fence.write()) }; let fence = unsafe { ManuallyDrop::take(&mut self.fence.write()) };
pending_writes.dispose(raw.as_ref());
self.command_allocator.dispose(raw.as_ref()); self.command_allocator.dispose(raw.as_ref());
#[cfg(feature = "indirect-validation")] #[cfg(feature = "indirect-validation")]
self.indirect_validation self.indirect_validation
@ -228,7 +222,6 @@ impl Device {
impl Device { impl Device {
pub(crate) fn new( pub(crate) fn new(
raw_device: Box<dyn hal::DynDevice>, raw_device: Box<dyn hal::DynDevice>,
raw_queue: &dyn hal::DynQueue,
adapter: &Arc<Adapter>, adapter: &Arc<Adapter>,
desc: &DeviceDescriptor, desc: &DeviceDescriptor,
trace_path: Option<&std::path::Path>, trace_path: Option<&std::path::Path>,
@ -241,10 +234,6 @@ impl Device {
let fence = unsafe { raw_device.create_fence() }.map_err(DeviceError::from_hal)?; let fence = unsafe { raw_device.create_fence() }.map_err(DeviceError::from_hal)?;
let command_allocator = command::CommandAllocator::new(); let command_allocator = command::CommandAllocator::new();
let pending_encoder = command_allocator
.acquire_encoder(raw_device.as_ref(), raw_queue)
.map_err(DeviceError::from_hal)?;
let mut pending_writes = PendingWrites::new(pending_encoder);
// Create zeroed buffer used for texture clears. // Create zeroed buffer used for texture clears.
let zero_buffer = unsafe { let zero_buffer = unsafe {
@ -256,24 +245,6 @@ impl Device {
}) })
} }
.map_err(DeviceError::from_hal)?; .map_err(DeviceError::from_hal)?;
pending_writes.activate();
unsafe {
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer.as_ref(),
usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST,
}]);
pending_writes
.command_encoder
.clear_buffer(zero_buffer.as_ref(), 0..ZERO_BUFFER_SIZE);
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer.as_ref(),
usage: hal::BufferUses::COPY_DST..hal::BufferUses::COPY_SRC,
}]);
}
let alignments = adapter.raw.capabilities.alignments.clone(); let alignments = adapter.raw.capabilities.alignments.clone();
let downlevel = adapter.raw.capabilities.downlevel.clone(); let downlevel = adapter.raw.capabilities.downlevel.clone();
@ -336,10 +307,6 @@ impl Device {
features: desc.required_features, features: desc.required_features,
downlevel, downlevel,
instance_flags, instance_flags,
pending_writes: Mutex::new(
rank::DEVICE_PENDING_WRITES,
ManuallyDrop::new(pending_writes),
),
deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()), deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()),
usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()), usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()),
// By starting at one, we can put the result in a NonZeroU64. // By starting at one, we can put the result in a NonZeroU64.

View File

@ -573,18 +573,14 @@ impl Adapter {
) -> Result<(Arc<Device>, Arc<Queue>), RequestDeviceError> { ) -> Result<(Arc<Device>, Arc<Queue>), RequestDeviceError> {
api_log!("Adapter::create_device"); api_log!("Adapter::create_device");
let device = Device::new( let device = Device::new(hal_device.device, self, desc, trace_path, instance_flags)?;
hal_device.device,
hal_device.queue.as_ref(),
self,
desc,
trace_path,
instance_flags,
)?;
let device = Arc::new(device); let device = Arc::new(device);
let queue = Arc::new(Queue::new(device.clone(), hal_device.queue));
let queue = Queue::new(device.clone(), hal_device.queue)?;
let queue = Arc::new(queue);
device.set_queue(&queue); device.set_queue(&queue);
Ok((device, queue)) Ok((device, queue))
} }

View File

@ -111,11 +111,11 @@ define_lock_ranks! {
// COMMAND_BUFFER_DATA, // COMMAND_BUFFER_DATA,
} }
rank BUFFER_MAP_STATE "Buffer::map_state" followed by { rank BUFFER_MAP_STATE "Buffer::map_state" followed by {
DEVICE_PENDING_WRITES, QUEUE_PENDING_WRITES,
SHARED_TRACKER_INDEX_ALLOCATOR_INNER, SHARED_TRACKER_INDEX_ALLOCATOR_INNER,
DEVICE_TRACE, DEVICE_TRACE,
} }
rank DEVICE_PENDING_WRITES "Device::pending_writes" followed by { rank QUEUE_PENDING_WRITES "Queue::pending_writes" followed by {
COMMAND_ALLOCATOR_FREE_ENCODERS, COMMAND_ALLOCATOR_FREE_ENCODERS,
SHARED_TRACKER_INDEX_ALLOCATOR_INNER, SHARED_TRACKER_INDEX_ALLOCATOR_INNER,
DEVICE_LIFE_TRACKER, DEVICE_LIFE_TRACKER,

View File

@ -676,10 +676,9 @@ impl Buffer {
}); });
} }
let mut pending_writes = device.pending_writes.lock();
let staging_buffer = staging_buffer.flush(); let staging_buffer = staging_buffer.flush();
if let Some(queue) = device.get_queue() {
let region = wgt::BufferSize::new(self.size).map(|size| hal::BufferCopy { let region = wgt::BufferSize::new(self.size).map(|size| hal::BufferCopy {
src_offset: 0, src_offset: 0,
dst_offset: 0, dst_offset: 0,
@ -693,6 +692,7 @@ impl Buffer {
buffer: raw_buf, buffer: raw_buf,
usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST, usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST,
}; };
let mut pending_writes = queue.pending_writes.lock();
let encoder = pending_writes.activate(); let encoder = pending_writes.activate();
unsafe { unsafe {
encoder.transition_buffers(&[transition_src, transition_dst]); encoder.transition_buffers(&[transition_src, transition_dst]);
@ -707,6 +707,7 @@ impl Buffer {
pending_writes.consume(staging_buffer); pending_writes.consume(staging_buffer);
pending_writes.insert_buffer(self); pending_writes.insert_buffer(self);
} }
}
BufferMapState::Idle => { BufferMapState::Idle => {
return Err(BufferAccessError::NotMapped); return Err(BufferAccessError::NotMapped);
} }
@ -778,16 +779,19 @@ impl Buffer {
}) })
}; };
let mut pending_writes = device.pending_writes.lock(); if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_buffer(self) { if pending_writes.contains_buffer(self) {
pending_writes.consume_temp(temp); pending_writes.consume_temp(temp);
} else { return Ok(());
}
}
let mut life_lock = device.lock_life(); let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_buffer_latest_submission_index(self); let last_submit_index = life_lock.get_buffer_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index { if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index); life_lock.schedule_resource_destruction(temp, last_submit_index);
} }
}
Ok(()) Ok(())
} }
@ -1244,16 +1248,19 @@ impl Texture {
}) })
}; };
let mut pending_writes = device.pending_writes.lock(); if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_texture(self) { if pending_writes.contains_texture(self) {
pending_writes.consume_temp(temp); pending_writes.consume_temp(temp);
} else { return Ok(());
}
}
let mut life_lock = device.lock_life(); let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_texture_latest_submission_index(self); let last_submit_index = life_lock.get_texture_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index { if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index); life_lock.schedule_resource_destruction(temp, last_submit_index);
} }
}
Ok(()) Ok(())
} }
@ -1960,16 +1967,19 @@ impl Blas {
}) })
}; };
let mut pending_writes = device.pending_writes.lock(); if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_blas(self) { if pending_writes.contains_blas(self) {
pending_writes.consume_temp(temp); pending_writes.consume_temp(temp);
} else { return Ok(());
}
}
let mut life_lock = device.lock_life(); let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_blas_latest_submission_index(self); let last_submit_index = life_lock.get_blas_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index { if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index); life_lock.schedule_resource_destruction(temp, last_submit_index);
} }
}
Ok(()) Ok(())
} }
@ -2047,16 +2057,19 @@ impl Tlas {
}) })
}; };
let mut pending_writes = device.pending_writes.lock(); if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_tlas(self) { if pending_writes.contains_tlas(self) {
pending_writes.consume_temp(temp); pending_writes.consume_temp(temp);
} else { return Ok(());
}
}
let mut life_lock = device.lock_life(); let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_tlas_latest_submission_index(self); let last_submit_index = life_lock.get_tlas_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index { if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index); life_lock.schedule_resource_destruction(temp, last_submit_index);
} }
}
Ok(()) Ok(())
} }