move command buffer resolving in Global's methods

This commit is contained in:
teoxoy 2024-07-03 08:59:30 +02:00 committed by Teodor Tanasoaia
parent a9c74f42d6
commit e26d2d7763
7 changed files with 168 additions and 93 deletions

View File

@ -4,7 +4,7 @@ use std::{ops::Range, sync::Arc};
use crate::device::trace::Command as TraceCommand;
use crate::{
api_log,
command::CommandBuffer,
command::CommandEncoderError,
device::DeviceError,
get_lowest_common_denom,
global::Global,
@ -76,7 +76,7 @@ whereas subesource range specified start {subresource_base_array_layer} and coun
#[error(transparent)]
Device(#[from] DeviceError),
#[error(transparent)]
CommandEncoderError(#[from] super::CommandEncoderError),
CommandEncoderError(#[from] CommandEncoderError),
}
impl Global {
@ -92,7 +92,15 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
@ -176,7 +184,15 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

View File

@ -301,35 +301,40 @@ impl Global {
timestamp_writes: None, // Handle only once we resolved the encoder.
};
match CommandBuffer::lock_encoder(hub, encoder_id) {
Ok(cmd_buf) => {
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let Ok(query_set) = hub.query_sets.get(tw.query_set) else {
return (
ComputePass::new(None, arc_desc),
Some(CommandEncoderError::InvalidTimestampWritesQuerySetId(
tw.query_set,
)),
);
};
let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));
if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) {
return (ComputePass::new(None, arc_desc), Some(e.into()));
}
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return make_err(CommandEncoderError::Invalid, arc_desc),
};
Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};
match cmd_buf.lock_encoder() {
Ok(_) => {}
Err(e) => return make_err(e, arc_desc),
};
(ComputePass::new(Some(cmd_buf), arc_desc), None)
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let Ok(query_set) = hub.query_sets.get(tw.query_set) else {
return make_err(
CommandEncoderError::InvalidTimestampWritesQuerySetId(tw.query_set),
arc_desc,
);
};
if let Err(e) = query_set.same_device_as(cmd_buf.as_ref()) {
return make_err(e.into(), arc_desc);
}
Err(err) => (ComputePass::new(None, arc_desc), Some(err)),
}
Some(ArcPassTimestampWrites {
query_set,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};
(ComputePass::new(Some(cmd_buf), arc_desc), None)
}
/// Creates a type erased compute pass.
@ -378,7 +383,11 @@ impl Global {
let hub = A::hub(self);
let scope = PassErrorScope::Pass;
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid).map_pass_err(scope),
};
cmd_buf.check_recording().map_pass_err(scope)?;
#[cfg(feature = "trace")]
{

View File

@ -30,7 +30,6 @@ pub use timestamp_writes::PassTimestampWrites;
use self::memory_init::CommandBufferTextureMemoryActions;
use crate::device::{Device, DeviceError};
use crate::hub::Hub;
use crate::lock::{rank, Mutex};
use crate::snatch::SnatchGuard;
@ -425,65 +424,41 @@ impl<A: HalApi> CommandBuffer<A> {
}
impl<A: HalApi> CommandBuffer<A> {
fn get_encoder_impl(
hub: &Hub<A>,
id: id::CommandEncoderId,
lock_on_acquire: bool,
) -> Result<Arc<Self>, CommandEncoderError> {
match hub.command_buffers.get(id.into_command_buffer_id()) {
Ok(cmd_buf) => {
let mut cmd_buf_data_guard = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data_guard.as_mut().unwrap();
match cmd_buf_data.status {
CommandEncoderStatus::Recording => {
if lock_on_acquire {
cmd_buf_data.status = CommandEncoderStatus::Locked;
}
drop(cmd_buf_data_guard);
Ok(cmd_buf)
}
CommandEncoderStatus::Locked => {
// Any operation on a locked encoder is required to put it into the invalid/error state.
// See https://www.w3.org/TR/webgpu/#encoder-state-locked
cmd_buf_data.encoder.discard();
cmd_buf_data.status = CommandEncoderStatus::Error;
Err(CommandEncoderError::Locked)
}
CommandEncoderStatus::Finished => Err(CommandEncoderError::NotRecording),
CommandEncoderStatus::Error => Err(CommandEncoderError::Invalid),
fn lock_encoder_impl(&self, lock: bool) -> Result<(), CommandEncoderError> {
let mut cmd_buf_data_guard = self.data.lock();
let cmd_buf_data = cmd_buf_data_guard.as_mut().unwrap();
match cmd_buf_data.status {
CommandEncoderStatus::Recording => {
if lock {
cmd_buf_data.status = CommandEncoderStatus::Locked;
}
Ok(())
}
Err(_) => Err(CommandEncoderError::Invalid),
CommandEncoderStatus::Locked => {
// Any operation on a locked encoder is required to put it into the invalid/error state.
// See https://www.w3.org/TR/webgpu/#encoder-state-locked
cmd_buf_data.encoder.discard();
cmd_buf_data.status = CommandEncoderStatus::Error;
Err(CommandEncoderError::Locked)
}
CommandEncoderStatus::Finished => Err(CommandEncoderError::NotRecording),
CommandEncoderStatus::Error => Err(CommandEncoderError::Invalid),
}
}
/// Return the [`CommandBuffer`] for `id`, for recording new commands.
///
/// In `wgpu_core`, the [`CommandBuffer`] type serves both as encoder and
/// buffer, which is why this function takes an [`id::CommandEncoderId`] but
/// returns a [`CommandBuffer`]. The returned command buffer must be in the
/// "recording" state. Otherwise, an error is returned.
fn get_encoder(
hub: &Hub<A>,
id: id::CommandEncoderId,
) -> Result<Arc<Self>, CommandEncoderError> {
let lock_on_acquire = false;
Self::get_encoder_impl(hub, id, lock_on_acquire)
/// Checks that the encoder is in the [`CommandEncoderStatus::Recording`] state.
fn check_recording(&self) -> Result<(), CommandEncoderError> {
self.lock_encoder_impl(false)
}
/// Return the [`CommandBuffer`] for `id` and if successful puts it into the [`CommandEncoderStatus::Locked`] state.
/// Locks the encoder by putting it in the [`CommandEncoderStatus::Locked`] state.
///
/// See [`CommandBuffer::get_encoder`].
/// Call [`CommandBuffer::unlock_encoder`] to put the [`CommandBuffer`] back into the [`CommandEncoderStatus::Recording`] state.
fn lock_encoder(
hub: &Hub<A>,
id: id::CommandEncoderId,
) -> Result<Arc<Self>, CommandEncoderError> {
let lock_on_acquire = true;
Self::get_encoder_impl(hub, id, lock_on_acquire)
fn lock_encoder(&self) -> Result<(), CommandEncoderError> {
self.lock_encoder_impl(true)
}
/// Unlocks the [`CommandBuffer`] for `id` and puts it back into the [`CommandEncoderStatus::Recording`] state.
/// Unlocks the [`CommandBuffer`] and puts it back into the [`CommandEncoderStatus::Recording`] state.
///
/// This function is the counterpart to [`CommandBuffer::lock_encoder`].
/// It is only valid to call this function if the encoder is in the [`CommandEncoderStatus::Locked`] state.
@ -661,7 +636,12 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid),
};
cmd_buf.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
#[cfg(feature = "trace")]
@ -692,7 +672,12 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid),
};
cmd_buf.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
@ -723,7 +708,12 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid),
};
cmd_buf.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

View File

@ -324,7 +324,14 @@ impl Global {
) -> Result<(), QueryError> {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
cmd_buf
.device
@ -369,7 +376,15 @@ impl Global {
) -> Result<(), QueryError> {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

View File

@ -1432,9 +1432,16 @@ impl Global {
occlusion_query_set: None,
};
let cmd_buf = match CommandBuffer::lock_encoder(hub, encoder_id) {
let make_err = |e, arc_desc| (RenderPass::new(None, arc_desc), Some(e));
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(e) => return (RenderPass::new(None, arc_desc), Some(e)),
Err(_) => return make_err(CommandEncoderError::Invalid, arc_desc),
};
match cmd_buf.lock_encoder() {
Ok(_) => {}
Err(e) => return make_err(e, arc_desc),
};
let err = fill_arc_desc(hub, &cmd_buf.device, desc, &mut arc_desc).err();
@ -1471,8 +1478,11 @@ impl Global {
#[cfg(feature = "trace")]
{
let hub = A::hub(self);
let cmd_buf: Arc<CommandBuffer<A>> =
CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;
let cmd_buf = match hub.command_buffers.get(encoder_id.into_command_buffer_id()) {
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid).map_pass_err(pass_scope)?,
};
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();

View File

@ -2,7 +2,7 @@
use crate::device::trace::Command as TraceCommand;
use crate::{
api_log,
command::{clear_texture, CommandBuffer, CommandEncoderError},
command::{clear_texture, CommandEncoderError},
conv,
device::{Device, DeviceError, MissingDownlevelFlags},
global::Global,
@ -544,7 +544,15 @@ impl Global {
}
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
@ -702,7 +710,15 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
let device = &cmd_buf.device;
device.check_is_valid()?;
@ -858,7 +874,15 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
let device = &cmd_buf.device;
device.check_is_valid()?;
@ -1026,7 +1050,15 @@ impl Global {
let hub = A::hub(self);
let cmd_buf = CommandBuffer::get_encoder(hub, command_encoder_id)?;
let cmd_buf = match hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id())
{
Ok(cmd_buf) => cmd_buf,
Err(_) => return Err(CommandEncoderError::Invalid.into()),
};
cmd_buf.check_recording()?;
let device = &cmd_buf.device;
device.check_is_valid()?;

View File

@ -323,6 +323,9 @@ ids! {
pub type QuerySetId QuerySet;
}
// The CommandBuffer type serves both as encoder and
// buffer, which is why the 2 functions below exist.
impl CommandEncoderId {
pub fn into_command_buffer_id(self) -> CommandBufferId {
Id(self.0, PhantomData)