diff --git a/player/src/lib.rs b/player/src/lib.rs
index ca3a4b6a5..c67c605e5 100644
--- a/player/src/lib.rs
+++ b/player/src/lib.rs
@@ -99,7 +99,7 @@ impl GlobalPlay for wgc::global::Global {
base,
timestamp_writes,
} => {
- self.command_encoder_run_compute_pass_impl::(
+ self.command_encoder_run_compute_pass_with_unresolved_commands::(
encoder,
base.as_ref(),
timestamp_writes.as_ref(),
diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs
index 49492ac03..4ee48f008 100644
--- a/wgpu-core/src/command/compute.rs
+++ b/wgpu-core/src/command/compute.rs
@@ -1,3 +1,4 @@
+use crate::command::compute_command::{ArcComputeCommand, ComputeCommand};
use crate::device::DeviceError;
use crate::resource::Resource;
use crate::snatch::SnatchGuard;
@@ -20,7 +21,6 @@ use crate::{
hal_label, id,
id::DeviceId,
init_tracker::MemoryInitKind,
- pipeline,
resource::{self},
storage::Storage,
track::{Tracker, UsageConflict, UsageScope},
@@ -39,59 +39,6 @@ use thiserror::Error;
use std::sync::Arc;
use std::{fmt, mem, str};
-#[doc(hidden)]
-#[derive(Clone, Copy, Debug)]
-#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
-pub enum ComputeCommand {
- SetBindGroup {
- index: u32,
- num_dynamic_offsets: usize,
- bind_group_id: id::BindGroupId,
- },
- SetPipeline(id::ComputePipelineId),
-
- /// Set a range of push constants to values stored in [`BasePass::push_constant_data`].
- SetPushConstant {
- /// The byte offset within the push constant storage to write to. This
- /// must be a multiple of four.
- offset: u32,
-
- /// The number of bytes to write. This must be a multiple of four.
- size_bytes: u32,
-
- /// Index in [`BasePass::push_constant_data`] of the start of the data
- /// to be written.
- ///
- /// Note: this is not a byte offset like `offset`. Rather, it is the
- /// index of the first `u32` element in `push_constant_data` to read.
- values_offset: u32,
- },
-
- Dispatch([u32; 3]),
- DispatchIndirect {
- buffer_id: id::BufferId,
- offset: wgt::BufferAddress,
- },
- PushDebugGroup {
- color: u32,
- len: usize,
- },
- PopDebugGroup,
- InsertDebugMarker {
- color: u32,
- len: usize,
- },
- WriteTimestamp {
- query_set_id: id::QuerySetId,
- query_index: u32,
- },
- BeginPipelineStatisticsQuery {
- query_set_id: id::QuerySetId,
- query_index: u32,
- },
- EndPipelineStatisticsQuery,
-}
-
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ComputePass {
base: BasePass,
@@ -185,7 +132,7 @@ pub enum ComputePassErrorInner {
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Bind group at index {0:?} is invalid")]
- InvalidBindGroup(usize),
+ InvalidBindGroup(u32),
#[error("Device {0:?} is invalid")]
InvalidDevice(DeviceId),
#[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
@@ -250,7 +197,7 @@ impl PrettyError for ComputePassErrorInner {
pub struct ComputePassError {
pub scope: PassErrorScope,
#[source]
- inner: ComputePassErrorInner,
+ pub(super) inner: ComputePassErrorInner,
}
impl PrettyError for ComputePassError {
fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
@@ -347,7 +294,8 @@ impl Global {
encoder_id: id::CommandEncoderId,
pass: &ComputePass,
) -> Result<(), ComputePassError> {
- self.command_encoder_run_compute_pass_impl::(
+ // TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally.
+ self.command_encoder_run_compute_pass_with_unresolved_commands::(
encoder_id,
pass.base.as_ref(),
pass.timestamp_writes.as_ref(),
@@ -355,11 +303,33 @@ impl Global {
}
#[doc(hidden)]
- pub fn command_encoder_run_compute_pass_impl(
+ pub fn command_encoder_run_compute_pass_with_unresolved_commands(
&self,
encoder_id: id::CommandEncoderId,
base: BasePassRef,
timestamp_writes: Option<&ComputePassTimestampWrites>,
+ ) -> Result<(), ComputePassError> {
+ let resolved_commands =
+ ComputeCommand::resolve_compute_command_ids(A::hub(self), base.commands)?;
+
+ self.command_encoder_run_compute_pass_impl::(
+ encoder_id,
+ BasePassRef {
+ label: base.label,
+ commands: &resolved_commands,
+ dynamic_offsets: base.dynamic_offsets,
+ string_data: base.string_data,
+ push_constant_data: base.push_constant_data,
+ },
+ timestamp_writes,
+ )
+ }
+
+ fn command_encoder_run_compute_pass_impl(
+ &self,
+ encoder_id: id::CommandEncoderId,
+ base: BasePassRef>,
+ timestamp_writes: Option<&ComputePassTimestampWrites>,
) -> Result<(), ComputePassError> {
profiling::scope!("CommandEncoder::run_compute_pass");
let pass_scope = PassErrorScope::Pass(encoder_id);
@@ -382,7 +352,13 @@ impl Global {
#[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(crate::device::trace::Command::RunComputePass {
- base: BasePass::from_ref(base),
+ base: BasePass {
+ label: base.label.map(str::to_string),
+ commands: base.commands.iter().map(Into::into).collect(),
+ dynamic_offsets: base.dynamic_offsets.to_vec(),
+ string_data: base.string_data.to_vec(),
+ push_constant_data: base.push_constant_data.to_vec(),
+ },
timestamp_writes: timestamp_writes.cloned(),
});
}
@@ -402,9 +378,7 @@ impl Global {
let raw = encoder.open().map_pass_err(pass_scope)?;
let bind_group_guard = hub.bind_groups.read();
- let pipeline_guard = hub.compute_pipelines.read();
let query_set_guard = hub.query_sets.read();
- let buffer_guard = hub.buffers.read();
let mut state = State {
binder: Binder::new(),
@@ -482,19 +456,21 @@ impl Global {
// be inserted before texture reads.
let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
+ // TODO: We should be draining the commands here, avoiding extra copies in the process.
+ // (A command encoder can't be executed twice!)
for command in base.commands {
- match *command {
- ComputeCommand::SetBindGroup {
+ match command {
+ ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
- bind_group_id,
+ bind_group,
} => {
- let scope = PassErrorScope::SetBindGroup(bind_group_id);
+ let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
let max_bind_groups = cmd_buf.limits.max_bind_groups;
- if index >= max_bind_groups {
+ if index >= &max_bind_groups {
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
- index,
+ index: *index,
max: max_bind_groups,
})
.map_pass_err(scope);
@@ -507,13 +483,9 @@ impl Global {
);
dynamic_offset_count += num_dynamic_offsets;
- let bind_group = tracker
- .bind_groups
- .add_single(&*bind_group_guard, bind_group_id)
- .ok_or(ComputePassErrorInner::InvalidBindGroup(index as usize))
- .map_pass_err(scope)?;
+ let bind_group = tracker.bind_groups.insert_single(bind_group.clone());
bind_group
- .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
+ .validate_dynamic_bindings(*index, &temp_offsets, &cmd_buf.limits)
.map_pass_err(scope)?;
buffer_memory_init_actions.extend(
@@ -535,14 +507,14 @@ impl Global {
let entries =
state
.binder
- .assign_group(index as usize, bind_group, &temp_offsets);
+ .assign_group(*index as usize, bind_group, &temp_offsets);
if !entries.is_empty() && pipeline_layout.is_some() {
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group
.raw(&snatch_guard)
- .ok_or(ComputePassErrorInner::InvalidBindGroup(i))
+ .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
.map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
@@ -556,16 +528,13 @@ impl Global {
}
}
}
- ComputeCommand::SetPipeline(pipeline_id) => {
+ ArcComputeCommand::SetPipeline(pipeline) => {
+ let pipeline_id = pipeline.as_info().id();
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
state.pipeline = Some(pipeline_id);
- let pipeline: &pipeline::ComputePipeline = tracker
- .compute_pipelines
- .add_single(&*pipeline_guard, pipeline_id)
- .ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id))
- .map_pass_err(scope)?;
+ tracker.compute_pipelines.insert_single(pipeline.clone());
unsafe {
raw.set_compute_pipeline(pipeline.raw());
@@ -589,7 +558,7 @@ impl Global {
if let Some(group) = e.group.as_ref() {
let raw_bg = group
.raw(&snatch_guard)
- .ok_or(ComputePassErrorInner::InvalidBindGroup(i))
+ .ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
.map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
@@ -625,7 +594,7 @@ impl Global {
}
}
}
- ComputeCommand::SetPushConstant {
+ ArcComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
@@ -636,7 +605,7 @@ impl Global {
let values_end_offset =
(values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
let data_slice =
- &base.push_constant_data[(values_offset as usize)..values_end_offset];
+ &base.push_constant_data[(*values_offset as usize)..values_end_offset];
let pipeline_layout = state
.binder
@@ -651,7 +620,7 @@ impl Global {
pipeline_layout
.validate_push_constant_ranges(
wgt::ShaderStages::COMPUTE,
- offset,
+ *offset,
end_offset_bytes,
)
.map_pass_err(scope)?;
@@ -660,12 +629,12 @@ impl Global {
raw.set_push_constants(
pipeline_layout.raw(),
wgt::ShaderStages::COMPUTE,
- offset,
+ *offset,
data_slice,
);
}
}
- ComputeCommand::Dispatch(groups) => {
+ ArcComputeCommand::Dispatch(groups) => {
let scope = PassErrorScope::Dispatch {
indirect: false,
pipeline: state.pipeline,
@@ -690,7 +659,7 @@ impl Global {
{
return Err(ComputePassErrorInner::Dispatch(
DispatchError::InvalidGroupSize {
- current: groups,
+ current: *groups,
limit: groups_size_limit,
},
))
@@ -698,10 +667,11 @@ impl Global {
}
unsafe {
- raw.dispatch(groups);
+ raw.dispatch(*groups);
}
}
- ComputeCommand::DispatchIndirect { buffer_id, offset } => {
+ ArcComputeCommand::DispatchIndirect { buffer, offset } => {
+ let buffer_id = buffer.as_info().id();
let scope = PassErrorScope::Dispatch {
indirect: true,
pipeline: state.pipeline,
@@ -713,29 +683,25 @@ impl Global {
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
.map_pass_err(scope)?;
- let indirect_buffer = state
+ state
.scope
.buffers
- .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT)
+ .insert_merge_single(buffer.clone(), hal::BufferUses::INDIRECT)
+ .map_pass_err(scope)?;
+ check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT)
.map_pass_err(scope)?;
- check_buffer_usage(
- buffer_id,
- indirect_buffer.usage,
- wgt::BufferUsages::INDIRECT,
- )
- .map_pass_err(scope)?;
let end_offset = offset + mem::size_of::() as u64;
- if end_offset > indirect_buffer.size {
+ if end_offset > buffer.size {
return Err(ComputePassErrorInner::IndirectBufferOverrun {
- offset,
+ offset: *offset,
end_offset,
- buffer_size: indirect_buffer.size,
+ buffer_size: buffer.size,
})
.map_pass_err(scope);
}
- let buf_raw = indirect_buffer
+ let buf_raw = buffer
.raw
.get(&snatch_guard)
.ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
@@ -744,9 +710,9 @@ impl Global {
let stride = 3 * 4; // 3 integers, x/y/z group size
buffer_memory_init_actions.extend(
- indirect_buffer.initialization_status.read().create_action(
- indirect_buffer,
- offset..(offset + stride),
+ buffer.initialization_status.read().create_action(
+ buffer,
+ *offset..(*offset + stride),
MemoryInitKind::NeedsInitializedMemory,
),
);
@@ -756,15 +722,15 @@ impl Global {
raw,
&mut intermediate_trackers,
&*bind_group_guard,
- Some(indirect_buffer.as_info().tracker_index()),
+ Some(buffer.as_info().tracker_index()),
&snatch_guard,
)
.map_pass_err(scope)?;
unsafe {
- raw.dispatch_indirect(buf_raw, offset);
+ raw.dispatch_indirect(buf_raw, *offset);
}
}
- ComputeCommand::PushDebugGroup { color: _, len } => {
+ ArcComputeCommand::PushDebugGroup { color: _, len } => {
state.debug_scope_depth += 1;
if !discard_hal_labels {
let label =
@@ -776,7 +742,7 @@ impl Global {
}
string_offset += len;
}
- ComputeCommand::PopDebugGroup => {
+ ArcComputeCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup;
if state.debug_scope_depth == 0 {
@@ -790,7 +756,7 @@ impl Global {
}
}
}
- ComputeCommand::InsertDebugMarker { color: _, len } => {
+ ArcComputeCommand::InsertDebugMarker { color: _, len } => {
if !discard_hal_labels {
let label =
str::from_utf8(&base.string_data[string_offset..string_offset + len])
@@ -799,49 +765,43 @@ impl Global {
}
string_offset += len;
}
- ComputeCommand::WriteTimestamp {
- query_set_id,
+ ArcComputeCommand::WriteTimestamp {
+ query_set,
query_index,
} => {
+ let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::WriteTimestamp;
device
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
.map_pass_err(scope)?;
- let query_set: &resource::QuerySet = tracker
- .query_sets
- .add_single(&*query_set_guard, query_set_id)
- .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
- .map_pass_err(scope)?;
+ let query_set = tracker.query_sets.insert_single(query_set.clone());
query_set
- .validate_and_write_timestamp(raw, query_set_id, query_index, None)
+ .validate_and_write_timestamp(raw, query_set_id, *query_index, None)
.map_pass_err(scope)?;
}
- ComputeCommand::BeginPipelineStatisticsQuery {
- query_set_id,
+ ArcComputeCommand::BeginPipelineStatisticsQuery {
+ query_set,
query_index,
} => {
+ let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
- let query_set: &resource::QuerySet = tracker
- .query_sets
- .add_single(&*query_set_guard, query_set_id)
- .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
- .map_pass_err(scope)?;
+ let query_set = tracker.query_sets.insert_single(query_set.clone());
query_set
.validate_and_begin_pipeline_statistics_query(
raw,
query_set_id,
- query_index,
+ *query_index,
None,
&mut active_query,
)
.map_pass_err(scope)?;
}
- ComputeCommand::EndPipelineStatisticsQuery => {
+ ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
diff --git a/wgpu-core/src/command/compute_command.rs b/wgpu-core/src/command/compute_command.rs
new file mode 100644
index 000000000..49fdbbec2
--- /dev/null
+++ b/wgpu-core/src/command/compute_command.rs
@@ -0,0 +1,322 @@
+use std::sync::Arc;
+
+use crate::{
+ binding_model::BindGroup,
+ hal_api::HalApi,
+ id,
+ pipeline::ComputePipeline,
+ resource::{Buffer, QuerySet},
+};
+
+use super::{ComputePassError, ComputePassErrorInner, PassErrorScope};
+
+#[derive(Clone, Copy, Debug)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub enum ComputeCommand {
+ SetBindGroup {
+ index: u32,
+ num_dynamic_offsets: usize,
+ bind_group_id: id::BindGroupId,
+ },
+
+ SetPipeline(id::ComputePipelineId),
+
+ /// Set a range of push constants to values stored in `push_constant_data`.
+ SetPushConstant {
+ /// The byte offset within the push constant storage to write to. This
+ /// must be a multiple of four.
+ offset: u32,
+
+ /// The number of bytes to write. This must be a multiple of four.
+ size_bytes: u32,
+
+ /// Index in `push_constant_data` of the start of the data
+ /// to be written.
+ ///
+ /// Note: this is not a byte offset like `offset`. Rather, it is the
+ /// index of the first `u32` element in `push_constant_data` to read.
+ values_offset: u32,
+ },
+
+ Dispatch([u32; 3]),
+
+ DispatchIndirect {
+ buffer_id: id::BufferId,
+ offset: wgt::BufferAddress,
+ },
+
+ PushDebugGroup {
+ color: u32,
+ len: usize,
+ },
+
+ PopDebugGroup,
+
+ InsertDebugMarker {
+ color: u32,
+ len: usize,
+ },
+
+ WriteTimestamp {
+ query_set_id: id::QuerySetId,
+ query_index: u32,
+ },
+
+ BeginPipelineStatisticsQuery {
+ query_set_id: id::QuerySetId,
+ query_index: u32,
+ },
+
+ EndPipelineStatisticsQuery,
+}
+
+impl ComputeCommand {
+ /// Resolves all ids in a list of commands into the corresponding resource Arc.
+ ///
+ // TODO: Once resolving is done on-the-fly during recording, this function should be only needed with the replay feature:
+ // #[cfg(feature = "replay")]
+ pub fn resolve_compute_command_ids(
+ hub: &crate::hub::Hub,
+ commands: &[ComputeCommand],
+ ) -> Result>, ComputePassError> {
+ let buffers_guard = hub.buffers.read();
+ let bind_group_guard = hub.bind_groups.read();
+ let query_set_guard = hub.query_sets.read();
+ let pipelines_guard = hub.compute_pipelines.read();
+
+ let resolved_commands: Vec> = commands
+ .iter()
+ .map(|c| -> Result, ComputePassError> {
+ Ok(match *c {
+ ComputeCommand::SetBindGroup {
+ index,
+ num_dynamic_offsets,
+ bind_group_id,
+ } => ArcComputeCommand::SetBindGroup {
+ index,
+ num_dynamic_offsets,
+ bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::SetBindGroup(bind_group_id),
+ inner: ComputePassErrorInner::InvalidBindGroup(index),
+ }
+ })?,
+ },
+
+ ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
+ pipelines_guard
+ .get_owned(pipeline_id)
+ .map_err(|_| ComputePassError {
+ scope: PassErrorScope::SetPipelineCompute(pipeline_id),
+ inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
+ })?,
+ ),
+
+ ComputeCommand::SetPushConstant {
+ offset,
+ size_bytes,
+ values_offset,
+ } => ArcComputeCommand::SetPushConstant {
+ offset,
+ size_bytes,
+ values_offset,
+ },
+
+ ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
+
+ ComputeCommand::DispatchIndirect { buffer_id, offset } => {
+ ArcComputeCommand::DispatchIndirect {
+ buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::Dispatch {
+ indirect: true,
+ pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again.
+ },
+ inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
+ }
+ })?,
+ offset,
+ }
+ }
+
+ ComputeCommand::PushDebugGroup { color, len } => {
+ ArcComputeCommand::PushDebugGroup { color, len }
+ }
+
+ ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
+
+ ComputeCommand::InsertDebugMarker { color, len } => {
+ ArcComputeCommand::InsertDebugMarker { color, len }
+ }
+
+ ComputeCommand::WriteTimestamp {
+ query_set_id,
+ query_index,
+ } => ArcComputeCommand::WriteTimestamp {
+ query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::WriteTimestamp,
+ inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
+ }
+ })?,
+ query_index,
+ },
+
+ ComputeCommand::BeginPipelineStatisticsQuery {
+ query_set_id,
+ query_index,
+ } => ArcComputeCommand::BeginPipelineStatisticsQuery {
+ query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
+ ComputePassError {
+ scope: PassErrorScope::BeginPipelineStatisticsQuery,
+ inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
+ }
+ })?,
+ query_index,
+ },
+
+ ComputeCommand::EndPipelineStatisticsQuery => {
+ ArcComputeCommand::EndPipelineStatisticsQuery
+ }
+ })
+ })
+ .collect::, ComputePassError>>()?;
+ Ok(resolved_commands)
+ }
+}
+
+/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs.
+#[derive(Clone, Debug)]
+pub enum ArcComputeCommand {
+ SetBindGroup {
+ index: u32,
+ num_dynamic_offsets: usize,
+ bind_group: Arc>,
+ },
+
+ SetPipeline(Arc>),
+
+ /// Set a range of push constants to values stored in `push_constant_data`.
+ SetPushConstant {
+ /// The byte offset within the push constant storage to write to. This
+ /// must be a multiple of four.
+ offset: u32,
+
+ /// The number of bytes to write. This must be a multiple of four.
+ size_bytes: u32,
+
+ /// Index in `push_constant_data` of the start of the data
+ /// to be written.
+ ///
+ /// Note: this is not a byte offset like `offset`. Rather, it is the
+ /// index of the first `u32` element in `push_constant_data` to read.
+ values_offset: u32,
+ },
+
+ Dispatch([u32; 3]),
+
+ DispatchIndirect {
+ buffer: Arc>,
+ offset: wgt::BufferAddress,
+ },
+
+ PushDebugGroup {
+ color: u32,
+ len: usize,
+ },
+
+ PopDebugGroup,
+
+ InsertDebugMarker {
+ color: u32,
+ len: usize,
+ },
+
+ WriteTimestamp {
+ query_set: Arc>,
+ query_index: u32,
+ },
+
+ BeginPipelineStatisticsQuery {
+ query_set: Arc>,
+ query_index: u32,
+ },
+
+ EndPipelineStatisticsQuery,
+}
+
+#[cfg(feature = "trace")]
+impl From<&ArcComputeCommand> for ComputeCommand {
+ fn from(value: &ArcComputeCommand) -> Self {
+ use crate::resource::Resource as _;
+
+ match value {
+ ArcComputeCommand::SetBindGroup {
+ index,
+ num_dynamic_offsets,
+ bind_group,
+ } => ComputeCommand::SetBindGroup {
+ index: *index,
+ num_dynamic_offsets: *num_dynamic_offsets,
+ bind_group_id: bind_group.as_info().id(),
+ },
+
+ ArcComputeCommand::SetPipeline(pipeline) => {
+ ComputeCommand::SetPipeline(pipeline.as_info().id())
+ }
+
+ ArcComputeCommand::SetPushConstant {
+ offset,
+ size_bytes,
+ values_offset,
+ } => ComputeCommand::SetPushConstant {
+ offset: *offset,
+ size_bytes: *size_bytes,
+ values_offset: *values_offset,
+ },
+
+ ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim),
+
+ ArcComputeCommand::DispatchIndirect { buffer, offset } => {
+ ComputeCommand::DispatchIndirect {
+ buffer_id: buffer.as_info().id(),
+ offset: *offset,
+ }
+ }
+
+ ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup {
+ color: *color,
+ len: *len,
+ },
+
+ ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup,
+
+ ArcComputeCommand::InsertDebugMarker { color, len } => {
+ ComputeCommand::InsertDebugMarker {
+ color: *color,
+ len: *len,
+ }
+ }
+
+ ArcComputeCommand::WriteTimestamp {
+ query_set,
+ query_index,
+ } => ComputeCommand::WriteTimestamp {
+ query_set_id: query_set.as_info().id(),
+ query_index: *query_index,
+ },
+
+ ArcComputeCommand::BeginPipelineStatisticsQuery {
+ query_set,
+ query_index,
+ } => ComputeCommand::BeginPipelineStatisticsQuery {
+ query_set_id: query_set.as_info().id(),
+ query_index: *query_index,
+ },
+
+ ArcComputeCommand::EndPipelineStatisticsQuery => {
+ ComputeCommand::EndPipelineStatisticsQuery
+ }
+ }
+ }
+}
diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs
index e812b4f8f..d53f47bf4 100644
--- a/wgpu-core/src/command/mod.rs
+++ b/wgpu-core/src/command/mod.rs
@@ -3,6 +3,7 @@ mod bind;
mod bundle;
mod clear;
mod compute;
+mod compute_command;
mod draw;
mod memory_init;
mod query;
@@ -13,7 +14,8 @@ use std::sync::Arc;
pub(crate) use self::clear::clear_texture;
pub use self::{
- bundle::*, clear::ClearError, compute::*, draw::*, query::*, render::*, transfer::*,
+ bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*,
+ render::*, transfer::*,
};
pub(crate) use allocator::CommandAllocator;
diff --git a/wgpu-core/src/track/buffer.rs b/wgpu-core/src/track/buffer.rs
index 8ea92d484..9a52a5325 100644
--- a/wgpu-core/src/track/buffer.rs
+++ b/wgpu-core/src/track/buffer.rs
@@ -245,6 +245,22 @@ impl BufferUsageScope {
.get(id)
.map_err(|_| UsageConflict::BufferInvalid { id })?;
+ self.insert_merge_single(buffer.clone(), new_state)
+ .map(|_| buffer)
+ }
+
+ /// Merge a single state into the UsageScope, using an already resolved buffer.
+ ///
+ /// If the resulting state is invalid, returns a usage
+ /// conflict with the details of the invalid state.
+ ///
+ /// If the ID is higher than the length of internal vectors,
+ /// the vectors will be extended. A call to set_size is not needed.
+ pub fn insert_merge_single(
+ &mut self,
+ buffer: Arc>,
+ new_state: BufferUses,
+ ) -> Result<(), UsageConflict> {
let index = buffer.info.tracker_index().as_usize();
self.allow_index(index);
@@ -260,12 +276,12 @@ impl BufferUsageScope {
index,
BufferStateProvider::Direct { state: new_state },
ResourceMetadataProvider::Direct {
- resource: Cow::Owned(buffer.clone()),
+ resource: Cow::Owned(buffer),
},
)?;
}
- Ok(buffer)
+ Ok(())
}
}
diff --git a/wgpu-core/src/track/metadata.rs b/wgpu-core/src/track/metadata.rs
index 3e71e0e08..d6e8d6f90 100644
--- a/wgpu-core/src/track/metadata.rs
+++ b/wgpu-core/src/track/metadata.rs
@@ -87,16 +87,18 @@ impl ResourceMetadata {
/// Add the resource with the given index, epoch, and reference count to the
/// set.
///
+ /// Returns a reference to the newly inserted resource.
+ /// (This allows avoiding a clone/reference count increase in many cases.)
+ ///
/// # Safety
///
/// The given `index` must be in bounds for this `ResourceMetadata`'s
/// existing tables. See `tracker_assert_in_bounds`.
#[inline(always)]
- pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc) {
+ pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc) -> &Arc {
self.owned.set(index, true);
- unsafe {
- *self.resources.get_unchecked_mut(index) = Some(resource);
- }
+ let resource_dst = unsafe { self.resources.get_unchecked_mut(index) };
+ resource_dst.insert(resource)
}
/// Get the resource with the given index.
diff --git a/wgpu-core/src/track/stateless.rs b/wgpu-core/src/track/stateless.rs
index 26caf165c..25ffc027e 100644
--- a/wgpu-core/src/track/stateless.rs
+++ b/wgpu-core/src/track/stateless.rs
@@ -158,16 +158,17 @@ impl StatelessTracker {
///
/// If the ID is higher than the length of internal vectors,
/// the vectors will be extended. A call to set_size is not needed.
- pub fn insert_single(&mut self, resource: Arc) {
+ ///
+ /// Returns a reference to the newly inserted resource.
+ /// (This allows avoiding a clone/reference count increase in many cases.)
+ pub fn insert_single(&mut self, resource: Arc) -> &Arc {
let index = resource.as_info().tracker_index().as_usize();
self.allow_index(index);
self.tracker_assert_in_bounds(index);
- unsafe {
- self.metadata.insert(index, resource);
- }
+ unsafe { self.metadata.insert(index, resource) }
}
/// Adds the given resource to the tracker.