Separate out ComputeCommand id->arc resolve (a step towards no lifetimes on wgpu::ComputePass) (#5432)

* move out compute command to separate module

* introduce ArcComputeCommand

* stateless tracker now returns reference to arc upon insertion

* add insert_merge_single to buffer tracker

* compute pass execution now works internally with an ArcComputeCommand

* compute pass execution now translates Command to ArcCommand ahead of time

* don't clone commands in compute pass execution

* remove doc hiding

* use option insert

* clippy fix

* fix private doc issue

* remove unnecessary copied over doc hide
This commit is contained in:
Andreas Reich 2024-04-23 09:01:29 +02:00 committed by GitHub
parent b3c5a6fb84
commit edf1a86148
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 442 additions and 139 deletions

View File

@ -99,7 +99,7 @@ impl GlobalPlay for wgc::global::Global {
base, base,
timestamp_writes, timestamp_writes,
} => { } => {
self.command_encoder_run_compute_pass_impl::<A>( self.command_encoder_run_compute_pass_with_unresolved_commands::<A>(
encoder, encoder,
base.as_ref(), base.as_ref(),
timestamp_writes.as_ref(), timestamp_writes.as_ref(),

View File

@ -1,3 +1,4 @@
use crate::command::compute_command::{ArcComputeCommand, ComputeCommand};
use crate::device::DeviceError; use crate::device::DeviceError;
use crate::resource::Resource; use crate::resource::Resource;
use crate::snatch::SnatchGuard; use crate::snatch::SnatchGuard;
@ -20,7 +21,6 @@ use crate::{
hal_label, id, hal_label, id,
id::DeviceId, id::DeviceId,
init_tracker::MemoryInitKind, init_tracker::MemoryInitKind,
pipeline,
resource::{self}, resource::{self},
storage::Storage, storage::Storage,
track::{Tracker, UsageConflict, UsageScope}, track::{Tracker, UsageConflict, UsageScope},
@ -39,59 +39,6 @@ use thiserror::Error;
use std::sync::Arc; use std::sync::Arc;
use std::{fmt, mem, str}; use std::{fmt, mem, str};
#[doc(hidden)]
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ComputeCommand {
SetBindGroup {
index: u32,
num_dynamic_offsets: usize,
bind_group_id: id::BindGroupId,
},
SetPipeline(id::ComputePipelineId),
/// Set a range of push constants to values stored in [`BasePass::push_constant_data`].
SetPushConstant {
/// The byte offset within the push constant storage to write to. This
/// must be a multiple of four.
offset: u32,
/// The number of bytes to write. This must be a multiple of four.
size_bytes: u32,
/// Index in [`BasePass::push_constant_data`] of the start of the data
/// to be written.
///
/// Note: this is not a byte offset like `offset`. Rather, it is the
/// index of the first `u32` element in `push_constant_data` to read.
values_offset: u32,
},
Dispatch([u32; 3]),
DispatchIndirect {
buffer_id: id::BufferId,
offset: wgt::BufferAddress,
},
PushDebugGroup {
color: u32,
len: usize,
},
PopDebugGroup,
InsertDebugMarker {
color: u32,
len: usize,
},
WriteTimestamp {
query_set_id: id::QuerySetId,
query_index: u32,
},
BeginPipelineStatisticsQuery {
query_set_id: id::QuerySetId,
query_index: u32,
},
EndPipelineStatisticsQuery,
}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ComputePass { pub struct ComputePass {
base: BasePass<ComputeCommand>, base: BasePass<ComputeCommand>,
@ -185,7 +132,7 @@ pub enum ComputePassErrorInner {
#[error(transparent)] #[error(transparent)]
Encoder(#[from] CommandEncoderError), Encoder(#[from] CommandEncoderError),
#[error("Bind group at index {0:?} is invalid")] #[error("Bind group at index {0:?} is invalid")]
InvalidBindGroup(usize), InvalidBindGroup(u32),
#[error("Device {0:?} is invalid")] #[error("Device {0:?} is invalid")]
InvalidDevice(DeviceId), InvalidDevice(DeviceId),
#[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")] #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
@ -250,7 +197,7 @@ impl PrettyError for ComputePassErrorInner {
pub struct ComputePassError { pub struct ComputePassError {
pub scope: PassErrorScope, pub scope: PassErrorScope,
#[source] #[source]
inner: ComputePassErrorInner, pub(super) inner: ComputePassErrorInner,
} }
impl PrettyError for ComputePassError { impl PrettyError for ComputePassError {
fn fmt_pretty(&self, fmt: &mut ErrorFormatter) { fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
@ -347,7 +294,8 @@ impl Global {
encoder_id: id::CommandEncoderId, encoder_id: id::CommandEncoderId,
pass: &ComputePass, pass: &ComputePass,
) -> Result<(), ComputePassError> { ) -> Result<(), ComputePassError> {
self.command_encoder_run_compute_pass_impl::<A>( // TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally.
self.command_encoder_run_compute_pass_with_unresolved_commands::<A>(
encoder_id, encoder_id,
pass.base.as_ref(), pass.base.as_ref(),
pass.timestamp_writes.as_ref(), pass.timestamp_writes.as_ref(),
@ -355,11 +303,33 @@ impl Global {
} }
#[doc(hidden)] #[doc(hidden)]
pub fn command_encoder_run_compute_pass_impl<A: HalApi>( pub fn command_encoder_run_compute_pass_with_unresolved_commands<A: HalApi>(
&self, &self,
encoder_id: id::CommandEncoderId, encoder_id: id::CommandEncoderId,
base: BasePassRef<ComputeCommand>, base: BasePassRef<ComputeCommand>,
timestamp_writes: Option<&ComputePassTimestampWrites>, timestamp_writes: Option<&ComputePassTimestampWrites>,
) -> Result<(), ComputePassError> {
let resolved_commands =
ComputeCommand::resolve_compute_command_ids(A::hub(self), base.commands)?;
self.command_encoder_run_compute_pass_impl::<A>(
encoder_id,
BasePassRef {
label: base.label,
commands: &resolved_commands,
dynamic_offsets: base.dynamic_offsets,
string_data: base.string_data,
push_constant_data: base.push_constant_data,
},
timestamp_writes,
)
}
fn command_encoder_run_compute_pass_impl<A: HalApi>(
&self,
encoder_id: id::CommandEncoderId,
base: BasePassRef<ArcComputeCommand<A>>,
timestamp_writes: Option<&ComputePassTimestampWrites>,
) -> Result<(), ComputePassError> { ) -> Result<(), ComputePassError> {
profiling::scope!("CommandEncoder::run_compute_pass"); profiling::scope!("CommandEncoder::run_compute_pass");
let pass_scope = PassErrorScope::Pass(encoder_id); let pass_scope = PassErrorScope::Pass(encoder_id);
@ -382,7 +352,13 @@ impl Global {
#[cfg(feature = "trace")] #[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands { if let Some(ref mut list) = cmd_buf_data.commands {
list.push(crate::device::trace::Command::RunComputePass { list.push(crate::device::trace::Command::RunComputePass {
base: BasePass::from_ref(base), base: BasePass {
label: base.label.map(str::to_string),
commands: base.commands.iter().map(Into::into).collect(),
dynamic_offsets: base.dynamic_offsets.to_vec(),
string_data: base.string_data.to_vec(),
push_constant_data: base.push_constant_data.to_vec(),
},
timestamp_writes: timestamp_writes.cloned(), timestamp_writes: timestamp_writes.cloned(),
}); });
} }
@ -402,9 +378,7 @@ impl Global {
let raw = encoder.open().map_pass_err(pass_scope)?; let raw = encoder.open().map_pass_err(pass_scope)?;
let bind_group_guard = hub.bind_groups.read(); let bind_group_guard = hub.bind_groups.read();
let pipeline_guard = hub.compute_pipelines.read();
let query_set_guard = hub.query_sets.read(); let query_set_guard = hub.query_sets.read();
let buffer_guard = hub.buffers.read();
let mut state = State { let mut state = State {
binder: Binder::new(), binder: Binder::new(),
@ -482,19 +456,21 @@ impl Global {
// be inserted before texture reads. // be inserted before texture reads.
let mut pending_discard_init_fixups = SurfacesInDiscardState::new(); let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
// TODO: We should be draining the commands here, avoiding extra copies in the process.
// (A command encoder can't be executed twice!)
for command in base.commands { for command in base.commands {
match *command { match command {
ComputeCommand::SetBindGroup { ArcComputeCommand::SetBindGroup {
index, index,
num_dynamic_offsets, num_dynamic_offsets,
bind_group_id, bind_group,
} => { } => {
let scope = PassErrorScope::SetBindGroup(bind_group_id); let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
let max_bind_groups = cmd_buf.limits.max_bind_groups; let max_bind_groups = cmd_buf.limits.max_bind_groups;
if index >= max_bind_groups { if index >= &max_bind_groups {
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
index, index: *index,
max: max_bind_groups, max: max_bind_groups,
}) })
.map_pass_err(scope); .map_pass_err(scope);
@ -507,13 +483,9 @@ impl Global {
); );
dynamic_offset_count += num_dynamic_offsets; dynamic_offset_count += num_dynamic_offsets;
let bind_group = tracker let bind_group = tracker.bind_groups.insert_single(bind_group.clone());
.bind_groups
.add_single(&*bind_group_guard, bind_group_id)
.ok_or(ComputePassErrorInner::InvalidBindGroup(index as usize))
.map_pass_err(scope)?;
bind_group bind_group
.validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits) .validate_dynamic_bindings(*index, &temp_offsets, &cmd_buf.limits)
.map_pass_err(scope)?; .map_pass_err(scope)?;
buffer_memory_init_actions.extend( buffer_memory_init_actions.extend(
@ -535,14 +507,14 @@ impl Global {
let entries = let entries =
state state
.binder .binder
.assign_group(index as usize, bind_group, &temp_offsets); .assign_group(*index as usize, bind_group, &temp_offsets);
if !entries.is_empty() && pipeline_layout.is_some() { if !entries.is_empty() && pipeline_layout.is_some() {
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw(); let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
for (i, e) in entries.iter().enumerate() { for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() { if let Some(group) = e.group.as_ref() {
let raw_bg = group let raw_bg = group
.raw(&snatch_guard) .raw(&snatch_guard)
.ok_or(ComputePassErrorInner::InvalidBindGroup(i)) .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
.map_pass_err(scope)?; .map_pass_err(scope)?;
unsafe { unsafe {
raw.set_bind_group( raw.set_bind_group(
@ -556,16 +528,13 @@ impl Global {
} }
} }
} }
ComputeCommand::SetPipeline(pipeline_id) => { ArcComputeCommand::SetPipeline(pipeline) => {
let pipeline_id = pipeline.as_info().id();
let scope = PassErrorScope::SetPipelineCompute(pipeline_id); let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
state.pipeline = Some(pipeline_id); state.pipeline = Some(pipeline_id);
let pipeline: &pipeline::ComputePipeline<A> = tracker tracker.compute_pipelines.insert_single(pipeline.clone());
.compute_pipelines
.add_single(&*pipeline_guard, pipeline_id)
.ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?;
unsafe { unsafe {
raw.set_compute_pipeline(pipeline.raw()); raw.set_compute_pipeline(pipeline.raw());
@ -589,7 +558,7 @@ impl Global {
if let Some(group) = e.group.as_ref() { if let Some(group) = e.group.as_ref() {
let raw_bg = group let raw_bg = group
.raw(&snatch_guard) .raw(&snatch_guard)
.ok_or(ComputePassErrorInner::InvalidBindGroup(i)) .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
.map_pass_err(scope)?; .map_pass_err(scope)?;
unsafe { unsafe {
raw.set_bind_group( raw.set_bind_group(
@ -625,7 +594,7 @@ impl Global {
} }
} }
} }
ComputeCommand::SetPushConstant { ArcComputeCommand::SetPushConstant {
offset, offset,
size_bytes, size_bytes,
values_offset, values_offset,
@ -636,7 +605,7 @@ impl Global {
let values_end_offset = let values_end_offset =
(values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize; (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
let data_slice = let data_slice =
&base.push_constant_data[(values_offset as usize)..values_end_offset]; &base.push_constant_data[(*values_offset as usize)..values_end_offset];
let pipeline_layout = state let pipeline_layout = state
.binder .binder
@ -651,7 +620,7 @@ impl Global {
pipeline_layout pipeline_layout
.validate_push_constant_ranges( .validate_push_constant_ranges(
wgt::ShaderStages::COMPUTE, wgt::ShaderStages::COMPUTE,
offset, *offset,
end_offset_bytes, end_offset_bytes,
) )
.map_pass_err(scope)?; .map_pass_err(scope)?;
@ -660,12 +629,12 @@ impl Global {
raw.set_push_constants( raw.set_push_constants(
pipeline_layout.raw(), pipeline_layout.raw(),
wgt::ShaderStages::COMPUTE, wgt::ShaderStages::COMPUTE,
offset, *offset,
data_slice, data_slice,
); );
} }
} }
ComputeCommand::Dispatch(groups) => { ArcComputeCommand::Dispatch(groups) => {
let scope = PassErrorScope::Dispatch { let scope = PassErrorScope::Dispatch {
indirect: false, indirect: false,
pipeline: state.pipeline, pipeline: state.pipeline,
@ -690,7 +659,7 @@ impl Global {
{ {
return Err(ComputePassErrorInner::Dispatch( return Err(ComputePassErrorInner::Dispatch(
DispatchError::InvalidGroupSize { DispatchError::InvalidGroupSize {
current: groups, current: *groups,
limit: groups_size_limit, limit: groups_size_limit,
}, },
)) ))
@ -698,10 +667,11 @@ impl Global {
} }
unsafe { unsafe {
raw.dispatch(groups); raw.dispatch(*groups);
} }
} }
ComputeCommand::DispatchIndirect { buffer_id, offset } => { ArcComputeCommand::DispatchIndirect { buffer, offset } => {
let buffer_id = buffer.as_info().id();
let scope = PassErrorScope::Dispatch { let scope = PassErrorScope::Dispatch {
indirect: true, indirect: true,
pipeline: state.pipeline, pipeline: state.pipeline,
@ -713,29 +683,25 @@ impl Global {
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION) .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
.map_pass_err(scope)?; .map_pass_err(scope)?;
let indirect_buffer = state state
.scope .scope
.buffers .buffers
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT) .insert_merge_single(buffer.clone(), hal::BufferUses::INDIRECT)
.map_pass_err(scope)?;
check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT)
.map_pass_err(scope)?; .map_pass_err(scope)?;
check_buffer_usage(
buffer_id,
indirect_buffer.usage,
wgt::BufferUsages::INDIRECT,
)
.map_pass_err(scope)?;
let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64; let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
if end_offset > indirect_buffer.size { if end_offset > buffer.size {
return Err(ComputePassErrorInner::IndirectBufferOverrun { return Err(ComputePassErrorInner::IndirectBufferOverrun {
offset, offset: *offset,
end_offset, end_offset,
buffer_size: indirect_buffer.size, buffer_size: buffer.size,
}) })
.map_pass_err(scope); .map_pass_err(scope);
} }
let buf_raw = indirect_buffer let buf_raw = buffer
.raw .raw
.get(&snatch_guard) .get(&snatch_guard)
.ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id)) .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
@ -744,9 +710,9 @@ impl Global {
let stride = 3 * 4; // 3 integers, x/y/z group size let stride = 3 * 4; // 3 integers, x/y/z group size
buffer_memory_init_actions.extend( buffer_memory_init_actions.extend(
indirect_buffer.initialization_status.read().create_action( buffer.initialization_status.read().create_action(
indirect_buffer, buffer,
offset..(offset + stride), *offset..(*offset + stride),
MemoryInitKind::NeedsInitializedMemory, MemoryInitKind::NeedsInitializedMemory,
), ),
); );
@ -756,15 +722,15 @@ impl Global {
raw, raw,
&mut intermediate_trackers, &mut intermediate_trackers,
&*bind_group_guard, &*bind_group_guard,
Some(indirect_buffer.as_info().tracker_index()), Some(buffer.as_info().tracker_index()),
&snatch_guard, &snatch_guard,
) )
.map_pass_err(scope)?; .map_pass_err(scope)?;
unsafe { unsafe {
raw.dispatch_indirect(buf_raw, offset); raw.dispatch_indirect(buf_raw, *offset);
} }
} }
ComputeCommand::PushDebugGroup { color: _, len } => { ArcComputeCommand::PushDebugGroup { color: _, len } => {
state.debug_scope_depth += 1; state.debug_scope_depth += 1;
if !discard_hal_labels { if !discard_hal_labels {
let label = let label =
@ -776,7 +742,7 @@ impl Global {
} }
string_offset += len; string_offset += len;
} }
ComputeCommand::PopDebugGroup => { ArcComputeCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup; let scope = PassErrorScope::PopDebugGroup;
if state.debug_scope_depth == 0 { if state.debug_scope_depth == 0 {
@ -790,7 +756,7 @@ impl Global {
} }
} }
} }
ComputeCommand::InsertDebugMarker { color: _, len } => { ArcComputeCommand::InsertDebugMarker { color: _, len } => {
if !discard_hal_labels { if !discard_hal_labels {
let label = let label =
str::from_utf8(&base.string_data[string_offset..string_offset + len]) str::from_utf8(&base.string_data[string_offset..string_offset + len])
@ -799,49 +765,43 @@ impl Global {
} }
string_offset += len; string_offset += len;
} }
ComputeCommand::WriteTimestamp { ArcComputeCommand::WriteTimestamp {
query_set_id, query_set,
query_index, query_index,
} => { } => {
let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::WriteTimestamp; let scope = PassErrorScope::WriteTimestamp;
device device
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES) .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
.map_pass_err(scope)?; .map_pass_err(scope)?;
let query_set: &resource::QuerySet<A> = tracker let query_set = tracker.query_sets.insert_single(query_set.clone());
.query_sets
.add_single(&*query_set_guard, query_set_id)
.ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?;
query_set query_set
.validate_and_write_timestamp(raw, query_set_id, query_index, None) .validate_and_write_timestamp(raw, query_set_id, *query_index, None)
.map_pass_err(scope)?; .map_pass_err(scope)?;
} }
ComputeCommand::BeginPipelineStatisticsQuery { ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set_id, query_set,
query_index, query_index,
} => { } => {
let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::BeginPipelineStatisticsQuery; let scope = PassErrorScope::BeginPipelineStatisticsQuery;
let query_set: &resource::QuerySet<A> = tracker let query_set = tracker.query_sets.insert_single(query_set.clone());
.query_sets
.add_single(&*query_set_guard, query_set_id)
.ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?;
query_set query_set
.validate_and_begin_pipeline_statistics_query( .validate_and_begin_pipeline_statistics_query(
raw, raw,
query_set_id, query_set_id,
query_index, *query_index,
None, None,
&mut active_query, &mut active_query,
) )
.map_pass_err(scope)?; .map_pass_err(scope)?;
} }
ComputeCommand::EndPipelineStatisticsQuery => { ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery; let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query) end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)

View File

@ -0,0 +1,322 @@
use std::sync::Arc;
use crate::{
binding_model::BindGroup,
hal_api::HalApi,
id,
pipeline::ComputePipeline,
resource::{Buffer, QuerySet},
};
use super::{ComputePassError, ComputePassErrorInner, PassErrorScope};
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ComputeCommand {
SetBindGroup {
index: u32,
num_dynamic_offsets: usize,
bind_group_id: id::BindGroupId,
},
SetPipeline(id::ComputePipelineId),
/// Set a range of push constants to values stored in `push_constant_data`.
SetPushConstant {
/// The byte offset within the push constant storage to write to. This
/// must be a multiple of four.
offset: u32,
/// The number of bytes to write. This must be a multiple of four.
size_bytes: u32,
/// Index in `push_constant_data` of the start of the data
/// to be written.
///
/// Note: this is not a byte offset like `offset`. Rather, it is the
/// index of the first `u32` element in `push_constant_data` to read.
values_offset: u32,
},
Dispatch([u32; 3]),
DispatchIndirect {
buffer_id: id::BufferId,
offset: wgt::BufferAddress,
},
PushDebugGroup {
color: u32,
len: usize,
},
PopDebugGroup,
InsertDebugMarker {
color: u32,
len: usize,
},
WriteTimestamp {
query_set_id: id::QuerySetId,
query_index: u32,
},
BeginPipelineStatisticsQuery {
query_set_id: id::QuerySetId,
query_index: u32,
},
EndPipelineStatisticsQuery,
}
impl ComputeCommand {
/// Resolves all ids in a list of commands into the corresponding resource Arc.
///
// TODO: Once resolving is done on-the-fly during recording, this function should be only needed with the replay feature:
// #[cfg(feature = "replay")]
pub fn resolve_compute_command_ids<A: HalApi>(
hub: &crate::hub::Hub<A>,
commands: &[ComputeCommand],
) -> Result<Vec<ArcComputeCommand<A>>, ComputePassError> {
let buffers_guard = hub.buffers.read();
let bind_group_guard = hub.bind_groups.read();
let query_set_guard = hub.query_sets.read();
let pipelines_guard = hub.compute_pipelines.read();
let resolved_commands: Vec<ArcComputeCommand<A>> = commands
.iter()
.map(|c| -> Result<ArcComputeCommand<A>, ComputePassError> {
Ok(match *c {
ComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group_id,
} => ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::SetBindGroup(bind_group_id),
inner: ComputePassErrorInner::InvalidBindGroup(index),
}
})?,
},
ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
pipelines_guard
.get_owned(pipeline_id)
.map_err(|_| ComputePassError {
scope: PassErrorScope::SetPipelineCompute(pipeline_id),
inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
})?,
),
ComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
} => ArcComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
},
ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
ComputeCommand::DispatchIndirect { buffer_id, offset } => {
ArcComputeCommand::DispatchIndirect {
buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::Dispatch {
indirect: true,
pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again.
},
inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
}
})?,
offset,
}
}
ComputeCommand::PushDebugGroup { color, len } => {
ArcComputeCommand::PushDebugGroup { color, len }
}
ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
ComputeCommand::InsertDebugMarker { color, len } => {
ArcComputeCommand::InsertDebugMarker { color, len }
}
ComputeCommand::WriteTimestamp {
query_set_id,
query_index,
} => ArcComputeCommand::WriteTimestamp {
query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::WriteTimestamp,
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
}
})?,
query_index,
},
ComputeCommand::BeginPipelineStatisticsQuery {
query_set_id,
query_index,
} => ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
ComputePassError {
scope: PassErrorScope::BeginPipelineStatisticsQuery,
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
}
})?,
query_index,
},
ComputeCommand::EndPipelineStatisticsQuery => {
ArcComputeCommand::EndPipelineStatisticsQuery
}
})
})
.collect::<Result<Vec<_>, ComputePassError>>()?;
Ok(resolved_commands)
}
}
/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs.
#[derive(Clone, Debug)]
pub enum ArcComputeCommand<A: HalApi> {
SetBindGroup {
index: u32,
num_dynamic_offsets: usize,
bind_group: Arc<BindGroup<A>>,
},
SetPipeline(Arc<ComputePipeline<A>>),
/// Set a range of push constants to values stored in `push_constant_data`.
SetPushConstant {
/// The byte offset within the push constant storage to write to. This
/// must be a multiple of four.
offset: u32,
/// The number of bytes to write. This must be a multiple of four.
size_bytes: u32,
/// Index in `push_constant_data` of the start of the data
/// to be written.
///
/// Note: this is not a byte offset like `offset`. Rather, it is the
/// index of the first `u32` element in `push_constant_data` to read.
values_offset: u32,
},
Dispatch([u32; 3]),
DispatchIndirect {
buffer: Arc<Buffer<A>>,
offset: wgt::BufferAddress,
},
PushDebugGroup {
color: u32,
len: usize,
},
PopDebugGroup,
InsertDebugMarker {
color: u32,
len: usize,
},
WriteTimestamp {
query_set: Arc<QuerySet<A>>,
query_index: u32,
},
BeginPipelineStatisticsQuery {
query_set: Arc<QuerySet<A>>,
query_index: u32,
},
EndPipelineStatisticsQuery,
}
#[cfg(feature = "trace")]
impl<A: HalApi> From<&ArcComputeCommand<A>> for ComputeCommand {
fn from(value: &ArcComputeCommand<A>) -> Self {
use crate::resource::Resource as _;
match value {
ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group,
} => ComputeCommand::SetBindGroup {
index: *index,
num_dynamic_offsets: *num_dynamic_offsets,
bind_group_id: bind_group.as_info().id(),
},
ArcComputeCommand::SetPipeline(pipeline) => {
ComputeCommand::SetPipeline(pipeline.as_info().id())
}
ArcComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
} => ComputeCommand::SetPushConstant {
offset: *offset,
size_bytes: *size_bytes,
values_offset: *values_offset,
},
ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim),
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
ComputeCommand::DispatchIndirect {
buffer_id: buffer.as_info().id(),
offset: *offset,
}
}
ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup {
color: *color,
len: *len,
},
ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup,
ArcComputeCommand::InsertDebugMarker { color, len } => {
ComputeCommand::InsertDebugMarker {
color: *color,
len: *len,
}
}
ArcComputeCommand::WriteTimestamp {
query_set,
query_index,
} => ComputeCommand::WriteTimestamp {
query_set_id: query_set.as_info().id(),
query_index: *query_index,
},
ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
query_index,
} => ComputeCommand::BeginPipelineStatisticsQuery {
query_set_id: query_set.as_info().id(),
query_index: *query_index,
},
ArcComputeCommand::EndPipelineStatisticsQuery => {
ComputeCommand::EndPipelineStatisticsQuery
}
}
}
}

View File

@ -3,6 +3,7 @@ mod bind;
mod bundle; mod bundle;
mod clear; mod clear;
mod compute; mod compute;
mod compute_command;
mod draw; mod draw;
mod memory_init; mod memory_init;
mod query; mod query;
@ -13,7 +14,8 @@ use std::sync::Arc;
pub(crate) use self::clear::clear_texture; pub(crate) use self::clear::clear_texture;
pub use self::{ pub use self::{
bundle::*, clear::ClearError, compute::*, draw::*, query::*, render::*, transfer::*, bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*,
render::*, transfer::*,
}; };
pub(crate) use allocator::CommandAllocator; pub(crate) use allocator::CommandAllocator;

View File

@ -245,6 +245,22 @@ impl<A: HalApi> BufferUsageScope<A> {
.get(id) .get(id)
.map_err(|_| UsageConflict::BufferInvalid { id })?; .map_err(|_| UsageConflict::BufferInvalid { id })?;
self.insert_merge_single(buffer.clone(), new_state)
.map(|_| buffer)
}
/// Merge a single state into the UsageScope, using an already resolved buffer.
///
/// If the resulting state is invalid, returns a usage
/// conflict with the details of the invalid state.
///
/// If the ID is higher than the length of internal vectors,
/// the vectors will be extended. A call to set_size is not needed.
pub fn insert_merge_single(
&mut self,
buffer: Arc<Buffer<A>>,
new_state: BufferUses,
) -> Result<(), UsageConflict> {
let index = buffer.info.tracker_index().as_usize(); let index = buffer.info.tracker_index().as_usize();
self.allow_index(index); self.allow_index(index);
@ -260,12 +276,12 @@ impl<A: HalApi> BufferUsageScope<A> {
index, index,
BufferStateProvider::Direct { state: new_state }, BufferStateProvider::Direct { state: new_state },
ResourceMetadataProvider::Direct { ResourceMetadataProvider::Direct {
resource: Cow::Owned(buffer.clone()), resource: Cow::Owned(buffer),
}, },
)?; )?;
} }
Ok(buffer) Ok(())
} }
} }

View File

@ -87,16 +87,18 @@ impl<T: Resource> ResourceMetadata<T> {
/// Add the resource with the given index, epoch, and reference count to the /// Add the resource with the given index, epoch, and reference count to the
/// set. /// set.
/// ///
/// Returns a reference to the newly inserted resource.
/// (This allows avoiding a clone/reference count increase in many cases.)
///
/// # Safety /// # Safety
/// ///
/// The given `index` must be in bounds for this `ResourceMetadata`'s /// The given `index` must be in bounds for this `ResourceMetadata`'s
/// existing tables. See `tracker_assert_in_bounds`. /// existing tables. See `tracker_assert_in_bounds`.
#[inline(always)] #[inline(always)]
pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc<T>) { pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc<T>) -> &Arc<T> {
self.owned.set(index, true); self.owned.set(index, true);
unsafe { let resource_dst = unsafe { self.resources.get_unchecked_mut(index) };
*self.resources.get_unchecked_mut(index) = Some(resource); resource_dst.insert(resource)
}
} }
/// Get the resource with the given index. /// Get the resource with the given index.

View File

@ -158,16 +158,17 @@ impl<T: Resource> StatelessTracker<T> {
/// ///
/// If the ID is higher than the length of internal vectors, /// If the ID is higher than the length of internal vectors,
/// the vectors will be extended. A call to set_size is not needed. /// the vectors will be extended. A call to set_size is not needed.
pub fn insert_single(&mut self, resource: Arc<T>) { ///
/// Returns a reference to the newly inserted resource.
/// (This allows avoiding a clone/reference count increase in many cases.)
pub fn insert_single(&mut self, resource: Arc<T>) -> &Arc<T> {
let index = resource.as_info().tracker_index().as_usize(); let index = resource.as_info().tracker_index().as_usize();
self.allow_index(index); self.allow_index(index);
self.tracker_assert_in_bounds(index); self.tracker_assert_in_bounds(index);
unsafe { unsafe { self.metadata.insert(index, resource) }
self.metadata.insert(index, resource);
}
} }
/// Adds the given resource to the tracker. /// Adds the given resource to the tracker.