diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 533383c63..96a1d4488 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -1,5 +1,7 @@ use crate::{ - binding_model::{BindError, LateMinBufferBindingSizeMismatch, PushConstantUploadError}, + binding_model::{ + BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError, + }, command::{ bind::Binder, compute_command::{ArcComputeCommand, ComputeCommand}, @@ -589,68 +591,16 @@ impl Global { bind_group, } => { let scope = PassErrorScope::SetBindGroup(bind_group.as_info().id()); - - bind_group.same_device_as(cmd_buf).map_pass_err(scope)?; - - let max_bind_groups = state.device.limits.max_bind_groups; - if index >= max_bind_groups { - return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { - index, - max: max_bind_groups, - }) - .map_pass_err(scope); - } - - state.temp_offsets.clear(); - state.temp_offsets.extend_from_slice( - &base.dynamic_offsets[state.dynamic_offset_count - ..state.dynamic_offset_count + num_dynamic_offsets], - ); - state.dynamic_offset_count += num_dynamic_offsets; - - let bind_group = state.tracker.bind_groups.insert_single(bind_group); - bind_group - .validate_dynamic_bindings(index, &state.temp_offsets) - .map_pass_err(scope)?; - - state.buffer_memory_init_actions.extend( - bind_group.used_buffer_ranges.iter().filter_map(|action| { - action - .buffer - .initialization_status - .read() - .check_action(action) - }), - ); - - for action in bind_group.used_texture_ranges.iter() { - state - .pending_discard_init_fixups - .extend(state.texture_memory_actions.register_init_action(action)); - } - - let pipeline_layout = state.binder.pipeline_layout.clone(); - let entries = - state - .binder - .assign_group(index as usize, bind_group, &state.temp_offsets); - if !entries.is_empty() && pipeline_layout.is_some() { - let pipeline_layout = pipeline_layout.as_ref().unwrap().raw(); - for (i, e) in entries.iter().enumerate() { - if let Some(group) = e.group.as_ref() { - let raw_bg = - group.try_raw(&state.snatch_guard).map_pass_err(scope)?; - unsafe { - raw.set_bind_group( - pipeline_layout, - index + i as u32, - raw_bg, - &e.dynamic_offsets, - ); - } - } - } - } + set_bind_group( + &mut state, + raw, + cmd_buf, + &base.dynamic_offsets, + index, + num_dynamic_offsets, + bind_group, + ) + .map_pass_err(scope)?; } ArcComputeCommand::SetPipeline(pipeline) => { let scope = PassErrorScope::SetPipelineCompute(pipeline.as_info().id()); @@ -969,6 +919,74 @@ impl Global { } } +fn set_bind_group( + state: &mut State, + raw: &mut A::CommandEncoder, + cmd_buf: &CommandBuffer, + dynamic_offsets: &[DynamicOffset], + index: u32, + num_dynamic_offsets: usize, + bind_group: Arc>, +) -> Result<(), ComputePassErrorInner> { + bind_group.same_device_as(cmd_buf)?; + + let max_bind_groups = state.device.limits.max_bind_groups; + if index >= max_bind_groups { + return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { + index, + max: max_bind_groups, + }); + } + + state.temp_offsets.clear(); + state.temp_offsets.extend_from_slice( + &dynamic_offsets + [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets], + ); + state.dynamic_offset_count += num_dynamic_offsets; + + let bind_group = state.tracker.bind_groups.insert_single(bind_group); + bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?; + + state + .buffer_memory_init_actions + .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| { + action + .buffer + .initialization_status + .read() + .check_action(action) + })); + + for action in bind_group.used_texture_ranges.iter() { + state + .pending_discard_init_fixups + .extend(state.texture_memory_actions.register_init_action(action)); + } + + let pipeline_layout = state.binder.pipeline_layout.clone(); + let entries = state + .binder + .assign_group(index as usize, bind_group, &state.temp_offsets); + if !entries.is_empty() && pipeline_layout.is_some() { + let pipeline_layout = pipeline_layout.as_ref().unwrap().raw(); + for (i, e) in entries.iter().enumerate() { + if let Some(group) = e.group.as_ref() { + let raw_bg = group.try_raw(&state.snatch_guard)?; + unsafe { + raw.set_bind_group( + pipeline_layout, + index + i as u32, + raw_bg, + &e.dynamic_offsets, + ); + } + } + } + } + Ok(()) +} + // Recording a compute pass. impl Global { pub fn compute_pass_set_bind_group(