mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-21 22:33:49 +00:00
Separate out ComputeCommand id->arc resolve (a step towards no lifetimes on wgpu::ComputePass
) (#5432)
* move out compute command to separate module * introduce ArcComputeCommand * stateless tracker now returns reference to arc upon insertion * add insert_merge_single to buffer tracker * compute pass execution now works internally with an ArcComputeCommand * compute pass execution now translates Command to ArcCommand ahead of time * don't clone commands in compute pass execution * remove doc hiding * use option insert * clippy fix * fix private doc issue * remove unnecessary copied over doc hide
This commit is contained in:
parent
b3c5a6fb84
commit
edf1a86148
@ -99,7 +99,7 @@ impl GlobalPlay for wgc::global::Global {
|
||||
base,
|
||||
timestamp_writes,
|
||||
} => {
|
||||
self.command_encoder_run_compute_pass_impl::<A>(
|
||||
self.command_encoder_run_compute_pass_with_unresolved_commands::<A>(
|
||||
encoder,
|
||||
base.as_ref(),
|
||||
timestamp_writes.as_ref(),
|
||||
|
@ -1,3 +1,4 @@
|
||||
use crate::command::compute_command::{ArcComputeCommand, ComputeCommand};
|
||||
use crate::device::DeviceError;
|
||||
use crate::resource::Resource;
|
||||
use crate::snatch::SnatchGuard;
|
||||
@ -20,7 +21,6 @@ use crate::{
|
||||
hal_label, id,
|
||||
id::DeviceId,
|
||||
init_tracker::MemoryInitKind,
|
||||
pipeline,
|
||||
resource::{self},
|
||||
storage::Storage,
|
||||
track::{Tracker, UsageConflict, UsageScope},
|
||||
@ -39,59 +39,6 @@ use thiserror::Error;
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, mem, str};
|
||||
|
||||
#[doc(hidden)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum ComputeCommand {
|
||||
SetBindGroup {
|
||||
index: u32,
|
||||
num_dynamic_offsets: usize,
|
||||
bind_group_id: id::BindGroupId,
|
||||
},
|
||||
SetPipeline(id::ComputePipelineId),
|
||||
|
||||
/// Set a range of push constants to values stored in [`BasePass::push_constant_data`].
|
||||
SetPushConstant {
|
||||
/// The byte offset within the push constant storage to write to. This
|
||||
/// must be a multiple of four.
|
||||
offset: u32,
|
||||
|
||||
/// The number of bytes to write. This must be a multiple of four.
|
||||
size_bytes: u32,
|
||||
|
||||
/// Index in [`BasePass::push_constant_data`] of the start of the data
|
||||
/// to be written.
|
||||
///
|
||||
/// Note: this is not a byte offset like `offset`. Rather, it is the
|
||||
/// index of the first `u32` element in `push_constant_data` to read.
|
||||
values_offset: u32,
|
||||
},
|
||||
|
||||
Dispatch([u32; 3]),
|
||||
DispatchIndirect {
|
||||
buffer_id: id::BufferId,
|
||||
offset: wgt::BufferAddress,
|
||||
},
|
||||
PushDebugGroup {
|
||||
color: u32,
|
||||
len: usize,
|
||||
},
|
||||
PopDebugGroup,
|
||||
InsertDebugMarker {
|
||||
color: u32,
|
||||
len: usize,
|
||||
},
|
||||
WriteTimestamp {
|
||||
query_set_id: id::QuerySetId,
|
||||
query_index: u32,
|
||||
},
|
||||
BeginPipelineStatisticsQuery {
|
||||
query_set_id: id::QuerySetId,
|
||||
query_index: u32,
|
||||
},
|
||||
EndPipelineStatisticsQuery,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct ComputePass {
|
||||
base: BasePass<ComputeCommand>,
|
||||
@ -185,7 +132,7 @@ pub enum ComputePassErrorInner {
|
||||
#[error(transparent)]
|
||||
Encoder(#[from] CommandEncoderError),
|
||||
#[error("Bind group at index {0:?} is invalid")]
|
||||
InvalidBindGroup(usize),
|
||||
InvalidBindGroup(u32),
|
||||
#[error("Device {0:?} is invalid")]
|
||||
InvalidDevice(DeviceId),
|
||||
#[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
|
||||
@ -250,7 +197,7 @@ impl PrettyError for ComputePassErrorInner {
|
||||
pub struct ComputePassError {
|
||||
pub scope: PassErrorScope,
|
||||
#[source]
|
||||
inner: ComputePassErrorInner,
|
||||
pub(super) inner: ComputePassErrorInner,
|
||||
}
|
||||
impl PrettyError for ComputePassError {
|
||||
fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
|
||||
@ -347,7 +294,8 @@ impl Global {
|
||||
encoder_id: id::CommandEncoderId,
|
||||
pass: &ComputePass,
|
||||
) -> Result<(), ComputePassError> {
|
||||
self.command_encoder_run_compute_pass_impl::<A>(
|
||||
// TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally.
|
||||
self.command_encoder_run_compute_pass_with_unresolved_commands::<A>(
|
||||
encoder_id,
|
||||
pass.base.as_ref(),
|
||||
pass.timestamp_writes.as_ref(),
|
||||
@ -355,11 +303,33 @@ impl Global {
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn command_encoder_run_compute_pass_impl<A: HalApi>(
|
||||
pub fn command_encoder_run_compute_pass_with_unresolved_commands<A: HalApi>(
|
||||
&self,
|
||||
encoder_id: id::CommandEncoderId,
|
||||
base: BasePassRef<ComputeCommand>,
|
||||
timestamp_writes: Option<&ComputePassTimestampWrites>,
|
||||
) -> Result<(), ComputePassError> {
|
||||
let resolved_commands =
|
||||
ComputeCommand::resolve_compute_command_ids(A::hub(self), base.commands)?;
|
||||
|
||||
self.command_encoder_run_compute_pass_impl::<A>(
|
||||
encoder_id,
|
||||
BasePassRef {
|
||||
label: base.label,
|
||||
commands: &resolved_commands,
|
||||
dynamic_offsets: base.dynamic_offsets,
|
||||
string_data: base.string_data,
|
||||
push_constant_data: base.push_constant_data,
|
||||
},
|
||||
timestamp_writes,
|
||||
)
|
||||
}
|
||||
|
||||
fn command_encoder_run_compute_pass_impl<A: HalApi>(
|
||||
&self,
|
||||
encoder_id: id::CommandEncoderId,
|
||||
base: BasePassRef<ArcComputeCommand<A>>,
|
||||
timestamp_writes: Option<&ComputePassTimestampWrites>,
|
||||
) -> Result<(), ComputePassError> {
|
||||
profiling::scope!("CommandEncoder::run_compute_pass");
|
||||
let pass_scope = PassErrorScope::Pass(encoder_id);
|
||||
@ -382,7 +352,13 @@ impl Global {
|
||||
#[cfg(feature = "trace")]
|
||||
if let Some(ref mut list) = cmd_buf_data.commands {
|
||||
list.push(crate::device::trace::Command::RunComputePass {
|
||||
base: BasePass::from_ref(base),
|
||||
base: BasePass {
|
||||
label: base.label.map(str::to_string),
|
||||
commands: base.commands.iter().map(Into::into).collect(),
|
||||
dynamic_offsets: base.dynamic_offsets.to_vec(),
|
||||
string_data: base.string_data.to_vec(),
|
||||
push_constant_data: base.push_constant_data.to_vec(),
|
||||
},
|
||||
timestamp_writes: timestamp_writes.cloned(),
|
||||
});
|
||||
}
|
||||
@ -402,9 +378,7 @@ impl Global {
|
||||
let raw = encoder.open().map_pass_err(pass_scope)?;
|
||||
|
||||
let bind_group_guard = hub.bind_groups.read();
|
||||
let pipeline_guard = hub.compute_pipelines.read();
|
||||
let query_set_guard = hub.query_sets.read();
|
||||
let buffer_guard = hub.buffers.read();
|
||||
|
||||
let mut state = State {
|
||||
binder: Binder::new(),
|
||||
@ -482,19 +456,21 @@ impl Global {
|
||||
// be inserted before texture reads.
|
||||
let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
|
||||
|
||||
// TODO: We should be draining the commands here, avoiding extra copies in the process.
|
||||
// (A command encoder can't be executed twice!)
|
||||
for command in base.commands {
|
||||
match *command {
|
||||
ComputeCommand::SetBindGroup {
|
||||
match command {
|
||||
ArcComputeCommand::SetBindGroup {
|
||||
index,
|
||||
num_dynamic_offsets,
|
||||
bind_group_id,
|
||||
bind_group,
|
||||
} => {
|
||||
let scope = PassErrorScope::SetBindGroup(bind_group_id);
|
||||
let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id());
|
||||
|
||||
let max_bind_groups = cmd_buf.limits.max_bind_groups;
|
||||
if index >= max_bind_groups {
|
||||
if index >= &max_bind_groups {
|
||||
return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
|
||||
index,
|
||||
index: *index,
|
||||
max: max_bind_groups,
|
||||
})
|
||||
.map_pass_err(scope);
|
||||
@ -507,13 +483,9 @@ impl Global {
|
||||
);
|
||||
dynamic_offset_count += num_dynamic_offsets;
|
||||
|
||||
let bind_group = tracker
|
||||
.bind_groups
|
||||
.add_single(&*bind_group_guard, bind_group_id)
|
||||
.ok_or(ComputePassErrorInner::InvalidBindGroup(index as usize))
|
||||
.map_pass_err(scope)?;
|
||||
let bind_group = tracker.bind_groups.insert_single(bind_group.clone());
|
||||
bind_group
|
||||
.validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
|
||||
.validate_dynamic_bindings(*index, &temp_offsets, &cmd_buf.limits)
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
buffer_memory_init_actions.extend(
|
||||
@ -535,14 +507,14 @@ impl Global {
|
||||
let entries =
|
||||
state
|
||||
.binder
|
||||
.assign_group(index as usize, bind_group, &temp_offsets);
|
||||
.assign_group(*index as usize, bind_group, &temp_offsets);
|
||||
if !entries.is_empty() && pipeline_layout.is_some() {
|
||||
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
|
||||
for (i, e) in entries.iter().enumerate() {
|
||||
if let Some(group) = e.group.as_ref() {
|
||||
let raw_bg = group
|
||||
.raw(&snatch_guard)
|
||||
.ok_or(ComputePassErrorInner::InvalidBindGroup(i))
|
||||
.ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
|
||||
.map_pass_err(scope)?;
|
||||
unsafe {
|
||||
raw.set_bind_group(
|
||||
@ -556,16 +528,13 @@ impl Global {
|
||||
}
|
||||
}
|
||||
}
|
||||
ComputeCommand::SetPipeline(pipeline_id) => {
|
||||
ArcComputeCommand::SetPipeline(pipeline) => {
|
||||
let pipeline_id = pipeline.as_info().id();
|
||||
let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
|
||||
|
||||
state.pipeline = Some(pipeline_id);
|
||||
|
||||
let pipeline: &pipeline::ComputePipeline<A> = tracker
|
||||
.compute_pipelines
|
||||
.add_single(&*pipeline_guard, pipeline_id)
|
||||
.ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id))
|
||||
.map_pass_err(scope)?;
|
||||
tracker.compute_pipelines.insert_single(pipeline.clone());
|
||||
|
||||
unsafe {
|
||||
raw.set_compute_pipeline(pipeline.raw());
|
||||
@ -589,7 +558,7 @@ impl Global {
|
||||
if let Some(group) = e.group.as_ref() {
|
||||
let raw_bg = group
|
||||
.raw(&snatch_guard)
|
||||
.ok_or(ComputePassErrorInner::InvalidBindGroup(i))
|
||||
.ok_or(ComputePassErrorInner::InvalidBindGroup(i as u32))
|
||||
.map_pass_err(scope)?;
|
||||
unsafe {
|
||||
raw.set_bind_group(
|
||||
@ -625,7 +594,7 @@ impl Global {
|
||||
}
|
||||
}
|
||||
}
|
||||
ComputeCommand::SetPushConstant {
|
||||
ArcComputeCommand::SetPushConstant {
|
||||
offset,
|
||||
size_bytes,
|
||||
values_offset,
|
||||
@ -636,7 +605,7 @@ impl Global {
|
||||
let values_end_offset =
|
||||
(values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
|
||||
let data_slice =
|
||||
&base.push_constant_data[(values_offset as usize)..values_end_offset];
|
||||
&base.push_constant_data[(*values_offset as usize)..values_end_offset];
|
||||
|
||||
let pipeline_layout = state
|
||||
.binder
|
||||
@ -651,7 +620,7 @@ impl Global {
|
||||
pipeline_layout
|
||||
.validate_push_constant_ranges(
|
||||
wgt::ShaderStages::COMPUTE,
|
||||
offset,
|
||||
*offset,
|
||||
end_offset_bytes,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
@ -660,12 +629,12 @@ impl Global {
|
||||
raw.set_push_constants(
|
||||
pipeline_layout.raw(),
|
||||
wgt::ShaderStages::COMPUTE,
|
||||
offset,
|
||||
*offset,
|
||||
data_slice,
|
||||
);
|
||||
}
|
||||
}
|
||||
ComputeCommand::Dispatch(groups) => {
|
||||
ArcComputeCommand::Dispatch(groups) => {
|
||||
let scope = PassErrorScope::Dispatch {
|
||||
indirect: false,
|
||||
pipeline: state.pipeline,
|
||||
@ -690,7 +659,7 @@ impl Global {
|
||||
{
|
||||
return Err(ComputePassErrorInner::Dispatch(
|
||||
DispatchError::InvalidGroupSize {
|
||||
current: groups,
|
||||
current: *groups,
|
||||
limit: groups_size_limit,
|
||||
},
|
||||
))
|
||||
@ -698,10 +667,11 @@ impl Global {
|
||||
}
|
||||
|
||||
unsafe {
|
||||
raw.dispatch(groups);
|
||||
raw.dispatch(*groups);
|
||||
}
|
||||
}
|
||||
ComputeCommand::DispatchIndirect { buffer_id, offset } => {
|
||||
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
|
||||
let buffer_id = buffer.as_info().id();
|
||||
let scope = PassErrorScope::Dispatch {
|
||||
indirect: true,
|
||||
pipeline: state.pipeline,
|
||||
@ -713,29 +683,25 @@ impl Global {
|
||||
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
let indirect_buffer = state
|
||||
state
|
||||
.scope
|
||||
.buffers
|
||||
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT)
|
||||
.insert_merge_single(buffer.clone(), hal::BufferUses::INDIRECT)
|
||||
.map_pass_err(scope)?;
|
||||
check_buffer_usage(
|
||||
buffer_id,
|
||||
indirect_buffer.usage,
|
||||
wgt::BufferUsages::INDIRECT,
|
||||
)
|
||||
check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT)
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
|
||||
if end_offset > indirect_buffer.size {
|
||||
if end_offset > buffer.size {
|
||||
return Err(ComputePassErrorInner::IndirectBufferOverrun {
|
||||
offset,
|
||||
offset: *offset,
|
||||
end_offset,
|
||||
buffer_size: indirect_buffer.size,
|
||||
buffer_size: buffer.size,
|
||||
})
|
||||
.map_pass_err(scope);
|
||||
}
|
||||
|
||||
let buf_raw = indirect_buffer
|
||||
let buf_raw = buffer
|
||||
.raw
|
||||
.get(&snatch_guard)
|
||||
.ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
|
||||
@ -744,9 +710,9 @@ impl Global {
|
||||
let stride = 3 * 4; // 3 integers, x/y/z group size
|
||||
|
||||
buffer_memory_init_actions.extend(
|
||||
indirect_buffer.initialization_status.read().create_action(
|
||||
indirect_buffer,
|
||||
offset..(offset + stride),
|
||||
buffer.initialization_status.read().create_action(
|
||||
buffer,
|
||||
*offset..(*offset + stride),
|
||||
MemoryInitKind::NeedsInitializedMemory,
|
||||
),
|
||||
);
|
||||
@ -756,15 +722,15 @@ impl Global {
|
||||
raw,
|
||||
&mut intermediate_trackers,
|
||||
&*bind_group_guard,
|
||||
Some(indirect_buffer.as_info().tracker_index()),
|
||||
Some(buffer.as_info().tracker_index()),
|
||||
&snatch_guard,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
unsafe {
|
||||
raw.dispatch_indirect(buf_raw, offset);
|
||||
raw.dispatch_indirect(buf_raw, *offset);
|
||||
}
|
||||
}
|
||||
ComputeCommand::PushDebugGroup { color: _, len } => {
|
||||
ArcComputeCommand::PushDebugGroup { color: _, len } => {
|
||||
state.debug_scope_depth += 1;
|
||||
if !discard_hal_labels {
|
||||
let label =
|
||||
@ -776,7 +742,7 @@ impl Global {
|
||||
}
|
||||
string_offset += len;
|
||||
}
|
||||
ComputeCommand::PopDebugGroup => {
|
||||
ArcComputeCommand::PopDebugGroup => {
|
||||
let scope = PassErrorScope::PopDebugGroup;
|
||||
|
||||
if state.debug_scope_depth == 0 {
|
||||
@ -790,7 +756,7 @@ impl Global {
|
||||
}
|
||||
}
|
||||
}
|
||||
ComputeCommand::InsertDebugMarker { color: _, len } => {
|
||||
ArcComputeCommand::InsertDebugMarker { color: _, len } => {
|
||||
if !discard_hal_labels {
|
||||
let label =
|
||||
str::from_utf8(&base.string_data[string_offset..string_offset + len])
|
||||
@ -799,49 +765,43 @@ impl Global {
|
||||
}
|
||||
string_offset += len;
|
||||
}
|
||||
ComputeCommand::WriteTimestamp {
|
||||
query_set_id,
|
||||
ArcComputeCommand::WriteTimestamp {
|
||||
query_set,
|
||||
query_index,
|
||||
} => {
|
||||
let query_set_id = query_set.as_info().id();
|
||||
let scope = PassErrorScope::WriteTimestamp;
|
||||
|
||||
device
|
||||
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
|
||||
.map_pass_err(scope)?;
|
||||
|
||||
let query_set: &resource::QuerySet<A> = tracker
|
||||
.query_sets
|
||||
.add_single(&*query_set_guard, query_set_id)
|
||||
.ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
||||
.map_pass_err(scope)?;
|
||||
let query_set = tracker.query_sets.insert_single(query_set.clone());
|
||||
|
||||
query_set
|
||||
.validate_and_write_timestamp(raw, query_set_id, query_index, None)
|
||||
.validate_and_write_timestamp(raw, query_set_id, *query_index, None)
|
||||
.map_pass_err(scope)?;
|
||||
}
|
||||
ComputeCommand::BeginPipelineStatisticsQuery {
|
||||
query_set_id,
|
||||
ArcComputeCommand::BeginPipelineStatisticsQuery {
|
||||
query_set,
|
||||
query_index,
|
||||
} => {
|
||||
let query_set_id = query_set.as_info().id();
|
||||
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
|
||||
|
||||
let query_set: &resource::QuerySet<A> = tracker
|
||||
.query_sets
|
||||
.add_single(&*query_set_guard, query_set_id)
|
||||
.ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
|
||||
.map_pass_err(scope)?;
|
||||
let query_set = tracker.query_sets.insert_single(query_set.clone());
|
||||
|
||||
query_set
|
||||
.validate_and_begin_pipeline_statistics_query(
|
||||
raw,
|
||||
query_set_id,
|
||||
query_index,
|
||||
*query_index,
|
||||
None,
|
||||
&mut active_query,
|
||||
)
|
||||
.map_pass_err(scope)?;
|
||||
}
|
||||
ComputeCommand::EndPipelineStatisticsQuery => {
|
||||
ArcComputeCommand::EndPipelineStatisticsQuery => {
|
||||
let scope = PassErrorScope::EndPipelineStatisticsQuery;
|
||||
|
||||
end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
|
||||
|
322
wgpu-core/src/command/compute_command.rs
Normal file
322
wgpu-core/src/command/compute_command.rs
Normal file
@ -0,0 +1,322 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
binding_model::BindGroup,
|
||||
hal_api::HalApi,
|
||||
id,
|
||||
pipeline::ComputePipeline,
|
||||
resource::{Buffer, QuerySet},
|
||||
};
|
||||
|
||||
use super::{ComputePassError, ComputePassErrorInner, PassErrorScope};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum ComputeCommand {
|
||||
SetBindGroup {
|
||||
index: u32,
|
||||
num_dynamic_offsets: usize,
|
||||
bind_group_id: id::BindGroupId,
|
||||
},
|
||||
|
||||
SetPipeline(id::ComputePipelineId),
|
||||
|
||||
/// Set a range of push constants to values stored in `push_constant_data`.
|
||||
SetPushConstant {
|
||||
/// The byte offset within the push constant storage to write to. This
|
||||
/// must be a multiple of four.
|
||||
offset: u32,
|
||||
|
||||
/// The number of bytes to write. This must be a multiple of four.
|
||||
size_bytes: u32,
|
||||
|
||||
/// Index in `push_constant_data` of the start of the data
|
||||
/// to be written.
|
||||
///
|
||||
/// Note: this is not a byte offset like `offset`. Rather, it is the
|
||||
/// index of the first `u32` element in `push_constant_data` to read.
|
||||
values_offset: u32,
|
||||
},
|
||||
|
||||
Dispatch([u32; 3]),
|
||||
|
||||
DispatchIndirect {
|
||||
buffer_id: id::BufferId,
|
||||
offset: wgt::BufferAddress,
|
||||
},
|
||||
|
||||
PushDebugGroup {
|
||||
color: u32,
|
||||
len: usize,
|
||||
},
|
||||
|
||||
PopDebugGroup,
|
||||
|
||||
InsertDebugMarker {
|
||||
color: u32,
|
||||
len: usize,
|
||||
},
|
||||
|
||||
WriteTimestamp {
|
||||
query_set_id: id::QuerySetId,
|
||||
query_index: u32,
|
||||
},
|
||||
|
||||
BeginPipelineStatisticsQuery {
|
||||
query_set_id: id::QuerySetId,
|
||||
query_index: u32,
|
||||
},
|
||||
|
||||
EndPipelineStatisticsQuery,
|
||||
}
|
||||
|
||||
impl ComputeCommand {
|
||||
/// Resolves all ids in a list of commands into the corresponding resource Arc.
|
||||
///
|
||||
// TODO: Once resolving is done on-the-fly during recording, this function should be only needed with the replay feature:
|
||||
// #[cfg(feature = "replay")]
|
||||
pub fn resolve_compute_command_ids<A: HalApi>(
|
||||
hub: &crate::hub::Hub<A>,
|
||||
commands: &[ComputeCommand],
|
||||
) -> Result<Vec<ArcComputeCommand<A>>, ComputePassError> {
|
||||
let buffers_guard = hub.buffers.read();
|
||||
let bind_group_guard = hub.bind_groups.read();
|
||||
let query_set_guard = hub.query_sets.read();
|
||||
let pipelines_guard = hub.compute_pipelines.read();
|
||||
|
||||
let resolved_commands: Vec<ArcComputeCommand<A>> = commands
|
||||
.iter()
|
||||
.map(|c| -> Result<ArcComputeCommand<A>, ComputePassError> {
|
||||
Ok(match *c {
|
||||
ComputeCommand::SetBindGroup {
|
||||
index,
|
||||
num_dynamic_offsets,
|
||||
bind_group_id,
|
||||
} => ArcComputeCommand::SetBindGroup {
|
||||
index,
|
||||
num_dynamic_offsets,
|
||||
bind_group: bind_group_guard.get_owned(bind_group_id).map_err(|_| {
|
||||
ComputePassError {
|
||||
scope: PassErrorScope::SetBindGroup(bind_group_id),
|
||||
inner: ComputePassErrorInner::InvalidBindGroup(index),
|
||||
}
|
||||
})?,
|
||||
},
|
||||
|
||||
ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline(
|
||||
pipelines_guard
|
||||
.get_owned(pipeline_id)
|
||||
.map_err(|_| ComputePassError {
|
||||
scope: PassErrorScope::SetPipelineCompute(pipeline_id),
|
||||
inner: ComputePassErrorInner::InvalidPipeline(pipeline_id),
|
||||
})?,
|
||||
),
|
||||
|
||||
ComputeCommand::SetPushConstant {
|
||||
offset,
|
||||
size_bytes,
|
||||
values_offset,
|
||||
} => ArcComputeCommand::SetPushConstant {
|
||||
offset,
|
||||
size_bytes,
|
||||
values_offset,
|
||||
},
|
||||
|
||||
ComputeCommand::Dispatch(dim) => ArcComputeCommand::Dispatch(dim),
|
||||
|
||||
ComputeCommand::DispatchIndirect { buffer_id, offset } => {
|
||||
ArcComputeCommand::DispatchIndirect {
|
||||
buffer: buffers_guard.get_owned(buffer_id).map_err(|_| {
|
||||
ComputePassError {
|
||||
scope: PassErrorScope::Dispatch {
|
||||
indirect: true,
|
||||
pipeline: None, // TODO: not used right now, but once we do the resolve during recording we can use this again.
|
||||
},
|
||||
inner: ComputePassErrorInner::InvalidBuffer(buffer_id),
|
||||
}
|
||||
})?,
|
||||
offset,
|
||||
}
|
||||
}
|
||||
|
||||
ComputeCommand::PushDebugGroup { color, len } => {
|
||||
ArcComputeCommand::PushDebugGroup { color, len }
|
||||
}
|
||||
|
||||
ComputeCommand::PopDebugGroup => ArcComputeCommand::PopDebugGroup,
|
||||
|
||||
ComputeCommand::InsertDebugMarker { color, len } => {
|
||||
ArcComputeCommand::InsertDebugMarker { color, len }
|
||||
}
|
||||
|
||||
ComputeCommand::WriteTimestamp {
|
||||
query_set_id,
|
||||
query_index,
|
||||
} => ArcComputeCommand::WriteTimestamp {
|
||||
query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
|
||||
ComputePassError {
|
||||
scope: PassErrorScope::WriteTimestamp,
|
||||
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
|
||||
}
|
||||
})?,
|
||||
query_index,
|
||||
},
|
||||
|
||||
ComputeCommand::BeginPipelineStatisticsQuery {
|
||||
query_set_id,
|
||||
query_index,
|
||||
} => ArcComputeCommand::BeginPipelineStatisticsQuery {
|
||||
query_set: query_set_guard.get_owned(query_set_id).map_err(|_| {
|
||||
ComputePassError {
|
||||
scope: PassErrorScope::BeginPipelineStatisticsQuery,
|
||||
inner: ComputePassErrorInner::InvalidQuerySet(query_set_id),
|
||||
}
|
||||
})?,
|
||||
query_index,
|
||||
},
|
||||
|
||||
ComputeCommand::EndPipelineStatisticsQuery => {
|
||||
ArcComputeCommand::EndPipelineStatisticsQuery
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, ComputePassError>>()?;
|
||||
Ok(resolved_commands)
|
||||
}
|
||||
}
|
||||
|
||||
/// Equivalent to `ComputeCommand` but the Ids resolved into resource Arcs.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ArcComputeCommand<A: HalApi> {
|
||||
SetBindGroup {
|
||||
index: u32,
|
||||
num_dynamic_offsets: usize,
|
||||
bind_group: Arc<BindGroup<A>>,
|
||||
},
|
||||
|
||||
SetPipeline(Arc<ComputePipeline<A>>),
|
||||
|
||||
/// Set a range of push constants to values stored in `push_constant_data`.
|
||||
SetPushConstant {
|
||||
/// The byte offset within the push constant storage to write to. This
|
||||
/// must be a multiple of four.
|
||||
offset: u32,
|
||||
|
||||
/// The number of bytes to write. This must be a multiple of four.
|
||||
size_bytes: u32,
|
||||
|
||||
/// Index in `push_constant_data` of the start of the data
|
||||
/// to be written.
|
||||
///
|
||||
/// Note: this is not a byte offset like `offset`. Rather, it is the
|
||||
/// index of the first `u32` element in `push_constant_data` to read.
|
||||
values_offset: u32,
|
||||
},
|
||||
|
||||
Dispatch([u32; 3]),
|
||||
|
||||
DispatchIndirect {
|
||||
buffer: Arc<Buffer<A>>,
|
||||
offset: wgt::BufferAddress,
|
||||
},
|
||||
|
||||
PushDebugGroup {
|
||||
color: u32,
|
||||
len: usize,
|
||||
},
|
||||
|
||||
PopDebugGroup,
|
||||
|
||||
InsertDebugMarker {
|
||||
color: u32,
|
||||
len: usize,
|
||||
},
|
||||
|
||||
WriteTimestamp {
|
||||
query_set: Arc<QuerySet<A>>,
|
||||
query_index: u32,
|
||||
},
|
||||
|
||||
BeginPipelineStatisticsQuery {
|
||||
query_set: Arc<QuerySet<A>>,
|
||||
query_index: u32,
|
||||
},
|
||||
|
||||
EndPipelineStatisticsQuery,
|
||||
}
|
||||
|
||||
#[cfg(feature = "trace")]
|
||||
impl<A: HalApi> From<&ArcComputeCommand<A>> for ComputeCommand {
|
||||
fn from(value: &ArcComputeCommand<A>) -> Self {
|
||||
use crate::resource::Resource as _;
|
||||
|
||||
match value {
|
||||
ArcComputeCommand::SetBindGroup {
|
||||
index,
|
||||
num_dynamic_offsets,
|
||||
bind_group,
|
||||
} => ComputeCommand::SetBindGroup {
|
||||
index: *index,
|
||||
num_dynamic_offsets: *num_dynamic_offsets,
|
||||
bind_group_id: bind_group.as_info().id(),
|
||||
},
|
||||
|
||||
ArcComputeCommand::SetPipeline(pipeline) => {
|
||||
ComputeCommand::SetPipeline(pipeline.as_info().id())
|
||||
}
|
||||
|
||||
ArcComputeCommand::SetPushConstant {
|
||||
offset,
|
||||
size_bytes,
|
||||
values_offset,
|
||||
} => ComputeCommand::SetPushConstant {
|
||||
offset: *offset,
|
||||
size_bytes: *size_bytes,
|
||||
values_offset: *values_offset,
|
||||
},
|
||||
|
||||
ArcComputeCommand::Dispatch(dim) => ComputeCommand::Dispatch(*dim),
|
||||
|
||||
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
|
||||
ComputeCommand::DispatchIndirect {
|
||||
buffer_id: buffer.as_info().id(),
|
||||
offset: *offset,
|
||||
}
|
||||
}
|
||||
|
||||
ArcComputeCommand::PushDebugGroup { color, len } => ComputeCommand::PushDebugGroup {
|
||||
color: *color,
|
||||
len: *len,
|
||||
},
|
||||
|
||||
ArcComputeCommand::PopDebugGroup => ComputeCommand::PopDebugGroup,
|
||||
|
||||
ArcComputeCommand::InsertDebugMarker { color, len } => {
|
||||
ComputeCommand::InsertDebugMarker {
|
||||
color: *color,
|
||||
len: *len,
|
||||
}
|
||||
}
|
||||
|
||||
ArcComputeCommand::WriteTimestamp {
|
||||
query_set,
|
||||
query_index,
|
||||
} => ComputeCommand::WriteTimestamp {
|
||||
query_set_id: query_set.as_info().id(),
|
||||
query_index: *query_index,
|
||||
},
|
||||
|
||||
ArcComputeCommand::BeginPipelineStatisticsQuery {
|
||||
query_set,
|
||||
query_index,
|
||||
} => ComputeCommand::BeginPipelineStatisticsQuery {
|
||||
query_set_id: query_set.as_info().id(),
|
||||
query_index: *query_index,
|
||||
},
|
||||
|
||||
ArcComputeCommand::EndPipelineStatisticsQuery => {
|
||||
ComputeCommand::EndPipelineStatisticsQuery
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ mod bind;
|
||||
mod bundle;
|
||||
mod clear;
|
||||
mod compute;
|
||||
mod compute_command;
|
||||
mod draw;
|
||||
mod memory_init;
|
||||
mod query;
|
||||
@ -13,7 +14,8 @@ use std::sync::Arc;
|
||||
|
||||
pub(crate) use self::clear::clear_texture;
|
||||
pub use self::{
|
||||
bundle::*, clear::ClearError, compute::*, draw::*, query::*, render::*, transfer::*,
|
||||
bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*,
|
||||
render::*, transfer::*,
|
||||
};
|
||||
pub(crate) use allocator::CommandAllocator;
|
||||
|
||||
|
@ -245,6 +245,22 @@ impl<A: HalApi> BufferUsageScope<A> {
|
||||
.get(id)
|
||||
.map_err(|_| UsageConflict::BufferInvalid { id })?;
|
||||
|
||||
self.insert_merge_single(buffer.clone(), new_state)
|
||||
.map(|_| buffer)
|
||||
}
|
||||
|
||||
/// Merge a single state into the UsageScope, using an already resolved buffer.
|
||||
///
|
||||
/// If the resulting state is invalid, returns a usage
|
||||
/// conflict with the details of the invalid state.
|
||||
///
|
||||
/// If the ID is higher than the length of internal vectors,
|
||||
/// the vectors will be extended. A call to set_size is not needed.
|
||||
pub fn insert_merge_single(
|
||||
&mut self,
|
||||
buffer: Arc<Buffer<A>>,
|
||||
new_state: BufferUses,
|
||||
) -> Result<(), UsageConflict> {
|
||||
let index = buffer.info.tracker_index().as_usize();
|
||||
|
||||
self.allow_index(index);
|
||||
@ -260,12 +276,12 @@ impl<A: HalApi> BufferUsageScope<A> {
|
||||
index,
|
||||
BufferStateProvider::Direct { state: new_state },
|
||||
ResourceMetadataProvider::Direct {
|
||||
resource: Cow::Owned(buffer.clone()),
|
||||
resource: Cow::Owned(buffer),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,16 +87,18 @@ impl<T: Resource> ResourceMetadata<T> {
|
||||
/// Add the resource with the given index, epoch, and reference count to the
|
||||
/// set.
|
||||
///
|
||||
/// Returns a reference to the newly inserted resource.
|
||||
/// (This allows avoiding a clone/reference count increase in many cases.)
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The given `index` must be in bounds for this `ResourceMetadata`'s
|
||||
/// existing tables. See `tracker_assert_in_bounds`.
|
||||
#[inline(always)]
|
||||
pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc<T>) {
|
||||
pub(super) unsafe fn insert(&mut self, index: usize, resource: Arc<T>) -> &Arc<T> {
|
||||
self.owned.set(index, true);
|
||||
unsafe {
|
||||
*self.resources.get_unchecked_mut(index) = Some(resource);
|
||||
}
|
||||
let resource_dst = unsafe { self.resources.get_unchecked_mut(index) };
|
||||
resource_dst.insert(resource)
|
||||
}
|
||||
|
||||
/// Get the resource with the given index.
|
||||
|
@ -158,16 +158,17 @@ impl<T: Resource> StatelessTracker<T> {
|
||||
///
|
||||
/// If the ID is higher than the length of internal vectors,
|
||||
/// the vectors will be extended. A call to set_size is not needed.
|
||||
pub fn insert_single(&mut self, resource: Arc<T>) {
|
||||
///
|
||||
/// Returns a reference to the newly inserted resource.
|
||||
/// (This allows avoiding a clone/reference count increase in many cases.)
|
||||
pub fn insert_single(&mut self, resource: Arc<T>) -> &Arc<T> {
|
||||
let index = resource.as_info().tracker_index().as_usize();
|
||||
|
||||
self.allow_index(index);
|
||||
|
||||
self.tracker_assert_in_bounds(index);
|
||||
|
||||
unsafe {
|
||||
self.metadata.insert(index, resource);
|
||||
}
|
||||
unsafe { self.metadata.insert(index, resource) }
|
||||
}
|
||||
|
||||
/// Adds the given resource to the tracker.
|
||||
|
Loading…
Reference in New Issue
Block a user