diff --git a/wgpu-core/src/command/clear.rs b/wgpu-core/src/command/clear.rs index fecf58894..547356180 100644 --- a/wgpu-core/src/command/clear.rs +++ b/wgpu-core/src/command/clear.rs @@ -4,7 +4,7 @@ use std::{ops::Range, sync::Arc}; use crate::device::trace::Command as TraceCommand; use crate::{ api_log, - command::CommandBuffer, + command::CommandEncoderError, device::DeviceError, get_lowest_common_denom, global::Global, @@ -76,7 +76,7 @@ whereas subesource range specified start {subresource_base_array_layer} and coun #[error(transparent)] Device(#[from] DeviceError), #[error(transparent)] - CommandEncoderError(#[from] super::CommandEncoderError), + CommandEncoderError(#[from] CommandEncoderError), } impl Global { @@ -92,7 +92,15 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); @@ -176,7 +184,15 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index b1dae2b49..ff2bdf37e 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -301,35 +301,40 @@ impl Global { timestamp_writes: None, // Handle only once we resolved the encoder. }; - match CommandBuffer::lock_encoder(hub, encoder_id) { - Ok(cmd_buf) => { - arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes { - let Ok(query_set) = hub.query_sets.get(tw.query_set) else { - return ( - ComputePass::new(None, arc_desc), - Some(CommandEncoderError::InvalidTimestampWritesQuerySetId( - tw.query_set, - )), - ); - }; + let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e)); - if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) { - return (ComputePass::new(None, arc_desc), Some(e.into())); - } + let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) { + Ok(cmd_buf) => cmd_buf, + Err(_) => return make_err(CommandEncoderError::Invalid, arc_desc), + }; - Some(ArcPassTimestampWrites { - query_set, - beginning_of_pass_write_index: tw.beginning_of_pass_write_index, - end_of_pass_write_index: tw.end_of_pass_write_index, - }) - } else { - None - }; + match cmd_buf.lock_encoder() { + Ok(_) => {} + Err(e) => return make_err(e, arc_desc), + }; - (ComputePass::new(Some(cmd_buf), arc_desc), None) + arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes { + let Ok(query_set) = hub.query_sets.get(tw.query_set) else { + return make_err( + CommandEncoderError::InvalidTimestampWritesQuerySetId(tw.query_set), + arc_desc, + ); + }; + + if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) { + return make_err(e.into(), arc_desc); } - Err(err) => (ComputePass::new(None, arc_desc), Some(err)), - } + + Some(ArcPassTimestampWrites { + query_set, + beginning_of_pass_write_index: tw.beginning_of_pass_write_index, + end_of_pass_write_index: tw.end_of_pass_write_index, + }) + } else { + None + }; + + (ComputePass::new(Some(cmd_buf), arc_desc), None) } /// Creates a type erased compute pass. @@ -378,7 +383,11 @@ impl Global { let hub = A::hub(self); let scope = PassErrorScope::Pass; - let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?; + let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid).map_pass_err(scope), + }; + cmd_buf.check_recording().map_pass_err(scope)?; #[cfg(feature = "trace")] { diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index fbd5e1d09..f5bfcec24 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -30,7 +30,6 @@ pub use timestamp_writes::PassTimestampWrites; use self::memory_init::CommandBufferTextureMemoryActions; use crate::device::{Device, DeviceError}; -use crate::hub::Hub; use crate::lock::{rank, Mutex}; use crate::snatch::SnatchGuard; @@ -425,65 +424,41 @@ impl CommandBuffer { } impl CommandBuffer { - fn get_encoder_impl( - hub: &Hub, - id: id::CommandEncoderId, - lock_on_acquire: bool, - ) -> Result, CommandEncoderError> { - match hub.command_buffers.get(id.into_command_buffer_id()) { - Ok(cmd_buf) => { - let mut cmd_buf_data_guard = cmd_buf.data.lock(); - let cmd_buf_data = cmd_buf_data_guard.as_mut().unwrap(); - match cmd_buf_data.status { - CommandEncoderStatus::Recording => { - if lock_on_acquire { - cmd_buf_data.status = CommandEncoderStatus::Locked; - } - drop(cmd_buf_data_guard); - Ok(cmd_buf) - } - CommandEncoderStatus::Locked => { - // Any operation on a locked encoder is required to put it into the invalid/error state. - // See https://www.w3.org/TR/webgpu/#encoder-state-locked - cmd_buf_data.encoder.discard(); - cmd_buf_data.status = CommandEncoderStatus::Error; - Err(CommandEncoderError::Locked) - } - CommandEncoderStatus::Finished => Err(CommandEncoderError::NotRecording), - CommandEncoderStatus::Error => Err(CommandEncoderError::Invalid), + fn lock_encoder_impl(&self, lock: bool) -> Result<(), CommandEncoderError> { + let mut cmd_buf_data_guard = self.data.lock(); + let cmd_buf_data = cmd_buf_data_guard.as_mut().unwrap(); + match cmd_buf_data.status { + CommandEncoderStatus::Recording => { + if lock { + cmd_buf_data.status = CommandEncoderStatus::Locked; } + Ok(()) } - Err(_) => Err(CommandEncoderError::Invalid), + CommandEncoderStatus::Locked => { + // Any operation on a locked encoder is required to put it into the invalid/error state. + // See https://www.w3.org/TR/webgpu/#encoder-state-locked + cmd_buf_data.encoder.discard(); + cmd_buf_data.status = CommandEncoderStatus::Error; + Err(CommandEncoderError::Locked) + } + CommandEncoderStatus::Finished => Err(CommandEncoderError::NotRecording), + CommandEncoderStatus::Error => Err(CommandEncoderError::Invalid), } } - /// Return the [`CommandBuffer`] for `id`, for recording new commands. - /// - /// In `wgpu_core`, the [`CommandBuffer`] type serves both as encoder and - /// buffer, which is why this function takes an [`id::CommandEncoderId`] but - /// returns a [`CommandBuffer`]. The returned command buffer must be in the - /// "recording" state. Otherwise, an error is returned. - fn get_encoder( - hub: &Hub, - id: id::CommandEncoderId, - ) -> Result, CommandEncoderError> { - let lock_on_acquire = false; - Self::get_encoder_impl(hub, id, lock_on_acquire) + /// Checks that the encoder is in the [`CommandEncoderStatus::Recording`] state. + fn check_recording(&self) -> Result<(), CommandEncoderError> { + self.lock_encoder_impl(false) } - /// Return the [`CommandBuffer`] for `id` and if successful puts it into the [`CommandEncoderStatus::Locked`] state. + /// Locks the encoder by putting it in the [`CommandEncoderStatus::Locked`] state. /// - /// See [`CommandBuffer::get_encoder`]. /// Call [`CommandBuffer::unlock_encoder`] to put the [`CommandBuffer`] back into the [`CommandEncoderStatus::Recording`] state. - fn lock_encoder( - hub: &Hub, - id: id::CommandEncoderId, - ) -> Result, CommandEncoderError> { - let lock_on_acquire = true; - Self::get_encoder_impl(hub, id, lock_on_acquire) + fn lock_encoder(&self) -> Result<(), CommandEncoderError> { + self.lock_encoder_impl(true) } - /// Unlocks the [`CommandBuffer`] for `id` and puts it back into the [`CommandEncoderStatus::Recording`] state. + /// Unlocks the [`CommandBuffer`] and puts it back into the [`CommandEncoderStatus::Recording`] state. /// /// This function is the counterpart to [`CommandBuffer::lock_encoder`]. /// It is only valid to call this function if the encoder is in the [`CommandEncoderStatus::Locked`] state. @@ -661,7 +636,12 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?; + let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid), + }; + cmd_buf.check_recording()?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); #[cfg(feature = "trace")] @@ -692,7 +672,12 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?; + let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid), + }; + cmd_buf.check_recording()?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); @@ -723,7 +708,12 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?; + let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid), + }; + cmd_buf.check_recording()?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); diff --git a/wgpu-core/src/command/query.rs b/wgpu-core/src/command/query.rs index f01050d7f..f6601bddd 100644 --- a/wgpu-core/src/command/query.rs +++ b/wgpu-core/src/command/query.rs @@ -324,7 +324,14 @@ impl Global { ) -> Result<(), QueryError> { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; cmd_buf .device @@ -369,7 +376,15 @@ impl Global { ) -> Result<(), QueryError> { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index 1324c68d8..66abd33b6 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -1432,9 +1432,16 @@ impl Global { occlusion_query_set: None, }; - let cmd_buf = match CommandBuffer::lock_encoder(hub, encoder_id) { + let make_err = |e, arc_desc| (RenderPass::new(None, arc_desc), Some(e)); + + let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) { Ok(cmd_buf) => cmd_buf, - Err(e) => return (RenderPass::new(None, arc_desc), Some(e)), + Err(_) => return make_err(CommandEncoderError::Invalid, arc_desc), + }; + + match cmd_buf.lock_encoder() { + Ok(_) => {} + Err(e) => return make_err(e, arc_desc), }; let err = fill_arc_desc(hub, &cmd_buf.device, desc, &mut arc_desc).err(); @@ -1471,8 +1478,11 @@ impl Global { #[cfg(feature = "trace")] { let hub = A::hub(self); - let cmd_buf: Arc> = - CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?; + + let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid).map_pass_err(pass_scope)?, + }; let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); diff --git a/wgpu-core/src/command/transfer.rs b/wgpu-core/src/command/transfer.rs index 5748c0c99..4379777eb 100644 --- a/wgpu-core/src/command/transfer.rs +++ b/wgpu-core/src/command/transfer.rs @@ -2,7 +2,7 @@ use crate::device::trace::Command as TraceCommand; use crate::{ api_log, - command::{clear_texture, CommandBuffer, CommandEncoderError}, + command::{clear_texture, CommandEncoderError}, conv, device::{Device, DeviceError, MissingDownlevelFlags}, global::Global, @@ -544,7 +544,15 @@ impl Global { } let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); @@ -702,7 +710,15 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; + let device = &cmd_buf.device; device.check_is_valid()?; @@ -858,7 +874,15 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; + let device = &cmd_buf.device; device.check_is_valid()?; @@ -1026,7 +1050,15 @@ impl Global { let hub = A::hub(self); - let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?; + let cmd_buf = match hub + .command_buffers + .get(command_encoder_id.into_command_buffer_id()) + { + Ok(cmd_buf) => cmd_buf, + Err(_) => return Err(CommandEncoderError::Invalid.into()), + }; + cmd_buf.check_recording()?; + let device = &cmd_buf.device; device.check_is_valid()?; diff --git a/wgpu-core/src/id.rs b/wgpu-core/src/id.rs index 5bc86b377..05efbd2e4 100644 --- a/wgpu-core/src/id.rs +++ b/wgpu-core/src/id.rs @@ -323,6 +323,9 @@ ids! { pub type QuerySetId QuerySet; } +// The CommandBuffer type serves both as encoder and +// buffer, which is why the 2 functions below exist. + impl CommandEncoderId { pub fn into_command_buffer_id(self) -> CommandBufferId { Id(self.0, PhantomData)