From f0f61d9bb664437974a9ab66d2cec661ae765edb Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 25 Jun 2024 15:22:13 +0200 Subject: [PATCH] move the raw encoder in `State` --- wgpu-core/src/command/compute.rs | 121 +++++++++++++++---------------- 1 file changed, 57 insertions(+), 64 deletions(-) diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 428783105..217cde115 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -252,7 +252,7 @@ where } } -struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> { +struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A: HalApi> { binder: Binder, pipeline: Option, scope: UsageScope<'scope, A>, @@ -261,6 +261,9 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> { snatch_guard: SnatchGuard<'snatch_guard>, device: &'cmd_buf Arc>, + + raw_encoder: &'raw_encoder mut A::CommandEncoder, + tracker: &'cmd_buf mut Tracker, buffer_memory_init_actions: &'cmd_buf mut Vec>, texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions, @@ -277,7 +280,9 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> { pending_discard_init_fixups: SurfacesInDiscardState, } -impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'cmd_buf, A> { +impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A: HalApi> + State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder, A> +{ fn is_ready(&self) -> Result<(), DispatchError> { let bind_mask = self.binder.invalid_mask(); if bind_mask != 0 { @@ -301,7 +306,6 @@ impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'c // part of the usage scope. fn flush_states( &mut self, - raw_encoder: &mut A::CommandEncoder, indirect_buffer: Option, ) -> Result<(), ResourceUsageCompatibilityError> { for bind_group in self.binder.list_active() { @@ -327,7 +331,7 @@ impl<'scope, 'snatch_guard, 'cmd_buf, A: HalApi> State<'scope, 'snatch_guard, 'c log::trace!("Encoding dispatch barriers"); CommandBuffer::drain_barriers( - raw_encoder, + self.raw_encoder, &mut self.intermediate_trackers, &self.snatch_guard, ); @@ -505,7 +509,7 @@ impl Global { encoder.close().map_pass_err(pass_scope)?; // will be reset to true if recording is done without errors *status = CommandEncoderStatus::Error; - let raw = encoder.open().map_pass_err(pass_scope)?; + let raw_encoder = encoder.open().map_pass_err(pass_scope)?; let mut state = State { binder: Binder::new(), @@ -516,6 +520,7 @@ impl Global { snatch_guard: device.snatchable_lock.read(), device, + raw_encoder, tracker: &mut cmd_buf_data.trackers, buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions, texture_memory_actions: &mut cmd_buf_data.texture_memory_actions, @@ -561,7 +566,9 @@ impl Global { // But no point in erroring over that nuance here! if let Some(range) = range { unsafe { - raw.reset_queries(query_set.raw.as_ref().unwrap(), range); + state + .raw_encoder + .reset_queries(query_set.raw.as_ref().unwrap(), range); } } @@ -580,7 +587,7 @@ impl Global { }; unsafe { - raw.begin_compute_pass(&hal_desc); + state.raw_encoder.begin_compute_pass(&hal_desc); } // TODO: We should be draining the commands here, avoiding extra copies in the process. @@ -595,7 +602,6 @@ impl Global { let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id()); set_bind_group( &mut state, - raw, cmd_buf, &base.dynamic_offsets, index, @@ -606,7 +612,7 @@ impl Global { } ArcComputeCommand::SetPipeline(pipeline) => { let scope = PassErrorScope::SetPipelineCompute(pipeline.as_info().id()); - set_pipeline(&mut state, raw, cmd_buf, pipeline).map_pass_err(scope)?; + set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?; } ArcComputeCommand::SetPushConstant { offset, @@ -615,8 +621,7 @@ impl Global { } => { let scope = PassErrorScope::SetPushConstant; set_push_constant( - &state, - raw, + &mut state, &base.push_constant_data, offset, size_bytes, @@ -629,32 +634,31 @@ impl Global { indirect: false, pipeline: state.pipeline, }; - dispatch(&mut state, raw, groups).map_pass_err(scope)?; + dispatch(&mut state, groups).map_pass_err(scope)?; } ArcComputeCommand::DispatchIndirect { buffer, offset } => { let scope = PassErrorScope::Dispatch { indirect: true, pipeline: state.pipeline, }; - dispatch_indirect(&mut state, raw, cmd_buf, buffer, offset) - .map_pass_err(scope)?; + dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?; } ArcComputeCommand::PushDebugGroup { color: _, len } => { - push_debug_group(&mut state, raw, &base.string_data, len); + push_debug_group(&mut state, &base.string_data, len); } ArcComputeCommand::PopDebugGroup => { let scope = PassErrorScope::PopDebugGroup; - pop_debug_group(&mut state, raw).map_pass_err(scope)?; + pop_debug_group(&mut state).map_pass_err(scope)?; } ArcComputeCommand::InsertDebugMarker { color: _, len } => { - insert_debug_marker(&mut state, raw, &base.string_data, len); + insert_debug_marker(&mut state, &base.string_data, len); } ArcComputeCommand::WriteTimestamp { query_set, query_index, } => { let scope = PassErrorScope::WriteTimestamp; - write_timestamp(&mut state, raw, cmd_buf, query_set, query_index) + write_timestamp(&mut state, cmd_buf, query_set, query_index) .map_pass_err(scope)?; } ArcComputeCommand::BeginPipelineStatisticsQuery { @@ -664,7 +668,7 @@ impl Global { let scope = PassErrorScope::BeginPipelineStatisticsQuery; validate_and_begin_pipeline_statistics_query( query_set, - raw, + state.raw_encoder, &mut state.tracker.query_sets, cmd_buf, query_index, @@ -675,20 +679,28 @@ impl Global { } ArcComputeCommand::EndPipelineStatisticsQuery => { let scope = PassErrorScope::EndPipelineStatisticsQuery; - end_pipeline_statistics_query(raw, &mut state.active_query) + end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query) .map_pass_err(scope)?; } } } unsafe { - raw.end_compute_pass(); + state.raw_encoder.end_compute_pass(); } // We've successfully recorded the compute pass, bring the // command buffer out of the error state. *status = CommandEncoderStatus::Recording; + let State { + snatch_guard, + tracker, + intermediate_trackers, + pending_discard_init_fixups, + .. + } = state; + // Stop the current command buffer. encoder.close().map_pass_err(pass_scope)?; @@ -697,17 +709,17 @@ impl Global { // Use that buffer to insert barriers and clear discarded images. let transit = encoder.open().map_pass_err(pass_scope)?; fixup_discarded_surfaces( - state.pending_discard_init_fixups.into_iter(), + pending_discard_init_fixups.into_iter(), transit, - &mut state.tracker.textures, - state.device, - &state.snatch_guard, + &mut tracker.textures, + device, + &snatch_guard, ); CommandBuffer::insert_barriers_from_tracker( transit, - state.tracker, - &state.intermediate_trackers, - &state.snatch_guard, + tracker, + &intermediate_trackers, + &snatch_guard, ); // Close the command buffer, and swap it with the previous. encoder.close_and_swap().map_pass_err(pass_scope)?; @@ -718,7 +730,6 @@ impl Global { fn set_bind_group( state: &mut State, - raw: &mut A::CommandEncoder, cmd_buf: &CommandBuffer, dynamic_offsets: &[DynamicOffset], index: u32, @@ -771,7 +782,7 @@ fn set_bind_group( if let Some(group) = e.group.as_ref() { let raw_bg = group.try_raw(&state.snatch_guard)?; unsafe { - raw.set_bind_group( + state.raw_encoder.set_bind_group( pipeline_layout, index + i as u32, raw_bg, @@ -786,7 +797,6 @@ fn set_bind_group( fn set_pipeline( state: &mut State, - raw: &mut A::CommandEncoder, cmd_buf: &CommandBuffer, pipeline: Arc>, ) -> Result<(), ComputePassErrorInner> { @@ -797,7 +807,7 @@ fn set_pipeline( let pipeline = state.tracker.compute_pipelines.insert_single(pipeline); unsafe { - raw.set_compute_pipeline(pipeline.raw()); + state.raw_encoder.set_compute_pipeline(pipeline.raw()); } // Rebind resources @@ -817,7 +827,7 @@ fn set_pipeline( if let Some(group) = e.group.as_ref() { let raw_bg = group.try_raw(&state.snatch_guard)?; unsafe { - raw.set_bind_group( + state.raw_encoder.set_bind_group( pipeline.layout.raw(), start_index as u32 + i as u32, raw_bg, @@ -835,7 +845,7 @@ fn set_pipeline( let offset = range.range.start; let size_bytes = range.range.end - offset; super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe { - raw.set_push_constants( + state.raw_encoder.set_push_constants( pipeline.layout.raw(), wgt::ShaderStages::COMPUTE, clear_offset, @@ -848,8 +858,7 @@ fn set_pipeline( } fn set_push_constant( - state: &State, - raw: &mut A::CommandEncoder, + state: &mut State, push_constant_data: &[u32], offset: u32, size_bytes: u32, @@ -875,7 +884,7 @@ fn set_push_constant( )?; unsafe { - raw.set_push_constants( + state.raw_encoder.set_push_constants( pipeline_layout.raw(), wgt::ShaderStages::COMPUTE, offset, @@ -887,12 +896,11 @@ fn set_push_constant( fn dispatch( state: &mut State, - raw: &mut A::CommandEncoder, groups: [u32; 3], ) -> Result<(), ComputePassErrorInner> { state.is_ready()?; - state.flush_states(raw, None)?; + state.flush_states(None)?; let groups_size_limit = state.device.limits.max_compute_workgroups_per_dimension; @@ -909,14 +917,13 @@ fn dispatch( } unsafe { - raw.dispatch(groups); + state.raw_encoder.dispatch(groups); } Ok(()) } fn dispatch_indirect( state: &mut State, - raw: &mut A::CommandEncoder, cmd_buf: &CommandBuffer, buffer: Arc>, offset: u64, @@ -954,21 +961,16 @@ fn dispatch_indirect( MemoryInitKind::NeedsInitializedMemory, )); - state.flush_states(raw, Some(buffer.as_info().tracker_index()))?; + state.flush_states(Some(buffer.as_info().tracker_index()))?; let buf_raw = buffer.try_raw(&state.snatch_guard)?; unsafe { - raw.dispatch_indirect(buf_raw, offset); + state.raw_encoder.dispatch_indirect(buf_raw, offset); } Ok(()) } -fn push_debug_group( - state: &mut State, - raw: &mut A::CommandEncoder, - string_data: &[u8], - len: usize, -) { +fn push_debug_group(state: &mut State, string_data: &[u8], len: usize) { state.debug_scope_depth += 1; if !state .device @@ -978,16 +980,13 @@ fn push_debug_group( let label = str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap(); unsafe { - raw.begin_debug_marker(label); + state.raw_encoder.begin_debug_marker(label); } } state.string_offset += len; } -fn pop_debug_group( - state: &mut State, - raw: &mut A::CommandEncoder, -) -> Result<(), ComputePassErrorInner> { +fn pop_debug_group(state: &mut State) -> Result<(), ComputePassErrorInner> { if state.debug_scope_depth == 0 { return Err(ComputePassErrorInner::InvalidPopDebugGroup); } @@ -998,18 +997,13 @@ fn pop_debug_group( .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS) { unsafe { - raw.end_debug_marker(); + state.raw_encoder.end_debug_marker(); } } Ok(()) } -fn insert_debug_marker( - state: &mut State, - raw: &mut A::CommandEncoder, - string_data: &[u8], - len: usize, -) { +fn insert_debug_marker(state: &mut State, string_data: &[u8], len: usize) { if !state .device .instance_flags @@ -1017,14 +1011,13 @@ fn insert_debug_marker( { let label = str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap(); - unsafe { raw.insert_debug_marker(label) } + unsafe { state.raw_encoder.insert_debug_marker(label) } } state.string_offset += len; } fn write_timestamp( state: &mut State, - raw: &mut A::CommandEncoder, cmd_buf: &CommandBuffer, query_set: Arc>, query_index: u32, @@ -1037,7 +1030,7 @@ fn write_timestamp( let query_set = state.tracker.query_sets.insert_single(query_set); - query_set.validate_and_write_timestamp(raw, query_index, None)?; + query_set.validate_and_write_timestamp(state.raw_encoder, query_index, None)?; Ok(()) }