move PendingWrites into the Queue

The `Device` should not contain any `Arc`s to resources as that creates cycles (since all resources hold strong references to the `Device`).
Note that `PendingWrites` internally has `Arc`s to resources.

I think this change also makes more sense conceptually since most operations that use `PendingWrites` are on the `Queue`.
This commit is contained in:
teoxoy 2024-10-15 11:23:34 +02:00 committed by Teodor Tanasoaia
parent 97acfd26ce
commit 6c7dbba399
6 changed files with 169 additions and 136 deletions

View File

@ -1,5 +1,5 @@
use crate::{
device::queue::TempResource,
device::{queue::TempResource, Device},
global::Global,
hub::Hub,
id::CommandEncoderId,
@ -20,10 +20,7 @@ use crate::{
use wgt::{math::align_to, BufferUsages, Features};
use super::CommandBufferMutable;
use crate::device::queue::PendingWrites;
use hal::BufferUses;
use std::mem::ManuallyDrop;
use std::ops::DerefMut;
use std::{
cmp::max,
num::NonZeroU64,
@ -184,7 +181,7 @@ impl Global {
build_command_index,
&mut buf_storage,
hub,
device.pending_writes.lock().deref_mut(),
device,
)?;
let snatch_guard = device.snatchable_lock.read();
@ -248,7 +245,9 @@ impl Global {
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
device.pending_writes.lock().insert_tlas(&tlas);
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}
cmd_buf_data.tlas_actions.push(TlasAction {
tlas: tlas.clone(),
@ -349,10 +348,12 @@ impl Global {
}
}
device
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
if let Some(queue) = device.get_queue() {
queue
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
}
Ok(())
}
@ -495,7 +496,7 @@ impl Global {
build_command_index,
&mut buf_storage,
hub,
device.pending_writes.lock().deref_mut(),
device,
)?;
let snatch_guard = device.snatchable_lock.read();
@ -516,7 +517,9 @@ impl Global {
.get(package.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
device.pending_writes.lock().insert_tlas(&tlas);
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
tlas_lock_store.push((Some(package), tlas.clone()))
@ -742,17 +745,21 @@ impl Global {
}
if let Some(staging_buffer) = staging_buffer {
device
.pending_writes
.lock()
.consume_temp(TempResource::StagingBuffer(staging_buffer));
if let Some(queue) = device.get_queue() {
queue
.pending_writes
.lock()
.consume_temp(TempResource::StagingBuffer(staging_buffer));
}
}
}
device
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
if let Some(queue) = device.get_queue() {
queue
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
}
Ok(())
}
@ -839,7 +846,7 @@ fn iter_blas<'a>(
build_command_index: NonZeroU64,
buf_storage: &mut Vec<TriangleBufferStore<'a>>,
hub: &Hub,
pending_writes: &mut ManuallyDrop<PendingWrites>,
device: &Device,
) -> Result<(), BuildAccelerationStructureError> {
let mut temp_buffer = Vec::new();
for entry in blas_iter {
@ -849,7 +856,9 @@ fn iter_blas<'a>(
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
pending_writes.insert_blas(&blas);
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_blas(&blas);
}
cmd_buf_data.blas_actions.push(BlasAction {
blas: blas.clone(),

View File

@ -14,7 +14,7 @@ use crate::{
hal_label,
id::{self, QueueId},
init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange},
lock::RwLockWriteGuard,
lock::{rank, Mutex, RwLockWriteGuard},
resource::{
Buffer, BufferAccessError, BufferMapState, DestroyedBuffer, DestroyedResourceError,
DestroyedTexture, Fallible, FlushedStagingBuffer, InvalidResourceError, Labeled,
@ -42,14 +42,59 @@ use super::Device;
pub struct Queue {
raw: ManuallyDrop<Box<dyn hal::DynQueue>>,
pub(crate) device: Arc<Device>,
pub(crate) pending_writes: Mutex<ManuallyDrop<PendingWrites>>,
}
impl Queue {
pub(crate) fn new(device: Arc<Device>, raw: Box<dyn hal::DynQueue>) -> Self {
Queue {
pub(crate) fn new(
device: Arc<Device>,
raw: Box<dyn hal::DynQueue>,
) -> Result<Self, DeviceError> {
let pending_encoder = device
.command_allocator
.acquire_encoder(device.raw(), raw.as_ref())
.map_err(DeviceError::from_hal);
let pending_encoder = match pending_encoder {
Ok(pending_encoder) => pending_encoder,
Err(e) => {
device.release_queue(raw);
return Err(e);
}
};
let mut pending_writes = PendingWrites::new(pending_encoder);
let zero_buffer = device.zero_buffer.as_ref();
pending_writes.activate();
unsafe {
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer,
usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST,
}]);
pending_writes
.command_encoder
.clear_buffer(zero_buffer, 0..super::ZERO_BUFFER_SIZE);
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer,
usage: hal::BufferUses::COPY_DST..hal::BufferUses::COPY_SRC,
}]);
}
let pending_writes = Mutex::new(
rank::QUEUE_PENDING_WRITES,
ManuallyDrop::new(pending_writes),
);
Ok(Queue {
raw: ManuallyDrop::new(raw),
device,
}
pending_writes,
})
}
pub(crate) fn raw(&self) -> &dyn hal::DynQueue {
@ -70,6 +115,9 @@ crate::impl_storage_item!(Queue);
impl Drop for Queue {
fn drop(&mut self) {
resource_log!("Drop {}", self.error_ident());
// SAFETY: We are in the Drop impl and we don't use self.pending_writes anymore after this point.
let pending_writes = unsafe { ManuallyDrop::take(&mut self.pending_writes.lock()) };
pending_writes.dispose(self.device.raw());
// SAFETY: we never access `self.raw` beyond this point.
let queue = unsafe { ManuallyDrop::take(&mut self.raw) };
self.device.release_queue(queue);
@ -418,7 +466,7 @@ impl Queue {
// freed, even if an error occurs. All paths from here must call
// `device.pending_writes.consume`.
let mut staging_buffer = StagingBuffer::new(&self.device, data_size)?;
let mut pending_writes = self.device.pending_writes.lock();
let mut pending_writes = self.pending_writes.lock();
let staging_buffer = {
profiling::scope!("copy");
@ -460,7 +508,7 @@ impl Queue {
let buffer = buffer.get()?;
let mut pending_writes = self.device.pending_writes.lock();
let mut pending_writes = self.pending_writes.lock();
// At this point, we have taken ownership of the staging_buffer from the
// user. Platform validation requires that the staging buffer always
@ -636,7 +684,7 @@ impl Queue {
.map_err(TransferError::from)?;
}
let mut pending_writes = self.device.pending_writes.lock();
let mut pending_writes = self.pending_writes.lock();
let encoder = pending_writes.activate();
// If the copy does not fully cover the layers, we need to initialize to
@ -888,7 +936,7 @@ impl Queue {
let (selector, dst_base) = extract_texture_selector(&destination, &size, &dst)?;
let mut pending_writes = self.device.pending_writes.lock();
let mut pending_writes = self.pending_writes.lock();
let encoder = pending_writes.activate();
// If the copy does not fully cover the layers, we need to initialize to
@ -1157,7 +1205,7 @@ impl Queue {
}
}
let mut pending_writes = self.device.pending_writes.lock();
let mut pending_writes = self.pending_writes.lock();
{
used_surface_textures.set_size(self.device.tracker_indices.textures.size());

View File

@ -6,10 +6,8 @@ use crate::{
device::{
bgl, create_validator,
life::{LifetimeTracker, WaitIdleError},
map_buffer,
queue::PendingWrites,
AttachmentData, DeviceLostInvocation, HostMap, MissingDownlevelFlags, MissingFeatures,
RenderPassContext, CLEANUP_WAIT_MS,
map_buffer, AttachmentData, DeviceLostInvocation, HostMap, MissingDownlevelFlags,
MissingFeatures, RenderPassContext, CLEANUP_WAIT_MS,
},
hal_label,
init_tracker::{
@ -141,7 +139,6 @@ pub struct Device {
pub(crate) features: wgt::Features,
pub(crate) downlevel: wgt::DownlevelCapabilities,
pub(crate) instance_flags: wgt::InstanceFlags,
pub(crate) pending_writes: Mutex<ManuallyDrop<PendingWrites>>,
pub(crate) deferred_destroy: Mutex<Vec<DeferredDestroy>>,
pub(crate) usage_scopes: UsageScopePool,
pub(crate) last_acceleration_structure_build_command_index: AtomicU64,
@ -181,11 +178,8 @@ impl Drop for Device {
let raw = unsafe { ManuallyDrop::take(&mut self.raw) };
// SAFETY: We are in the Drop impl and we don't use self.zero_buffer anymore after this point.
let zero_buffer = unsafe { ManuallyDrop::take(&mut self.zero_buffer) };
// SAFETY: We are in the Drop impl and we don't use self.pending_writes anymore after this point.
let pending_writes = unsafe { ManuallyDrop::take(&mut self.pending_writes.lock()) };
// SAFETY: We are in the Drop impl and we don't use self.fence anymore after this point.
let fence = unsafe { ManuallyDrop::take(&mut self.fence.write()) };
pending_writes.dispose(raw.as_ref());
self.command_allocator.dispose(raw.as_ref());
#[cfg(feature = "indirect-validation")]
self.indirect_validation
@ -228,7 +222,6 @@ impl Device {
impl Device {
pub(crate) fn new(
raw_device: Box<dyn hal::DynDevice>,
raw_queue: &dyn hal::DynQueue,
adapter: &Arc<Adapter>,
desc: &DeviceDescriptor,
trace_path: Option<&std::path::Path>,
@ -241,10 +234,6 @@ impl Device {
let fence = unsafe { raw_device.create_fence() }.map_err(DeviceError::from_hal)?;
let command_allocator = command::CommandAllocator::new();
let pending_encoder = command_allocator
.acquire_encoder(raw_device.as_ref(), raw_queue)
.map_err(DeviceError::from_hal)?;
let mut pending_writes = PendingWrites::new(pending_encoder);
// Create zeroed buffer used for texture clears.
let zero_buffer = unsafe {
@ -256,24 +245,6 @@ impl Device {
})
}
.map_err(DeviceError::from_hal)?;
pending_writes.activate();
unsafe {
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer.as_ref(),
usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST,
}]);
pending_writes
.command_encoder
.clear_buffer(zero_buffer.as_ref(), 0..ZERO_BUFFER_SIZE);
pending_writes
.command_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: zero_buffer.as_ref(),
usage: hal::BufferUses::COPY_DST..hal::BufferUses::COPY_SRC,
}]);
}
let alignments = adapter.raw.capabilities.alignments.clone();
let downlevel = adapter.raw.capabilities.downlevel.clone();
@ -336,10 +307,6 @@ impl Device {
features: desc.required_features,
downlevel,
instance_flags,
pending_writes: Mutex::new(
rank::DEVICE_PENDING_WRITES,
ManuallyDrop::new(pending_writes),
),
deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()),
usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()),
// By starting at one, we can put the result in a NonZeroU64.

View File

@ -573,18 +573,14 @@ impl Adapter {
) -> Result<(Arc<Device>, Arc<Queue>), RequestDeviceError> {
api_log!("Adapter::create_device");
let device = Device::new(
hal_device.device,
hal_device.queue.as_ref(),
self,
desc,
trace_path,
instance_flags,
)?;
let device = Device::new(hal_device.device, self, desc, trace_path, instance_flags)?;
let device = Arc::new(device);
let queue = Arc::new(Queue::new(device.clone(), hal_device.queue));
let queue = Queue::new(device.clone(), hal_device.queue)?;
let queue = Arc::new(queue);
device.set_queue(&queue);
Ok((device, queue))
}

View File

@ -111,11 +111,11 @@ define_lock_ranks! {
// COMMAND_BUFFER_DATA,
}
rank BUFFER_MAP_STATE "Buffer::map_state" followed by {
DEVICE_PENDING_WRITES,
QUEUE_PENDING_WRITES,
SHARED_TRACKER_INDEX_ALLOCATOR_INNER,
DEVICE_TRACE,
}
rank DEVICE_PENDING_WRITES "Device::pending_writes" followed by {
rank QUEUE_PENDING_WRITES "Queue::pending_writes" followed by {
COMMAND_ALLOCATOR_FREE_ENCODERS,
SHARED_TRACKER_INDEX_ALLOCATOR_INNER,
DEVICE_LIFE_TRACKER,

View File

@ -676,36 +676,37 @@ impl Buffer {
});
}
let mut pending_writes = device.pending_writes.lock();
let staging_buffer = staging_buffer.flush();
let region = wgt::BufferSize::new(self.size).map(|size| hal::BufferCopy {
src_offset: 0,
dst_offset: 0,
size,
});
let transition_src = hal::BufferBarrier {
buffer: staging_buffer.raw(),
usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC,
};
let transition_dst = hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: raw_buf,
usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST,
};
let encoder = pending_writes.activate();
unsafe {
encoder.transition_buffers(&[transition_src, transition_dst]);
if self.size > 0 {
encoder.copy_buffer_to_buffer(
staging_buffer.raw(),
raw_buf,
region.as_slice(),
);
if let Some(queue) = device.get_queue() {
let region = wgt::BufferSize::new(self.size).map(|size| hal::BufferCopy {
src_offset: 0,
dst_offset: 0,
size,
});
let transition_src = hal::BufferBarrier {
buffer: staging_buffer.raw(),
usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC,
};
let transition_dst = hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: raw_buf,
usage: hal::BufferUses::empty()..hal::BufferUses::COPY_DST,
};
let mut pending_writes = queue.pending_writes.lock();
let encoder = pending_writes.activate();
unsafe {
encoder.transition_buffers(&[transition_src, transition_dst]);
if self.size > 0 {
encoder.copy_buffer_to_buffer(
staging_buffer.raw(),
raw_buf,
region.as_slice(),
);
}
}
pending_writes.consume(staging_buffer);
pending_writes.insert_buffer(self);
}
pending_writes.consume(staging_buffer);
pending_writes.insert_buffer(self);
}
BufferMapState::Idle => {
return Err(BufferAccessError::NotMapped);
@ -778,17 +779,20 @@ impl Buffer {
})
};
let mut pending_writes = device.pending_writes.lock();
if pending_writes.contains_buffer(self) {
pending_writes.consume_temp(temp);
} else {
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_buffer_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_buffer(self) {
pending_writes.consume_temp(temp);
return Ok(());
}
}
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_buffer_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
}
Ok(())
}
}
@ -1244,17 +1248,20 @@ impl Texture {
})
};
let mut pending_writes = device.pending_writes.lock();
if pending_writes.contains_texture(self) {
pending_writes.consume_temp(temp);
} else {
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_texture_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_texture(self) {
pending_writes.consume_temp(temp);
return Ok(());
}
}
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_texture_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
}
Ok(())
}
}
@ -1960,17 +1967,20 @@ impl Blas {
})
};
let mut pending_writes = device.pending_writes.lock();
if pending_writes.contains_blas(self) {
pending_writes.consume_temp(temp);
} else {
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_blas_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_blas(self) {
pending_writes.consume_temp(temp);
return Ok(());
}
}
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_blas_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
}
Ok(())
}
}
@ -2047,17 +2057,20 @@ impl Tlas {
})
};
let mut pending_writes = device.pending_writes.lock();
if pending_writes.contains_tlas(self) {
pending_writes.consume_temp(temp);
} else {
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_tlas_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
if let Some(queue) = device.get_queue() {
let mut pending_writes = queue.pending_writes.lock();
if pending_writes.contains_tlas(self) {
pending_writes.consume_temp(temp);
return Ok(());
}
}
let mut life_lock = device.lock_life();
let last_submit_index = life_lock.get_tlas_latest_submission_index(self);
if let Some(last_submit_index) = last_submit_index {
life_lock.schedule_resource_destruction(temp, last_submit_index);
}
Ok(())
}
}