put all state in State

This commit is contained in:
teoxoy 2024-06-25 13:59:09 +02:00 committed by Teodor Tanasoaia
parent 83a8699be7
commit d3eed4920b

View File

@ -9,12 +9,12 @@ use crate::{
CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr, PassErrorScope,
QueryUseError, StateChange,
},
device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
error::{ErrorFormatter, PrettyError},
global::Global,
hal_api::HalApi,
hal_label, id,
init_tracker::MemoryInitKind,
init_tracker::{BufferInitTrackerAction, MemoryInitKind},
resource::{self, DestroyedResourceError, MissingBufferUsageError, ParentDevice, Resource},
snatch::SnatchGuard,
track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope},
@ -33,7 +33,7 @@ use wgt::{BufferAddress, DynamicOffset};
use std::sync::Arc;
use std::{fmt, mem, str};
use super::DynComputePass;
use super::{memory_init::CommandBufferTextureMemoryActions, DynComputePass};
pub struct ComputePass<A: HalApi> {
/// All pass data & records is stored here.
@ -248,14 +248,32 @@ where
}
}
struct State<'a, A: HalApi> {
struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> {
binder: Binder<A>,
pipeline: Option<id::ComputePipelineId>,
scope: UsageScope<'a, A>,
scope: UsageScope<'scope, A>,
debug_scope_depth: u32,
snatch_guard: SnatchGuard<'snatch_guard>,
device: &'cmd_buf Arc<Device<A>>,
tracker: &'cmd_buf mut Tracker<A>,
buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction<A>>,
texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions<A>,
temp_offsets: Vec<u32>,
dynamic_offset_count: usize,
string_offset: usize,
active_query: Option<(Arc<resource::QuerySet<A>>, u32)>,
intermediate_trackers: Tracker<A>,
/// Immediate texture inits required because of prior discards. Need to
/// be inserted before texture reads.
pending_discard_init_fixups: SurfacesInDiscardState<A>,
}
impl<'a, A: HalApi> State<'a, A> {
impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'cmd_buf, A> {
fn is_ready(&self) -> Result<(), DispatchError> {
let bind_mask = self.binder.invalid_mask();
if bind_mask != 0 {
@ -280,9 +298,7 @@ impl<'a, A: HalApi> State<'a, A> {
fn flush_states(
&mut self,
raw_encoder: &mut A::CommandEncoder,
base_trackers: &mut Tracker<A>,
indirect_buffer: Option<TrackerIndex>,
snatch_guard: &SnatchGuard,
) -> Result<(), ResourceUsageCompatibilityError> {
for bind_group in self.binder.list_active() {
unsafe { self.scope.merge_bind_group(&bind_group.used)? };
@ -292,21 +308,25 @@ impl<'a, A: HalApi> State<'a, A> {
for bind_group in self.binder.list_active() {
unsafe {
base_trackers
self.intermediate_trackers
.set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used)
}
}
// Add the state of the indirect buffer if it hasn't been hit before.
unsafe {
base_trackers
self.intermediate_trackers
.buffers
.set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
}
log::trace!("Encoding dispatch barriers");
CommandBuffer::drain_barriers(raw_encoder, base_trackers, snatch_guard);
CommandBuffer::drain_barriers(
raw_encoder,
&mut self.intermediate_trackers,
&self.snatch_guard,
);
Ok(())
}
}
@ -474,9 +494,6 @@ impl Global {
let encoder = &mut cmd_buf_data.encoder;
let status = &mut cmd_buf_data.status;
let tracker = &mut cmd_buf_data.trackers;
let buffer_memory_init_actions = &mut cmd_buf_data.buffer_memory_init_actions;
let texture_memory_actions = &mut cmd_buf_data.texture_memory_actions;
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
@ -491,25 +508,39 @@ impl Global {
pipeline: None,
scope: device.new_usage_scope(),
debug_scope_depth: 0,
snatch_guard: device.snatchable_lock.read(),
device,
tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
temp_offsets: Vec::new(),
dynamic_offset_count: 0,
string_offset: 0,
active_query: None,
intermediate_trackers: Tracker::new(),
pending_discard_init_fixups: SurfacesInDiscardState::new(),
};
let mut temp_offsets = Vec::new();
let mut dynamic_offset_count = 0;
let mut string_offset = 0;
let mut active_query = None;
let snatch_guard = device.snatchable_lock.read();
let indices = &device.tracker_indices;
tracker.buffers.set_size(indices.buffers.size());
tracker.textures.set_size(indices.textures.size());
tracker.bind_groups.set_size(indices.bind_groups.size());
tracker
let indices = &state.device.tracker_indices;
state.tracker.buffers.set_size(indices.buffers.size());
state.tracker.textures.set_size(indices.textures.size());
state
.tracker
.bind_groups
.set_size(indices.bind_groups.size());
state
.tracker
.compute_pipelines
.set_size(indices.compute_pipelines.size());
tracker.query_sets.set_size(indices.query_sets.size());
state.tracker.query_sets.set_size(indices.query_sets.size());
let timestamp_writes = if let Some(tw) = timestamp_writes.take() {
let query_set = tracker.query_sets.insert_single(tw.query_set);
let query_set = state.tracker.query_sets.insert_single(tw.query_set);
// Unlike in render passes we can't delay resetting the query sets since
// there is no auxiliary pass.
@ -539,10 +570,6 @@ impl Global {
None
};
let discard_hal_labels = self
.instance
.flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
let hal_desc = hal::ComputePassDescriptor {
label: hal_label(base.label.as_deref(), self.instance.flags),
timestamp_writes,
@ -552,12 +579,6 @@ impl Global {
raw.begin_compute_pass(&hal_desc);
}
let mut intermediate_trackers = Tracker::<A>::new();
// Immediate texture inits required because of prior discards. Need to
// be inserted before texture reads.
let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
// TODO: We should be draining the commands here, avoiding extra copies in the process.
// (A command encoder can't be executed twice!)
for command in base.commands {
@ -580,19 +601,19 @@ impl Global {
.map_pass_err(scope);
}
temp_offsets.clear();
temp_offsets.extend_from_slice(
&base.dynamic_offsets
[dynamic_offset_count..dynamic_offset_count + num_dynamic_offsets],
state.temp_offsets.clear();
state.temp_offsets.extend_from_slice(
&base.dynamic_offsets[state.dynamic_offset_count
..state.dynamic_offset_count + num_dynamic_offsets],
);
dynamic_offset_count += num_dynamic_offsets;
state.dynamic_offset_count += num_dynamic_offsets;
let bind_group = tracker.bind_groups.insert_single(bind_group);
let bind_group = state.tracker.bind_groups.insert_single(bind_group);
bind_group
.validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
.validate_dynamic_bindings(index, &state.temp_offsets, &cmd_buf.limits)
.map_pass_err(scope)?;
buffer_memory_init_actions.extend(
state.buffer_memory_init_actions.extend(
bind_group.used_buffer_ranges.iter().filter_map(|action| {
action
.buffer
@ -603,20 +624,22 @@ impl Global {
);
for action in bind_group.used_texture_ranges.iter() {
pending_discard_init_fixups
.extend(texture_memory_actions.register_init_action(action));
state
.pending_discard_init_fixups
.extend(state.texture_memory_actions.register_init_action(action));
}
let pipeline_layout = state.binder.pipeline_layout.clone();
let entries =
state
.binder
.assign_group(index as usize, bind_group, &temp_offsets);
.assign_group(index as usize, bind_group, &state.temp_offsets);
if !entries.is_empty() && pipeline_layout.is_some() {
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(&snatch_guard).map_pass_err(scope)?;
let raw_bg =
group.try_raw(&state.snatch_guard).map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
pipeline_layout,
@ -636,7 +659,7 @@ impl Global {
state.pipeline = Some(pipeline.as_info().id());
let pipeline = tracker.compute_pipelines.insert_single(pipeline);
let pipeline = state.tracker.compute_pipelines.insert_single(pipeline);
unsafe {
raw.set_compute_pipeline(pipeline.raw());
@ -659,7 +682,7 @@ impl Global {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg =
group.try_raw(&snatch_guard).map_pass_err(scope)?;
group.try_raw(&state.snatch_guard).map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
pipeline.layout.raw(),
@ -741,9 +764,7 @@ impl Global {
};
state.is_ready().map_pass_err(scope)?;
state
.flush_states(raw, &mut intermediate_trackers, None, &snatch_guard)
.map_pass_err(scope)?;
state.flush_states(raw, None).map_pass_err(scope)?;
let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
@ -774,7 +795,8 @@ impl Global {
state.is_ready().map_pass_err(scope)?;
device
state
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
.map_pass_err(scope)?;
@ -797,11 +819,9 @@ impl Global {
.map_pass_err(scope);
}
let buf_raw = buffer.try_raw(&snatch_guard).map_pass_err(scope)?;
let stride = 3 * 4; // 3 integers, x/y/z group size
buffer_memory_init_actions.extend(
state.buffer_memory_init_actions.extend(
buffer.initialization_status.read().create_action(
&buffer,
offset..(offset + stride),
@ -810,28 +830,30 @@ impl Global {
);
state
.flush_states(
raw,
&mut intermediate_trackers,
Some(buffer.as_info().tracker_index()),
&snatch_guard,
)
.flush_states(raw, Some(buffer.as_info().tracker_index()))
.map_pass_err(scope)?;
let buf_raw = buffer.try_raw(&state.snatch_guard).map_pass_err(scope)?;
unsafe {
raw.dispatch_indirect(buf_raw, offset);
}
}
ArcComputeCommand::PushDebugGroup { color: _, len } => {
state.debug_scope_depth += 1;
if !discard_hal_labels {
let label =
str::from_utf8(&base.string_data[string_offset..string_offset + len])
.unwrap();
if !state
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
let label = str::from_utf8(
&base.string_data[state.string_offset..state.string_offset + len],
)
.unwrap();
unsafe {
raw.begin_debug_marker(label);
}
}
string_offset += len;
state.string_offset += len;
}
ArcComputeCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup;
@ -841,20 +863,29 @@ impl Global {
.map_pass_err(scope);
}
state.debug_scope_depth -= 1;
if !discard_hal_labels {
if !state
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
unsafe {
raw.end_debug_marker();
}
}
}
ArcComputeCommand::InsertDebugMarker { color: _, len } => {
if !discard_hal_labels {
let label =
str::from_utf8(&base.string_data[string_offset..string_offset + len])
.unwrap();
if !state
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
let label = str::from_utf8(
&base.string_data[state.string_offset..state.string_offset + len],
)
.unwrap();
unsafe { raw.insert_debug_marker(label) }
}
string_offset += len;
state.string_offset += len;
}
ArcComputeCommand::WriteTimestamp {
query_set,
@ -864,11 +895,12 @@ impl Global {
query_set.same_device_as(cmd_buf).map_pass_err(scope)?;
device
state
.device
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
.map_pass_err(scope)?;
let query_set = tracker.query_sets.insert_single(query_set);
let query_set = state.tracker.query_sets.insert_single(query_set);
query_set
.validate_and_write_timestamp(raw, query_index, None)
@ -882,20 +914,21 @@ impl Global {
query_set.same_device_as(cmd_buf).map_pass_err(scope)?;
let query_set = tracker.query_sets.insert_single(query_set);
let query_set = state.tracker.query_sets.insert_single(query_set);
validate_and_begin_pipeline_statistics_query(
query_set.clone(),
raw,
query_index,
None,
&mut active_query,
&mut state.active_query,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(raw, &mut active_query).map_pass_err(scope)?;
end_pipeline_statistics_query(raw, &mut state.active_query)
.map_pass_err(scope)?;
}
}
}
@ -916,17 +949,17 @@ impl Global {
// Use that buffer to insert barriers and clear discarded images.
let transit = encoder.open().map_pass_err(pass_scope)?;
fixup_discarded_surfaces(
pending_discard_init_fixups.into_iter(),
state.pending_discard_init_fixups.into_iter(),
transit,
&mut tracker.textures,
device,
&snatch_guard,
&mut state.tracker.textures,
state.device,
&state.snatch_guard,
);
CommandBuffer::insert_barriers_from_tracker(
transit,
tracker,
&intermediate_trackers,
&snatch_guard,
state.tracker,
&state.intermediate_trackers,
&state.snatch_guard,
);
// Close the command buffer, and swap it with the previous.
encoder.close_and_swap().map_pass_err(pass_scope)?;