From d3eed4920b9eb364448bf96b09c7c69449d1d904 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 25 Jun 2024 13:59:09 +0200 Subject: [PATCH] put all state in `State` --- wgpu-core/src/command/compute.rs | 207 ++++++++++++++++++------------- 1 file changed, 120 insertions(+), 87 deletions(-) diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 19dd1b0da..4c6e005d4 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -9,12 +9,12 @@ use crate::{ CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange, }, - device::{DeviceError, MissingDownlevelFlags, MissingFeatures}, + device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures}, error::{ErrorFormatter, PrettyError}, global::Global, hal_api::HalApi, hal_label, id, - init_tracker::MemoryInitKind, + init_tracker::{BufferInitTrackerAction, MemoryInitKind}, resource::{self, DestroyedResourceError, MissingBufferUsageError, ParentDevice, Resource}, snatch::SnatchGuard, track::{ResourceUsageCompatibilityError, Tracker, TrackerIndex, UsageScope}, @@ -33,7 +33,7 @@ use wgt::{BufferAddress, DynamicOffset}; use std::sync::Arc; use std::{fmt, mem, str}; -use super::DynComputePass; +use super::{memory_init::CommandBufferTextureMemoryActions, DynComputePass}; pub struct ComputePass { /// All pass data & records is stored here. @@ -248,14 +248,32 @@ where } } -struct State<'a, A: HalApi> { +struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> { binder: Binder, pipeline: Option, - scope: UsageScope<'a, A>, + scope: UsageScope<'scope, A>, debug_scope_depth: u32, + + snatch_guard: SnatchGuard<'snatch_guard>, + + device: &'cmd_buf Arc>, + tracker: &'cmd_buf mut Tracker, + buffer_memory_init_actions: &'cmd_buf mut Vec>, + texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions, + + temp_offsets: Vec, + dynamic_offset_count: usize, + string_offset: usize, + active_query: Option<(Arc>, u32)>, + + intermediate_trackers: Tracker, + + /// Immediate texture inits required because of prior discards. Need to + /// be inserted before texture reads. + pending_discard_init_fixups: SurfacesInDiscardState, } -impl<'a, A: HalApi> State<'a, A> { +impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'cmd_buf, A> { fn is_ready(&self) -> Result<(), DispatchError> { let bind_mask = self.binder.invalid_mask(); if bind_mask != 0 { @@ -280,9 +298,7 @@ impl<'a, A: HalApi> State<'a, A> { fn flush_states( &mut self, raw_encoder: &mut A::CommandEncoder, - base_trackers: &mut Tracker, indirect_buffer: Option, - snatch_guard: &SnatchGuard, ) -> Result<(), ResourceUsageCompatibilityError> { for bind_group in self.binder.list_active() { unsafe { self.scope.merge_bind_group(&bind_group.used)? }; @@ -292,21 +308,25 @@ impl<'a, A: HalApi> State<'a, A> { for bind_group in self.binder.list_active() { unsafe { - base_trackers + self.intermediate_trackers .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used) } } // Add the state of the indirect buffer if it hasn't been hit before. unsafe { - base_trackers + self.intermediate_trackers .buffers .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer); } log::trace!("Encoding dispatch barriers"); - CommandBuffer::drain_barriers(raw_encoder, base_trackers, snatch_guard); + CommandBuffer::drain_barriers( + raw_encoder, + &mut self.intermediate_trackers, + &self.snatch_guard, + ); Ok(()) } } @@ -474,9 +494,6 @@ impl Global { let encoder = &mut cmd_buf_data.encoder; let status = &mut cmd_buf_data.status; - let tracker = &mut cmd_buf_data.trackers; - let buffer_memory_init_actions = &mut cmd_buf_data.buffer_memory_init_actions; - let texture_memory_actions = &mut cmd_buf_data.texture_memory_actions; // We automatically keep extending command buffers over time, and because // we want to insert a command buffer _before_ what we're about to record, @@ -491,25 +508,39 @@ impl Global { pipeline: None, scope: device.new_usage_scope(), debug_scope_depth: 0, + + snatch_guard: device.snatchable_lock.read(), + + device, + tracker: &mut cmd_buf_data.trackers, + buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions, + texture_memory_actions: &mut cmd_buf_data.texture_memory_actions, + + temp_offsets: Vec::new(), + dynamic_offset_count: 0, + string_offset: 0, + active_query: None, + + intermediate_trackers: Tracker::new(), + + pending_discard_init_fixups: SurfacesInDiscardState::new(), }; - let mut temp_offsets = Vec::new(); - let mut dynamic_offset_count = 0; - let mut string_offset = 0; - let mut active_query = None; - let snatch_guard = device.snatchable_lock.read(); - - let indices = &device.tracker_indices; - tracker.buffers.set_size(indices.buffers.size()); - tracker.textures.set_size(indices.textures.size()); - tracker.bind_groups.set_size(indices.bind_groups.size()); - tracker + let indices = &state.device.tracker_indices; + state.tracker.buffers.set_size(indices.buffers.size()); + state.tracker.textures.set_size(indices.textures.size()); + state + .tracker + .bind_groups + .set_size(indices.bind_groups.size()); + state + .tracker .compute_pipelines .set_size(indices.compute_pipelines.size()); - tracker.query_sets.set_size(indices.query_sets.size()); + state.tracker.query_sets.set_size(indices.query_sets.size()); let timestamp_writes = if let Some(tw) = timestamp_writes.take() { - let query_set = tracker.query_sets.insert_single(tw.query_set); + let query_set = state.tracker.query_sets.insert_single(tw.query_set); // Unlike in render passes we can't delay resetting the query sets since // there is no auxiliary pass. @@ -539,10 +570,6 @@ impl Global { None }; - let discard_hal_labels = self - .instance - .flags - .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS); let hal_desc = hal::ComputePassDescriptor { label: hal_label(base.label.as_deref(), self.instance.flags), timestamp_writes, @@ -552,12 +579,6 @@ impl Global { raw.begin_compute_pass(&hal_desc); } - let mut intermediate_trackers = Tracker::::new(); - - // Immediate texture inits required because of prior discards. Need to - // be inserted before texture reads. - let mut pending_discard_init_fixups = SurfacesInDiscardState::new(); - // TODO: We should be draining the commands here, avoiding extra copies in the process. // (A command encoder can't be executed twice!) for command in base.commands { @@ -580,19 +601,19 @@ impl Global { .map_pass_err(scope); } - temp_offsets.clear(); - temp_offsets.extend_from_slice( - &base.dynamic_offsets - [dynamic_offset_count..dynamic_offset_count + num_dynamic_offsets], + state.temp_offsets.clear(); + state.temp_offsets.extend_from_slice( + &base.dynamic_offsets[state.dynamic_offset_count + ..state.dynamic_offset_count + num_dynamic_offsets], ); - dynamic_offset_count += num_dynamic_offsets; + state.dynamic_offset_count += num_dynamic_offsets; - let bind_group = tracker.bind_groups.insert_single(bind_group); + let bind_group = state.tracker.bind_groups.insert_single(bind_group); bind_group - .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits) + .validate_dynamic_bindings(index, &state.temp_offsets, &cmd_buf.limits) .map_pass_err(scope)?; - buffer_memory_init_actions.extend( + state.buffer_memory_init_actions.extend( bind_group.used_buffer_ranges.iter().filter_map(|action| { action .buffer @@ -603,20 +624,22 @@ impl Global { ); for action in bind_group.used_texture_ranges.iter() { - pending_discard_init_fixups - .extend(texture_memory_actions.register_init_action(action)); + state + .pending_discard_init_fixups + .extend(state.texture_memory_actions.register_init_action(action)); } let pipeline_layout = state.binder.pipeline_layout.clone(); let entries = state .binder - .assign_group(index as usize, bind_group, &temp_offsets); + .assign_group(index as usize, bind_group, &state.temp_offsets); if !entries.is_empty() && pipeline_layout.is_some() { let pipeline_layout = pipeline_layout.as_ref().unwrap().raw(); for (i, e) in entries.iter().enumerate() { if let Some(group) = e.group.as_ref() { - let raw_bg = group.try_raw(&snatch_guard).map_pass_err(scope)?; + let raw_bg = + group.try_raw(&state.snatch_guard).map_pass_err(scope)?; unsafe { raw.set_bind_group( pipeline_layout, @@ -636,7 +659,7 @@ impl Global { state.pipeline = Some(pipeline.as_info().id()); - let pipeline = tracker.compute_pipelines.insert_single(pipeline); + let pipeline = state.tracker.compute_pipelines.insert_single(pipeline); unsafe { raw.set_compute_pipeline(pipeline.raw()); @@ -659,7 +682,7 @@ impl Global { for (i, e) in entries.iter().enumerate() { if let Some(group) = e.group.as_ref() { let raw_bg = - group.try_raw(&snatch_guard).map_pass_err(scope)?; + group.try_raw(&state.snatch_guard).map_pass_err(scope)?; unsafe { raw.set_bind_group( pipeline.layout.raw(), @@ -741,9 +764,7 @@ impl Global { }; state.is_ready().map_pass_err(scope)?; - state - .flush_states(raw, &mut intermediate_trackers, None, &snatch_guard) - .map_pass_err(scope)?; + state.flush_states(raw, None).map_pass_err(scope)?; let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension; @@ -774,7 +795,8 @@ impl Global { state.is_ready().map_pass_err(scope)?; - device + state + .device .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION) .map_pass_err(scope)?; @@ -797,11 +819,9 @@ impl Global { .map_pass_err(scope); } - let buf_raw = buffer.try_raw(&snatch_guard).map_pass_err(scope)?; - let stride = 3 * 4; // 3 integers, x/y/z group size - buffer_memory_init_actions.extend( + state.buffer_memory_init_actions.extend( buffer.initialization_status.read().create_action( &buffer, offset..(offset + stride), @@ -810,28 +830,30 @@ impl Global { ); state - .flush_states( - raw, - &mut intermediate_trackers, - Some(buffer.as_info().tracker_index()), - &snatch_guard, - ) + .flush_states(raw, Some(buffer.as_info().tracker_index())) .map_pass_err(scope)?; + + let buf_raw = buffer.try_raw(&state.snatch_guard).map_pass_err(scope)?; unsafe { raw.dispatch_indirect(buf_raw, offset); } } ArcComputeCommand::PushDebugGroup { color: _, len } => { state.debug_scope_depth += 1; - if !discard_hal_labels { - let label = - str::from_utf8(&base.string_data[string_offset..string_offset + len]) - .unwrap(); + if !state + .device + .instance_flags + .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS) + { + let label = str::from_utf8( + &base.string_data[state.string_offset..state.string_offset + len], + ) + .unwrap(); unsafe { raw.begin_debug_marker(label); } } - string_offset += len; + state.string_offset += len; } ArcComputeCommand::PopDebugGroup => { let scope = PassErrorScope::PopDebugGroup; @@ -841,20 +863,29 @@ impl Global { .map_pass_err(scope); } state.debug_scope_depth -= 1; - if !discard_hal_labels { + if !state + .device + .instance_flags + .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS) + { unsafe { raw.end_debug_marker(); } } } ArcComputeCommand::InsertDebugMarker { color: _, len } => { - if !discard_hal_labels { - let label = - str::from_utf8(&base.string_data[string_offset..string_offset + len]) - .unwrap(); + if !state + .device + .instance_flags + .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS) + { + let label = str::from_utf8( + &base.string_data[state.string_offset..state.string_offset + len], + ) + .unwrap(); unsafe { raw.insert_debug_marker(label) } } - string_offset += len; + state.string_offset += len; } ArcComputeCommand::WriteTimestamp { query_set, @@ -864,11 +895,12 @@ impl Global { query_set.same_device_as(cmd_buf).map_pass_err(scope)?; - device + state + .device .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES) .map_pass_err(scope)?; - let query_set = tracker.query_sets.insert_single(query_set); + let query_set = state.tracker.query_sets.insert_single(query_set); query_set .validate_and_write_timestamp(raw, query_index, None) @@ -882,20 +914,21 @@ impl Global { query_set.same_device_as(cmd_buf).map_pass_err(scope)?; - let query_set = tracker.query_sets.insert_single(query_set); + let query_set = state.tracker.query_sets.insert_single(query_set); validate_and_begin_pipeline_statistics_query( query_set.clone(), raw, query_index, None, - &mut active_query, + &mut state.active_query, ) .map_pass_err(scope)?; } ArcComputeCommand::EndPipelineStatisticsQuery => { let scope = PassErrorScope::EndPipelineStatisticsQuery; - end_pipeline_statistics_query(raw, &mut active_query).map_pass_err(scope)?; + end_pipeline_statistics_query(raw, &mut state.active_query) + .map_pass_err(scope)?; } } } @@ -916,17 +949,17 @@ impl Global { // Use that buffer to insert barriers and clear discarded images. let transit = encoder.open().map_pass_err(pass_scope)?; fixup_discarded_surfaces( - pending_discard_init_fixups.into_iter(), + state.pending_discard_init_fixups.into_iter(), transit, - &mut tracker.textures, - device, - &snatch_guard, + &mut state.tracker.textures, + state.device, + &state.snatch_guard, ); CommandBuffer::insert_barriers_from_tracker( transit, - tracker, - &intermediate_trackers, - &snatch_guard, + state.tracker, + &state.intermediate_trackers, + &state.snatch_guard, ); // Close the command buffer, and swap it with the previous. encoder.close_and_swap().map_pass_err(pass_scope)?;